refactor(benchmarks): modularize benchmark system with config-driven execution

This commit is contained in:
2026-03-02 16:00:36 +08:00
parent a7b01cb49e
commit a16b376dd7
14 changed files with 779 additions and 180 deletions

View File

@@ -1,59 +1,12 @@
from typing import Literal, cast
"""Benchmark evaluation module.
import torch
from compressors import DinoCompressor, FloatCompressor
from configs import cfg_manager
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device
This module provides a modular benchmark system that supports:
- Multiple dataset sources (HuggingFace, local)
- Multiple evaluation tasks (retrieval, with extensibility for more)
- Configuration-driven execution
"""
from .task_eval import task_eval
from .base import BaseBenchmarkTask, BaseDataset
from .runner import run_benchmark
def evaluate(
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
dataset: Literal["CIFAR-10", "CIFAR-100"],
benchmark: Literal["Recall@1", "Recall@10"],
):
"""运行模型评估。
Args:
compressor_model: 压缩模型类型。
dataset: 数据集名称。
benchmark: 评估指标。
"""
device = get_device()
match compressor_model:
case "Dinov2":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=device
),
)
model = DinoCompressor().to(device)
case "Dinov2WithCompressor":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=device
),
)
output_dir = cfg_manager.get().output.directory
compressor = FloatCompressor()
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
model = DinoCompressor(compressor).to(device)
case _:
raise ValueError(f"Unknown compressor: {compressor_model}")
# 根据 benchmark 确定 top_k
match benchmark:
case "Recall@1":
task_eval(processor, model, dataset, compressor_model, top_k=1)
case "Recall@10":
task_eval(processor, model, dataset, compressor_model, top_k=10)
case _:
raise ValueError(f"Unknown benchmark: {benchmark}")
__all__ = ["task_eval", "evaluate"]
__all__ = ["BaseBenchmarkTask", "BaseDataset", "run_benchmark"]

View File

@@ -0,0 +1,99 @@
"""Base classes for benchmark datasets and tasks."""
from abc import ABC, abstractmethod
from typing import Any, Protocol
import lancedb
from torch.utils.data import DataLoader
class BaseDataset(ABC):
"""Abstract base class for benchmark datasets."""
@abstractmethod
def get_train_split(self) -> Any:
"""Get training split of the dataset.
Returns:
Dataset object for training.
"""
pass
@abstractmethod
def get_test_split(self) -> Any:
"""Get test/evaluation split of the dataset.
Returns:
Dataset object for testing.
"""
pass
class BaseBenchmarkTask(ABC):
"""Abstract base class for benchmark evaluation tasks."""
def __init__(self, **kwargs: Any):
"""Initialize the benchmark task.
Args:
**kwargs: Task-specific configuration parameters.
"""
self.config = kwargs
@abstractmethod
def build_database(
self,
model: Any,
processor: Any,
train_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> None:
"""Build the evaluation database from training data.
Args:
model: Feature extraction model.
processor: Image preprocessor.
train_dataset: Training dataset.
table: LanceDB table to store features.
batch_size: Batch size for DataLoader.
"""
pass
@abstractmethod
def evaluate(
self,
model: Any,
processor: Any,
test_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> dict[str, Any]:
"""Evaluate the model on the test dataset.
Args:
model: Feature extraction model.
processor: Image preprocessor.
test_dataset: Test dataset.
table: LanceDB table to search against.
batch_size: Batch size for DataLoader.
Returns:
Dictionary containing evaluation results.
"""
pass
class DatasetFactory(Protocol):
"""Protocol for dataset factory."""
def __call__(self, config: Any) -> BaseDataset:
"""Create a dataset from configuration.
Args:
config: Dataset configuration.
Returns:
Dataset instance.
"""
...

View File

@@ -0,0 +1,6 @@
"""Dataset loaders for benchmark evaluation."""
from .huggingface import HuggingFaceDataset
from .local import LocalDataset
__all__ = ["HuggingFaceDataset", "LocalDataset"]

View File

@@ -0,0 +1,66 @@
"""HuggingFace dataset loader for benchmark evaluation."""
from typing import Any
from datasets import load_dataset
from ..base import BaseDataset
class HuggingFaceDataset(BaseDataset):
"""Dataset loader for HuggingFace datasets."""
def __init__(
self,
hf_id: str,
img_column: str = "img",
label_column: str = "label",
):
"""Initialize HuggingFace dataset loader.
Args:
hf_id: HuggingFace dataset ID.
img_column: Name of the image column.
label_column: Name of the label column.
"""
self.hf_id = hf_id
self.img_column = img_column
self.label_column = label_column
self._train_dataset: Any = None
self._test_dataset: Any = None
def _load(self) -> tuple[Any, Any]:
"""Load dataset from HuggingFace.
Returns:
Tuple of (train_dataset, test_dataset).
"""
if self._train_dataset is None:
dataset = load_dataset(self.hf_id)
# Handle datasets that use 'train' and 'test' splits
if "train" in dataset:
self._train_dataset = dataset["train"]
if "test" in dataset:
self._test_dataset = dataset["test"]
# Handle datasets that use 'train' and 'validation' splits
elif "validation" in dataset:
self._test_dataset = dataset["validation"]
return self._train_dataset, self._test_dataset
def get_train_split(self) -> Any:
"""Get training split of the dataset.
Returns:
Training dataset.
"""
train, _ = self._load()
return train
def get_test_split(self) -> Any:
"""Get test/evaluation split of the dataset.
Returns:
Test dataset.
"""
_, test = self._load()
return test

View File

@@ -0,0 +1,157 @@
"""Local dataset loader for benchmark evaluation."""
from pathlib import Path
from typing import Any, Optional
from ..base import BaseDataset
class LocalDataset(BaseDataset):
"""Dataset loader for local datasets."""
def __init__(
self,
local_path: str,
img_column: str = "image_path",
label_column: str = "label",
):
"""Initialize local dataset loader.
Args:
local_path: Path to local dataset directory or CSV file.
img_column: Name of the image path column.
label_column: Name of the label column.
"""
self.local_path = Path(local_path)
self.img_column = img_column
self.label_column = label_column
self._train_dataset: Optional[Any] = None
self._test_dataset: Optional[Any] = None
def _load_csv_dataset(self) -> tuple[Any, Any]:
"""Load dataset from CSV file.
Expected CSV format:
label,image_path,x1,y1,x2,y2
"class_name","path/to/image.jpg",100,200,300,400
Returns:
Tuple of (train_dataset, test_dataset).
"""
import pandas as pd
from torch.utils.data import Dataset as TorchDataset
# Load CSV file
df = pd.read_csv(self.local_path)
# Create a simple dataset class
class CSVDataset(TorchDataset):
def __init__(self, dataframe: pd.DataFrame, img_col: str, label_col: str):
self.df = dataframe.reset_index(drop=True)
self.img_col = img_col
self.label_col = label_col
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int) -> dict[str, Any]:
row = self.df.iloc[idx]
return {
"img": row[self.img_col],
"label": row[self.label_col],
}
# Split into train/test (80/20)
split_idx = int(len(df) * 0.8)
train_df = df.iloc[:split_idx]
test_df = df.iloc[split_idx:]
self._train_dataset = CSVDataset(train_df, self.img_column, self.label_column)
self._test_dataset = CSVDataset(test_df, self.img_column, self.label_column)
return self._train_dataset, self._test_dataset
def _load_directory_dataset(self) -> tuple[Any, Any]:
"""Load dataset from directory structure.
Expected structure:
local_path/
train/
class_name_1/
image1.jpg
image2.jpg
class_name_2/
image1.jpg
test/
class_name_1/
image1.jpg
Returns:
Tuple of (train_dataset, test_dataset).
"""
from torch.utils.data import Dataset as TorchDataset
from PIL import Image
class DirectoryDataset(TorchDataset):
def __init__(self, root_dir: Path, transform=None):
self.root_dir = root_dir
self.transform = transform
self.samples = []
self.label_map = {}
# Build label map
classes = sorted([d.name for d in root_dir.iterdir() if d.is_dir()])
self.label_map = {cls: idx for idx, cls in enumerate(classes)}
# Build sample list
for cls_dir in root_dir.iterdir():
if cls_dir.is_dir():
label = self.label_map[cls_dir.name]
for img_path in cls_dir.iterdir():
if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
self.samples.append((img_path, label))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> dict[str, Any]:
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
return {"img": image, "label": label}
train_dir = self.local_path / "train"
test_dir = self.local_path / "test"
if train_dir.exists():
self._train_dataset = DirectoryDataset(train_dir)
if test_dir.exists():
self._test_dataset = DirectoryDataset(test_dir)
return self._train_dataset, self._test_dataset
def get_train_split(self) -> Any:
"""Get training split of the dataset.
Returns:
Training dataset.
"""
if self._train_dataset is None:
if self.local_path.suffix.lower() == ".csv":
self._load_csv_dataset()
else:
self._load_directory_dataset()
return self._train_dataset
def get_test_split(self) -> Any:
"""Get test/evaluation split of the dataset.
Returns:
Test dataset.
"""
if self._test_dataset is None:
if self.local_path.suffix.lower() == ".csv":
self._load_csv_dataset()
else:
self._load_directory_dataset()
return self._test_dataset

View File

@@ -0,0 +1,186 @@
"""Benchmark runner for executing evaluations."""
from pathlib import Path
from typing import Any
import lancedb
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
from benchmarks.tasks import get_task
from configs.models import BenchmarkConfig, DatasetSourceConfig
def create_dataset(config: DatasetSourceConfig) -> Any:
"""Create a dataset instance from configuration.
Args:
config: Dataset source configuration.
Returns:
Dataset instance.
Raises:
ValueError: If source_type is not supported.
"""
if config.source_type == "huggingface":
return HuggingFaceDataset(
hf_id=config.path,
img_column=config.img_column,
label_column=config.label_column,
)
elif config.source_type == "local":
return LocalDataset(
local_path=config.path,
img_column=config.img_column,
label_column=config.label_column,
)
else:
raise ValueError(
f"Unsupported source_type: {config.source_type}. "
f"Supported types: 'huggingface', 'local'"
)
def _get_table_name(config: BenchmarkConfig, model_name: str) -> str:
"""Generate database table name from config and model name.
Args:
config: Benchmark configuration.
model_name: Model name for table naming.
Returns:
Formatted table name.
"""
prefix = config.model_table_prefix
# Use dataset path as part of table name (sanitize)
dataset_name = Path(config.dataset.path).name.lower().replace("-", "_")
return f"{prefix}_{dataset_name}_{model_name}"
def _ensure_table(
config: BenchmarkConfig,
model_name: str,
vector_dim: int,
) -> lancedb.table.Table:
"""Ensure the LanceDB table exists with correct schema.
Args:
config: Benchmark configuration.
model_name: Model name for table naming.
vector_dim: Feature vector dimension.
Returns:
LanceDB table instance.
"""
import pyarrow as pa
from database import db_manager
table_name = _get_table_name(config, model_name)
# Build expected schema
schema = pa.schema(
[
pa.field("id", pa.int32()),
pa.field("label", pa.int32()),
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
]
)
db = db_manager.db
existing_tables = db.list_tables().tables
# Check if table exists and has correct schema
if table_name in existing_tables:
table = db.open_table(table_name)
if table.schema != schema:
print(f"Table '{table_name}' schema mismatch, rebuilding.")
db.drop_table(table_name)
table = db.create_table(table_name, schema=schema)
else:
table = db.create_table(table_name, schema=schema)
return table
def run_benchmark(
model: Any,
processor: Any,
config: BenchmarkConfig,
model_name: str = "model",
) -> dict[str, Any]:
"""Run benchmark evaluation.
Workflow:
1. Create dataset from configuration
2. Create benchmark task from configuration
3. Build evaluation database from training set
4. Evaluate on test set
Args:
model: Feature extraction model.
processor: Image preprocessor.
config: Benchmark configuration.
model_name: Model name for table naming.
Returns:
Dictionary containing evaluation results.
Raises:
ValueError: If benchmark is not enabled in config.
"""
if not config.enabled:
raise ValueError(
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
)
# Create dataset
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
dataset = create_dataset(config.dataset)
# Get train and test splits
train_dataset = dataset.get_train_split()
test_dataset = dataset.get_test_split()
if train_dataset is None or test_dataset is None:
raise ValueError(
f"Dataset {config.dataset.path} does not have train/test splits"
)
# Infer vector dimension from a sample
sample = train_dataset[0]
sample_image = sample["img"]
from .tasks.retrieval import _infer_vector_dim
vector_dim = _infer_vector_dim(processor, model, sample_image)
print(f"Model output dimension: {vector_dim}")
# Ensure table exists with correct schema
table = _ensure_table(config, model_name, vector_dim)
table_name = _get_table_name(config, model_name)
# Check if database is already built
table_count = table.count_rows()
if table_count > 0:
print(
f"Table '{table_name}' already has {table_count} entries, skipping database build."
)
else:
# Create and run benchmark task
task = get_task(config.task.type, top_k=config.task.top_k)
print(f"Building database with {len(train_dataset)} training samples...")
task.build_database(model, processor, train_dataset, table, config.batch_size)
# Run evaluation
task = get_task(config.task.type, top_k=config.task.top_k)
print(f"Evaluating on {len(test_dataset)} test samples...")
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
# Print results
print("\n=== Benchmark Results ===")
print(f"Dataset: {config.dataset.path}")
print(f"Task: {config.task.name}")
print(f"Top-K: {results['top_k']}")
print(f"Accuracy: {results['accuracy']:.4f}")
print(f"Correct: {results['correct']}/{results['total']}")
return results

View File

@@ -0,0 +1,6 @@
"""Benchmark evaluation tasks."""
from .retrieval import RetrievalTask
from .registry import TASK_REGISTRY, get_task
__all__ = ["RetrievalTask", "TASK_REGISTRY", "get_task"]

View File

@@ -0,0 +1,59 @@
"""Task registry for benchmark evaluation."""
from typing import Any, Type
from benchmarks.base import BaseBenchmarkTask
# Task registry: maps task type string to task class
TASK_REGISTRY: dict[str, Type[BaseBenchmarkTask]] = {}
class RegisterTask:
"""Decorator class to register a benchmark task.
Usage:
@register_task("retrieval")
class RetrievalTask(BaseBenchmarkTask):
...
"""
def __init__(self, task_type: str):
"""Initialize the decorator with task type.
Args:
task_type: Task type identifier.
"""
self.task_type = task_type
def __call__(self, cls: type[BaseBenchmarkTask]) -> type[BaseBenchmarkTask]:
"""Register the decorated class to the task registry.
Args:
cls: The class to be decorated.
Returns:
The unmodified class.
"""
TASK_REGISTRY[self.task_type] = cls
return cls
def get_task(task_type: str, **kwargs: Any) -> BaseBenchmarkTask:
"""Get a benchmark task instance by type.
Args:
task_type: Task type identifier.
**kwargs: Additional arguments passed to task constructor.
Returns:
Task instance.
Raises:
ValueError: If task type is not registered.
"""
if task_type not in TASK_REGISTRY:
available = list(TASK_REGISTRY.keys())
raise ValueError(
f"Unknown task type: {task_type}. Available tasks: {available}"
)
return TASK_REGISTRY[task_type](**kwargs)

View File

@@ -1,41 +1,22 @@
from typing import Dict, Literal, cast
"""Retrieval task for benchmark evaluation (Recall@K)."""
from typing import Any, cast
import lancedb
import pyarrow as pa
import torch
from database import db_manager
from datasets import load_dataset
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名)
DATASET_CONFIG: Dict[str, tuple[str, str, str]] = {
"CIFAR-10": ("uoft-cs/cifar10", "img", "label"),
"CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"),
}
def _get_table_name(dataset: str, model_name: str) -> str:
"""Generate database table name from dataset and model name.
Args:
dataset: Dataset name, e.g. "CIFAR-10".
model_name: Model name, e.g. "Dinov2".
Returns:
Formatted table name, e.g. "cifar10_dinov2".
"""
ds_part = dataset.lower().replace("-", "")
model_part = model_name.lower()
return f"{ds_part}_{model_part}"
def _infer_vector_dim(
processor: BitImageProcessorFast,
model: nn.Module,
sample_image,
sample_image: Any,
) -> int:
"""Infer model output vector dimension via a single forward pass.
@@ -55,7 +36,6 @@ def _infer_vector_dim(
inputs.to(device)
output = model(inputs)
# output shape: [1, dim]
return output.shape[-1]
@@ -100,16 +80,13 @@ def _establish_eval_database(
imgs = batch["img"]
labels = batch["label"]
# 预处理并推理
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs) # [B, dim]
outputs = model(inputs)
# 整个batch一次性转到CPU
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
# 逐条写入数据库
batch_size = len(labels_list)
table.add(
[
@@ -134,9 +111,6 @@ def _evaluate_recall(
) -> tuple[int, int]:
"""Evaluate Recall@K by searching the database for each test image.
For each batch, features are extracted in one forward pass and moved to CPU,
then each sample is searched individually against the database.
Args:
processor: Image preprocessor.
model: Feature extraction model.
@@ -157,21 +131,17 @@ def _evaluate_recall(
imgs = batch["img"]
labels = batch["label"]
# 批量前向推理
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs) # [B, dim]
outputs = model(inputs)
# 整个batch一次性转到CPU
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
# 逐条搜索并验证
for j in range(len(labels_list)):
feature = features[j].tolist()
true_label = labels_list[j]
# 搜索 top_k 最相似结果
results = (
table.search(feature)
.select(["label", "_distance"])
@@ -179,7 +149,6 @@ def _evaluate_recall(
.to_polars()
)
# 检查 top_k 中是否包含正确标签
retrieved_labels = results["label"].to_list()
if true_label in retrieved_labels:
correct += 1
@@ -188,69 +157,51 @@ def _evaluate_recall(
return correct, total
def task_eval(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Literal["CIFAR-10", "CIFAR-100"],
model_name: str,
top_k: int = 10,
batch_size: int = 64,
) -> float:
"""Evaluate model Recall@K accuracy on a dataset using vector retrieval.
@RegisterTask("retrieval")
class RetrievalTask(BaseBenchmarkTask):
"""Retrieval evaluation task (Recall@K)."""
Workflow:
1. Create or open a database table named by dataset and model.
2. Build database from training set features (skip if table exists).
3. Evaluate on test set: extract features in batches, search top_k,
check if correct label appears in results.
def __init__(self, top_k: int = 10):
"""Initialize retrieval task.
Args:
processor: Image preprocessor.
model: Feature extraction model.
dataset: Dataset name.
model_name: Model name, used for table name generation.
top_k: Number of top similar results to retrieve.
batch_size: Batch size for DataLoader.
Args:
top_k: Number of top results to retrieve for recall calculation.
"""
super().__init__(top_k=top_k)
self.top_k = top_k
Returns:
Recall@K accuracy (0.0 ~ 1.0).
def build_database(
self,
model: Any,
processor: Any,
train_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> None:
"""Build the evaluation database from training data.
Raises:
ValueError: If dataset name is not supported.
"""
if dataset not in DATASET_CONFIG:
raise ValueError(
f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}."
)
hf_id, img_col, label_col = DATASET_CONFIG[dataset]
Args:
model: Feature extraction model.
processor: Image preprocessor.
train_dataset: Training dataset.
table: LanceDB table to store features.
batch_size: Batch size for DataLoader.
"""
# Get a sample image to infer vector dimension
sample = train_dataset[0]
sample_image = sample["img"]
# 加载数据集
train_dataset = load_dataset(hf_id, split="train")
test_dataset = load_dataset(hf_id, split="test")
vector_dim = _infer_vector_dim(processor, model, sample_image)
expected_schema = _build_eval_schema(vector_dim)
# 生成表名,推断向量维度
table_name = _get_table_name(dataset, model_name)
vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col])
expected_schema = _build_eval_schema(vector_dim)
existing_tables = db_manager.db.list_tables().tables
# Check schema compatibility
if table.schema != expected_schema:
raise ValueError(
f"Table schema mismatch. Expected: {expected_schema}, "
f"Got: {table.schema}"
)
# 如果旧表 schema 不匹配(如 label 类型变更),删除重建
if table_name in existing_tables:
old_table = db_manager.db.open_table(table_name)
if old_table.schema != expected_schema:
print(f"Table '{table_name}' schema mismatch, rebuilding.")
db_manager.db.drop_table(table_name)
existing_tables = []
if table_name in existing_tables:
# 表已存在且 schema 匹配,跳过建库
print(f"Table '{table_name}' already exists, skipping database build.")
table = db_manager.db.open_table(table_name)
else:
# 创建新表
table = db_manager.db.create_table(table_name, schema=expected_schema)
# 使用 DataLoader 批量建库
# Build database
train_loader = DataLoader(
train_dataset.with_format("torch"),
batch_size=batch_size,
@@ -259,17 +210,45 @@ def task_eval(
)
_establish_eval_database(processor, model, table, train_loader)
# 使用 DataLoader 批量评估
test_loader = DataLoader(
test_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=False,
num_workers=4,
)
correct, total = _evaluate_recall(processor, model, table, test_loader, top_k)
def evaluate(
self,
model: Any,
processor: Any,
test_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> dict[str, Any]:
"""Evaluate the model on the test dataset.
accuracy = correct / total
print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}")
print(f"Correct: {correct}/{total}")
Args:
model: Feature extraction model.
processor: Image preprocessor.
test_dataset: Test dataset.
table: LanceDB table to search against.
batch_size: Batch size for DataLoader.
return accuracy
Returns:
Dictionary containing evaluation results with keys:
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
- correct: Number of correct predictions
- total: Total number of test samples
- top_k: The K value used
"""
test_loader = DataLoader(
test_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=False,
num_workers=4,
)
correct, total = _evaluate_recall(
processor, model, table, test_loader, self.top_k
)
accuracy = correct / total if total > 0 else 0.0
return {
"accuracy": accuracy,
"correct": correct,
"total": total,
"top_k": self.top_k,
}

View File

@@ -2,10 +2,10 @@ model:
name: "facebook/dinov2-large"
compression_dim: 512
device: "auto" # auto-detect GPU
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
sam_min_mask_area: 100 # Minimum mask area threshold
sam_max_masks: 10 # Maximum number of masks to keep
compressor_path: null # Path to trained HashCompressor weights (optional)
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
sam_min_mask_area: 100 # Minimum mask area threshold
sam_max_masks: 10 # Maximum number of masks to keep
compressor_path: null # Path to trained HashCompressor weights (optional)
output:
directory: "./outputs"
@@ -19,3 +19,17 @@ dataset:
rotation_range: [-30, 30]
overlap_threshold: 0.3
seed: 42
benchmark:
enabled: true
dataset:
source_type: "huggingface"
path: "uoft-cs/cifar10"
img_column: "img"
label_column: "label"
task:
name: "recall_at_k"
type: "retrieval"
top_k: 10
batch_size: 64
model_table_prefix: "benchmark"

View File

@@ -1,7 +1,7 @@
"""Pydantic data models for feature compressor configuration."""
from pathlib import Path
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -98,6 +98,41 @@ class DatasetConfig(BaseModel):
return v
class DatasetSourceConfig(BaseModel):
"""Configuration for benchmark dataset source."""
model_config = ConfigDict(extra="ignore")
source_type: Literal["huggingface", "local"] = "huggingface"
path: str = Field(default="", description="HuggingFace dataset ID or local path")
img_column: str = Field(default="img", description="Image column name")
label_column: str = Field(default="label", description="Label column name")
class BenchmarkTaskConfig(BaseModel):
"""Configuration for benchmark task."""
model_config = ConfigDict(extra="ignore")
name: str = Field(default="recall_at_k", description="Task name")
type: str = Field(default="retrieval", description="Task type")
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
class BenchmarkConfig(BaseModel):
"""Configuration for benchmark evaluation."""
model_config = ConfigDict(extra="ignore")
enabled: bool = Field(default=False, description="Enable benchmark evaluation")
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")
model_table_prefix: str = Field(
default="benchmark", description="Prefix for LanceDB table names"
)
class Config(BaseModel):
"""Root configuration for the feature compressor."""
@@ -106,3 +141,4 @@ class Config(BaseModel):
model: ModelConfig = Field(default_factory=ModelConfig)
output: OutputConfig = Field(default_factory=OutputConfig)
dataset: DatasetConfig = Field(default_factory=DatasetConfig)
benchmark: BenchmarkConfig = Field(default_factory=BenchmarkConfig)

View File

@@ -16,9 +16,51 @@ if __name__ == "__main__":
epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt"
)
elif args.action == "benchmark":
from benchmarks import evaluate
from typing import cast
evaluate("Dinov2", "CIFAR-10", "Recall@10")
import torch
from benchmarks import run_benchmark
from compressors import DinoCompressor
from configs import cfg_manager
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device
config = cfg_manager.get()
benchmark_cfg = config.benchmark
if not benchmark_cfg.enabled:
print("Benchmark is not enabled. Set benchmark.enabled=true in config.yaml")
exit(1)
device = get_device()
# Load model and processor based on config
model_cfg = config.model
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device),
)
# Load compressor weights if specified in model config
model = DinoCompressor().to(device)
if model_cfg.compressor_path is not None:
from compressors import HashCompressor
compressor = HashCompressor(
input_dim=model_cfg.compression_dim,
output_dim=model_cfg.compression_dim,
)
compressor.load_state_dict(torch.load(model_cfg.compressor_path))
# Wrap with compressor if path is specified
model.compressor = compressor
# Run benchmark
run_benchmark(
model=model,
processor=processor,
config=benchmark_cfg,
model_name="dinov2",
)
elif args.action == "visualize":
from visualizer import app

View File

@@ -1,26 +1,21 @@
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import torch
from PIL import Image
from configs import cfg_manager
from compressors import (
BinarySign,
DinoCompressor,
HashCompressor,
SegmentCompressor,
SAMHashPipeline,
create_pipeline_from_config,
SegmentCompressor,
bits_to_hash,
hash_to_bits,
create_pipeline_from_config,
hamming_distance,
hamming_similarity,
hash_to_bits,
)
from configs import cfg_manager
from PIL import Image
class TestHashCompressor: