From a16b376dd764eaa62d12ab5a00abeb6f74e7b73d Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Mon, 2 Mar 2026 16:00:36 +0800 Subject: [PATCH] refactor(benchmarks): modularize benchmark system with config-driven execution --- .gitignore | 1 + mini-nav/benchmarks/__init__.py | 65 +----- mini-nav/benchmarks/base.py | 99 +++++++++ mini-nav/benchmarks/datasets/__init__.py | 6 + mini-nav/benchmarks/datasets/huggingface.py | 66 ++++++ mini-nav/benchmarks/datasets/local.py | 157 ++++++++++++++ mini-nav/benchmarks/runner.py | 186 +++++++++++++++++ mini-nav/benchmarks/tasks/__init__.py | 6 + mini-nav/benchmarks/tasks/registry.py | 59 ++++++ .../{task_eval.py => tasks/retrieval.py} | 193 ++++++++---------- mini-nav/configs/config.yaml | 22 +- mini-nav/configs/models.py | 38 +++- mini-nav/main.py | 46 ++++- mini-nav/tests/test_compressors.py | 15 +- 14 files changed, 779 insertions(+), 180 deletions(-) create mode 100644 mini-nav/benchmarks/base.py create mode 100644 mini-nav/benchmarks/datasets/__init__.py create mode 100644 mini-nav/benchmarks/datasets/huggingface.py create mode 100644 mini-nav/benchmarks/datasets/local.py create mode 100644 mini-nav/benchmarks/runner.py create mode 100644 mini-nav/benchmarks/tasks/__init__.py create mode 100644 mini-nav/benchmarks/tasks/registry.py rename mini-nav/benchmarks/{task_eval.py => tasks/retrieval.py} (50%) diff --git a/.gitignore b/.gitignore index 461da70..54078a0 100644 --- a/.gitignore +++ b/.gitignore @@ -207,6 +207,7 @@ __marimo__/ # Projects datasets/ +!mini-nav/**/datasets/ data/ deps/ outputs/ diff --git a/mini-nav/benchmarks/__init__.py b/mini-nav/benchmarks/__init__.py index 06f74c5..45cbf75 100644 --- a/mini-nav/benchmarks/__init__.py +++ b/mini-nav/benchmarks/__init__.py @@ -1,59 +1,12 @@ -from typing import Literal, cast +"""Benchmark evaluation module. -import torch -from compressors import DinoCompressor, FloatCompressor -from configs import cfg_manager -from transformers import AutoImageProcessor, BitImageProcessorFast -from utils import get_device +This module provides a modular benchmark system that supports: +- Multiple dataset sources (HuggingFace, local) +- Multiple evaluation tasks (retrieval, with extensibility for more) +- Configuration-driven execution +""" -from .task_eval import task_eval +from .base import BaseBenchmarkTask, BaseDataset +from .runner import run_benchmark - -def evaluate( - compressor_model: Literal["Dinov2", "Dinov2WithCompressor"], - dataset: Literal["CIFAR-10", "CIFAR-100"], - benchmark: Literal["Recall@1", "Recall@10"], -): - """运行模型评估。 - - Args: - compressor_model: 压缩模型类型。 - dataset: 数据集名称。 - benchmark: 评估指标。 - """ - device = get_device() - - match compressor_model: - case "Dinov2": - processor = cast( - BitImageProcessorFast, - AutoImageProcessor.from_pretrained( - "facebook/dinov2-large", device_map=device - ), - ) - model = DinoCompressor().to(device) - case "Dinov2WithCompressor": - processor = cast( - BitImageProcessorFast, - AutoImageProcessor.from_pretrained( - "facebook/dinov2-large", device_map=device - ), - ) - output_dir = cfg_manager.get().output.directory - compressor = FloatCompressor() - compressor.load_state_dict(torch.load(output_dir / "compressor.pt")) - model = DinoCompressor(compressor).to(device) - case _: - raise ValueError(f"Unknown compressor: {compressor_model}") - - # 根据 benchmark 确定 top_k - match benchmark: - case "Recall@1": - task_eval(processor, model, dataset, compressor_model, top_k=1) - case "Recall@10": - task_eval(processor, model, dataset, compressor_model, top_k=10) - case _: - raise ValueError(f"Unknown benchmark: {benchmark}") - - -__all__ = ["task_eval", "evaluate"] +__all__ = ["BaseBenchmarkTask", "BaseDataset", "run_benchmark"] \ No newline at end of file diff --git a/mini-nav/benchmarks/base.py b/mini-nav/benchmarks/base.py new file mode 100644 index 0000000..e013393 --- /dev/null +++ b/mini-nav/benchmarks/base.py @@ -0,0 +1,99 @@ +"""Base classes for benchmark datasets and tasks.""" + +from abc import ABC, abstractmethod +from typing import Any, Protocol + +import lancedb +from torch.utils.data import DataLoader + + +class BaseDataset(ABC): + """Abstract base class for benchmark datasets.""" + + @abstractmethod + def get_train_split(self) -> Any: + """Get training split of the dataset. + + Returns: + Dataset object for training. + """ + pass + + @abstractmethod + def get_test_split(self) -> Any: + """Get test/evaluation split of the dataset. + + Returns: + Dataset object for testing. + """ + pass + + +class BaseBenchmarkTask(ABC): + """Abstract base class for benchmark evaluation tasks.""" + + def __init__(self, **kwargs: Any): + """Initialize the benchmark task. + + Args: + **kwargs: Task-specific configuration parameters. + """ + self.config = kwargs + + @abstractmethod + def build_database( + self, + model: Any, + processor: Any, + train_dataset: Any, + table: lancedb.table.Table, + batch_size: int, + ) -> None: + """Build the evaluation database from training data. + + Args: + model: Feature extraction model. + processor: Image preprocessor. + train_dataset: Training dataset. + table: LanceDB table to store features. + batch_size: Batch size for DataLoader. + """ + pass + + @abstractmethod + def evaluate( + self, + model: Any, + processor: Any, + test_dataset: Any, + table: lancedb.table.Table, + batch_size: int, + ) -> dict[str, Any]: + """Evaluate the model on the test dataset. + + Args: + model: Feature extraction model. + processor: Image preprocessor. + test_dataset: Test dataset. + table: LanceDB table to search against. + batch_size: Batch size for DataLoader. + + Returns: + Dictionary containing evaluation results. + """ + pass + + +class DatasetFactory(Protocol): + """Protocol for dataset factory.""" + + def __call__(self, config: Any) -> BaseDataset: + """Create a dataset from configuration. + + Args: + config: Dataset configuration. + + Returns: + Dataset instance. + """ + ... diff --git a/mini-nav/benchmarks/datasets/__init__.py b/mini-nav/benchmarks/datasets/__init__.py new file mode 100644 index 0000000..cb4eecd --- /dev/null +++ b/mini-nav/benchmarks/datasets/__init__.py @@ -0,0 +1,6 @@ +"""Dataset loaders for benchmark evaluation.""" + +from .huggingface import HuggingFaceDataset +from .local import LocalDataset + +__all__ = ["HuggingFaceDataset", "LocalDataset"] \ No newline at end of file diff --git a/mini-nav/benchmarks/datasets/huggingface.py b/mini-nav/benchmarks/datasets/huggingface.py new file mode 100644 index 0000000..4841a09 --- /dev/null +++ b/mini-nav/benchmarks/datasets/huggingface.py @@ -0,0 +1,66 @@ +"""HuggingFace dataset loader for benchmark evaluation.""" + +from typing import Any + +from datasets import load_dataset + +from ..base import BaseDataset + + +class HuggingFaceDataset(BaseDataset): + """Dataset loader for HuggingFace datasets.""" + + def __init__( + self, + hf_id: str, + img_column: str = "img", + label_column: str = "label", + ): + """Initialize HuggingFace dataset loader. + + Args: + hf_id: HuggingFace dataset ID. + img_column: Name of the image column. + label_column: Name of the label column. + """ + self.hf_id = hf_id + self.img_column = img_column + self.label_column = label_column + self._train_dataset: Any = None + self._test_dataset: Any = None + + def _load(self) -> tuple[Any, Any]: + """Load dataset from HuggingFace. + + Returns: + Tuple of (train_dataset, test_dataset). + """ + if self._train_dataset is None: + dataset = load_dataset(self.hf_id) + # Handle datasets that use 'train' and 'test' splits + if "train" in dataset: + self._train_dataset = dataset["train"] + if "test" in dataset: + self._test_dataset = dataset["test"] + # Handle datasets that use 'train' and 'validation' splits + elif "validation" in dataset: + self._test_dataset = dataset["validation"] + return self._train_dataset, self._test_dataset + + def get_train_split(self) -> Any: + """Get training split of the dataset. + + Returns: + Training dataset. + """ + train, _ = self._load() + return train + + def get_test_split(self) -> Any: + """Get test/evaluation split of the dataset. + + Returns: + Test dataset. + """ + _, test = self._load() + return test diff --git a/mini-nav/benchmarks/datasets/local.py b/mini-nav/benchmarks/datasets/local.py new file mode 100644 index 0000000..e3ea3b1 --- /dev/null +++ b/mini-nav/benchmarks/datasets/local.py @@ -0,0 +1,157 @@ +"""Local dataset loader for benchmark evaluation.""" + +from pathlib import Path +from typing import Any, Optional + +from ..base import BaseDataset + + +class LocalDataset(BaseDataset): + """Dataset loader for local datasets.""" + + def __init__( + self, + local_path: str, + img_column: str = "image_path", + label_column: str = "label", + ): + """Initialize local dataset loader. + + Args: + local_path: Path to local dataset directory or CSV file. + img_column: Name of the image path column. + label_column: Name of the label column. + """ + self.local_path = Path(local_path) + self.img_column = img_column + self.label_column = label_column + self._train_dataset: Optional[Any] = None + self._test_dataset: Optional[Any] = None + + def _load_csv_dataset(self) -> tuple[Any, Any]: + """Load dataset from CSV file. + + Expected CSV format: + label,image_path,x1,y1,x2,y2 + "class_name","path/to/image.jpg",100,200,300,400 + + Returns: + Tuple of (train_dataset, test_dataset). + """ + import pandas as pd + + from torch.utils.data import Dataset as TorchDataset + + # Load CSV file + df = pd.read_csv(self.local_path) + + # Create a simple dataset class + class CSVDataset(TorchDataset): + def __init__(self, dataframe: pd.DataFrame, img_col: str, label_col: str): + self.df = dataframe.reset_index(drop=True) + self.img_col = img_col + self.label_col = label_col + + def __len__(self) -> int: + return len(self.df) + + def __getitem__(self, idx: int) -> dict[str, Any]: + row = self.df.iloc[idx] + return { + "img": row[self.img_col], + "label": row[self.label_col], + } + + # Split into train/test (80/20) + split_idx = int(len(df) * 0.8) + train_df = df.iloc[:split_idx] + test_df = df.iloc[split_idx:] + + self._train_dataset = CSVDataset(train_df, self.img_column, self.label_column) + self._test_dataset = CSVDataset(test_df, self.img_column, self.label_column) + + return self._train_dataset, self._test_dataset + + def _load_directory_dataset(self) -> tuple[Any, Any]: + """Load dataset from directory structure. + + Expected structure: + local_path/ + train/ + class_name_1/ + image1.jpg + image2.jpg + class_name_2/ + image1.jpg + test/ + class_name_1/ + image1.jpg + + Returns: + Tuple of (train_dataset, test_dataset). + """ + from torch.utils.data import Dataset as TorchDataset + from PIL import Image + + class DirectoryDataset(TorchDataset): + def __init__(self, root_dir: Path, transform=None): + self.root_dir = root_dir + self.transform = transform + self.samples = [] + self.label_map = {} + + # Build label map + classes = sorted([d.name for d in root_dir.iterdir() if d.is_dir()]) + self.label_map = {cls: idx for idx, cls in enumerate(classes)} + + # Build sample list + for cls_dir in root_dir.iterdir(): + if cls_dir.is_dir(): + label = self.label_map[cls_dir.name] + for img_path in cls_dir.iterdir(): + if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]: + self.samples.append((img_path, label)) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int) -> dict[str, Any]: + img_path, label = self.samples[idx] + image = Image.open(img_path).convert("RGB") + return {"img": image, "label": label} + + train_dir = self.local_path / "train" + test_dir = self.local_path / "test" + + if train_dir.exists(): + self._train_dataset = DirectoryDataset(train_dir) + if test_dir.exists(): + self._test_dataset = DirectoryDataset(test_dir) + + return self._train_dataset, self._test_dataset + + def get_train_split(self) -> Any: + """Get training split of the dataset. + + Returns: + Training dataset. + """ + if self._train_dataset is None: + if self.local_path.suffix.lower() == ".csv": + self._load_csv_dataset() + else: + self._load_directory_dataset() + return self._train_dataset + + def get_test_split(self) -> Any: + """Get test/evaluation split of the dataset. + + Returns: + Test dataset. + """ + if self._test_dataset is None: + if self.local_path.suffix.lower() == ".csv": + self._load_csv_dataset() + else: + self._load_directory_dataset() + return self._test_dataset diff --git a/mini-nav/benchmarks/runner.py b/mini-nav/benchmarks/runner.py new file mode 100644 index 0000000..4f11cb8 --- /dev/null +++ b/mini-nav/benchmarks/runner.py @@ -0,0 +1,186 @@ +"""Benchmark runner for executing evaluations.""" + +from pathlib import Path +from typing import Any + +import lancedb +from benchmarks.datasets import HuggingFaceDataset, LocalDataset +from benchmarks.tasks import get_task +from configs.models import BenchmarkConfig, DatasetSourceConfig + + +def create_dataset(config: DatasetSourceConfig) -> Any: + """Create a dataset instance from configuration. + + Args: + config: Dataset source configuration. + + Returns: + Dataset instance. + + Raises: + ValueError: If source_type is not supported. + """ + if config.source_type == "huggingface": + return HuggingFaceDataset( + hf_id=config.path, + img_column=config.img_column, + label_column=config.label_column, + ) + elif config.source_type == "local": + return LocalDataset( + local_path=config.path, + img_column=config.img_column, + label_column=config.label_column, + ) + else: + raise ValueError( + f"Unsupported source_type: {config.source_type}. " + f"Supported types: 'huggingface', 'local'" + ) + + +def _get_table_name(config: BenchmarkConfig, model_name: str) -> str: + """Generate database table name from config and model name. + + Args: + config: Benchmark configuration. + model_name: Model name for table naming. + + Returns: + Formatted table name. + """ + prefix = config.model_table_prefix + # Use dataset path as part of table name (sanitize) + dataset_name = Path(config.dataset.path).name.lower().replace("-", "_") + return f"{prefix}_{dataset_name}_{model_name}" + + +def _ensure_table( + config: BenchmarkConfig, + model_name: str, + vector_dim: int, +) -> lancedb.table.Table: + """Ensure the LanceDB table exists with correct schema. + + Args: + config: Benchmark configuration. + model_name: Model name for table naming. + vector_dim: Feature vector dimension. + + Returns: + LanceDB table instance. + """ + import pyarrow as pa + from database import db_manager + + table_name = _get_table_name(config, model_name) + + # Build expected schema + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("label", pa.int32()), + pa.field("vector", pa.list_(pa.float32(), vector_dim)), + ] + ) + + db = db_manager.db + existing_tables = db.list_tables().tables + + # Check if table exists and has correct schema + if table_name in existing_tables: + table = db.open_table(table_name) + if table.schema != schema: + print(f"Table '{table_name}' schema mismatch, rebuilding.") + db.drop_table(table_name) + table = db.create_table(table_name, schema=schema) + else: + table = db.create_table(table_name, schema=schema) + + return table + + +def run_benchmark( + model: Any, + processor: Any, + config: BenchmarkConfig, + model_name: str = "model", +) -> dict[str, Any]: + """Run benchmark evaluation. + + Workflow: + 1. Create dataset from configuration + 2. Create benchmark task from configuration + 3. Build evaluation database from training set + 4. Evaluate on test set + + Args: + model: Feature extraction model. + processor: Image preprocessor. + config: Benchmark configuration. + model_name: Model name for table naming. + + Returns: + Dictionary containing evaluation results. + + Raises: + ValueError: If benchmark is not enabled in config. + """ + if not config.enabled: + raise ValueError( + "Benchmark is not enabled. Set benchmark.enabled=true in config.yaml" + ) + + # Create dataset + print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}") + dataset = create_dataset(config.dataset) + + # Get train and test splits + train_dataset = dataset.get_train_split() + test_dataset = dataset.get_test_split() + + if train_dataset is None or test_dataset is None: + raise ValueError( + f"Dataset {config.dataset.path} does not have train/test splits" + ) + + # Infer vector dimension from a sample + sample = train_dataset[0] + sample_image = sample["img"] + + from .tasks.retrieval import _infer_vector_dim + + vector_dim = _infer_vector_dim(processor, model, sample_image) + print(f"Model output dimension: {vector_dim}") + + # Ensure table exists with correct schema + table = _ensure_table(config, model_name, vector_dim) + table_name = _get_table_name(config, model_name) + + # Check if database is already built + table_count = table.count_rows() + if table_count > 0: + print( + f"Table '{table_name}' already has {table_count} entries, skipping database build." + ) + else: + # Create and run benchmark task + task = get_task(config.task.type, top_k=config.task.top_k) + print(f"Building database with {len(train_dataset)} training samples...") + task.build_database(model, processor, train_dataset, table, config.batch_size) + + # Run evaluation + task = get_task(config.task.type, top_k=config.task.top_k) + print(f"Evaluating on {len(test_dataset)} test samples...") + results = task.evaluate(model, processor, test_dataset, table, config.batch_size) + + # Print results + print("\n=== Benchmark Results ===") + print(f"Dataset: {config.dataset.path}") + print(f"Task: {config.task.name}") + print(f"Top-K: {results['top_k']}") + print(f"Accuracy: {results['accuracy']:.4f}") + print(f"Correct: {results['correct']}/{results['total']}") + + return results diff --git a/mini-nav/benchmarks/tasks/__init__.py b/mini-nav/benchmarks/tasks/__init__.py new file mode 100644 index 0000000..4bd7a2a --- /dev/null +++ b/mini-nav/benchmarks/tasks/__init__.py @@ -0,0 +1,6 @@ +"""Benchmark evaluation tasks.""" + +from .retrieval import RetrievalTask +from .registry import TASK_REGISTRY, get_task + +__all__ = ["RetrievalTask", "TASK_REGISTRY", "get_task"] \ No newline at end of file diff --git a/mini-nav/benchmarks/tasks/registry.py b/mini-nav/benchmarks/tasks/registry.py new file mode 100644 index 0000000..d6636de --- /dev/null +++ b/mini-nav/benchmarks/tasks/registry.py @@ -0,0 +1,59 @@ +"""Task registry for benchmark evaluation.""" + +from typing import Any, Type + +from benchmarks.base import BaseBenchmarkTask + +# Task registry: maps task type string to task class +TASK_REGISTRY: dict[str, Type[BaseBenchmarkTask]] = {} + + +class RegisterTask: + """Decorator class to register a benchmark task. + + Usage: + @register_task("retrieval") + class RetrievalTask(BaseBenchmarkTask): + ... + """ + + def __init__(self, task_type: str): + """Initialize the decorator with task type. + + Args: + task_type: Task type identifier. + """ + self.task_type = task_type + + def __call__(self, cls: type[BaseBenchmarkTask]) -> type[BaseBenchmarkTask]: + """Register the decorated class to the task registry. + + Args: + cls: The class to be decorated. + + Returns: + The unmodified class. + """ + TASK_REGISTRY[self.task_type] = cls + return cls + + +def get_task(task_type: str, **kwargs: Any) -> BaseBenchmarkTask: + """Get a benchmark task instance by type. + + Args: + task_type: Task type identifier. + **kwargs: Additional arguments passed to task constructor. + + Returns: + Task instance. + + Raises: + ValueError: If task type is not registered. + """ + if task_type not in TASK_REGISTRY: + available = list(TASK_REGISTRY.keys()) + raise ValueError( + f"Unknown task type: {task_type}. Available tasks: {available}" + ) + return TASK_REGISTRY[task_type](**kwargs) diff --git a/mini-nav/benchmarks/task_eval.py b/mini-nav/benchmarks/tasks/retrieval.py similarity index 50% rename from mini-nav/benchmarks/task_eval.py rename to mini-nav/benchmarks/tasks/retrieval.py index 5009587..2878c51 100644 --- a/mini-nav/benchmarks/task_eval.py +++ b/mini-nav/benchmarks/tasks/retrieval.py @@ -1,41 +1,22 @@ -from typing import Dict, Literal, cast +"""Retrieval task for benchmark evaluation (Recall@K).""" + +from typing import Any, cast import lancedb import pyarrow as pa import torch -from database import db_manager -from datasets import load_dataset +from benchmarks.base import BaseBenchmarkTask +from benchmarks.tasks.registry import RegisterTask from torch import nn from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import BitImageProcessorFast -# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名) -DATASET_CONFIG: Dict[str, tuple[str, str, str]] = { - "CIFAR-10": ("uoft-cs/cifar10", "img", "label"), - "CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"), -} - - -def _get_table_name(dataset: str, model_name: str) -> str: - """Generate database table name from dataset and model name. - - Args: - dataset: Dataset name, e.g. "CIFAR-10". - model_name: Model name, e.g. "Dinov2". - - Returns: - Formatted table name, e.g. "cifar10_dinov2". - """ - ds_part = dataset.lower().replace("-", "") - model_part = model_name.lower() - return f"{ds_part}_{model_part}" - def _infer_vector_dim( processor: BitImageProcessorFast, model: nn.Module, - sample_image, + sample_image: Any, ) -> int: """Infer model output vector dimension via a single forward pass. @@ -55,7 +36,6 @@ def _infer_vector_dim( inputs.to(device) output = model(inputs) - # output shape: [1, dim] return output.shape[-1] @@ -100,16 +80,13 @@ def _establish_eval_database( imgs = batch["img"] labels = batch["label"] - # 预处理并推理 inputs = processor(imgs, return_tensors="pt") inputs.to(device) - outputs = model(inputs) # [B, dim] + outputs = model(inputs) - # 整个batch一次性转到CPU features = cast(torch.Tensor, outputs).cpu() labels_list = labels.tolist() - # 逐条写入数据库 batch_size = len(labels_list) table.add( [ @@ -134,9 +111,6 @@ def _evaluate_recall( ) -> tuple[int, int]: """Evaluate Recall@K by searching the database for each test image. - For each batch, features are extracted in one forward pass and moved to CPU, - then each sample is searched individually against the database. - Args: processor: Image preprocessor. model: Feature extraction model. @@ -157,21 +131,17 @@ def _evaluate_recall( imgs = batch["img"] labels = batch["label"] - # 批量前向推理 inputs = processor(imgs, return_tensors="pt") inputs.to(device) - outputs = model(inputs) # [B, dim] + outputs = model(inputs) - # 整个batch一次性转到CPU features = cast(torch.Tensor, outputs).cpu() labels_list = labels.tolist() - # 逐条搜索并验证 for j in range(len(labels_list)): feature = features[j].tolist() true_label = labels_list[j] - # 搜索 top_k 最相似结果 results = ( table.search(feature) .select(["label", "_distance"]) @@ -179,7 +149,6 @@ def _evaluate_recall( .to_polars() ) - # 检查 top_k 中是否包含正确标签 retrieved_labels = results["label"].to_list() if true_label in retrieved_labels: correct += 1 @@ -188,69 +157,51 @@ def _evaluate_recall( return correct, total -def task_eval( - processor: BitImageProcessorFast, - model: nn.Module, - dataset: Literal["CIFAR-10", "CIFAR-100"], - model_name: str, - top_k: int = 10, - batch_size: int = 64, -) -> float: - """Evaluate model Recall@K accuracy on a dataset using vector retrieval. +@RegisterTask("retrieval") +class RetrievalTask(BaseBenchmarkTask): + """Retrieval evaluation task (Recall@K).""" - Workflow: - 1. Create or open a database table named by dataset and model. - 2. Build database from training set features (skip if table exists). - 3. Evaluate on test set: extract features in batches, search top_k, - check if correct label appears in results. + def __init__(self, top_k: int = 10): + """Initialize retrieval task. - Args: - processor: Image preprocessor. - model: Feature extraction model. - dataset: Dataset name. - model_name: Model name, used for table name generation. - top_k: Number of top similar results to retrieve. - batch_size: Batch size for DataLoader. + Args: + top_k: Number of top results to retrieve for recall calculation. + """ + super().__init__(top_k=top_k) + self.top_k = top_k - Returns: - Recall@K accuracy (0.0 ~ 1.0). + def build_database( + self, + model: Any, + processor: Any, + train_dataset: Any, + table: lancedb.table.Table, + batch_size: int, + ) -> None: + """Build the evaluation database from training data. - Raises: - ValueError: If dataset name is not supported. - """ - if dataset not in DATASET_CONFIG: - raise ValueError( - f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}." - ) - hf_id, img_col, label_col = DATASET_CONFIG[dataset] + Args: + model: Feature extraction model. + processor: Image preprocessor. + train_dataset: Training dataset. + table: LanceDB table to store features. + batch_size: Batch size for DataLoader. + """ + # Get a sample image to infer vector dimension + sample = train_dataset[0] + sample_image = sample["img"] - # 加载数据集 - train_dataset = load_dataset(hf_id, split="train") - test_dataset = load_dataset(hf_id, split="test") + vector_dim = _infer_vector_dim(processor, model, sample_image) + expected_schema = _build_eval_schema(vector_dim) - # 生成表名,推断向量维度 - table_name = _get_table_name(dataset, model_name) - vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col]) - expected_schema = _build_eval_schema(vector_dim) - existing_tables = db_manager.db.list_tables().tables + # Check schema compatibility + if table.schema != expected_schema: + raise ValueError( + f"Table schema mismatch. Expected: {expected_schema}, " + f"Got: {table.schema}" + ) - # 如果旧表 schema 不匹配(如 label 类型变更),删除重建 - if table_name in existing_tables: - old_table = db_manager.db.open_table(table_name) - if old_table.schema != expected_schema: - print(f"Table '{table_name}' schema mismatch, rebuilding.") - db_manager.db.drop_table(table_name) - existing_tables = [] - - if table_name in existing_tables: - # 表已存在且 schema 匹配,跳过建库 - print(f"Table '{table_name}' already exists, skipping database build.") - table = db_manager.db.open_table(table_name) - else: - # 创建新表 - table = db_manager.db.create_table(table_name, schema=expected_schema) - - # 使用 DataLoader 批量建库 + # Build database train_loader = DataLoader( train_dataset.with_format("torch"), batch_size=batch_size, @@ -259,17 +210,45 @@ def task_eval( ) _establish_eval_database(processor, model, table, train_loader) - # 使用 DataLoader 批量评估 - test_loader = DataLoader( - test_dataset.with_format("torch"), - batch_size=batch_size, - shuffle=False, - num_workers=4, - ) - correct, total = _evaluate_recall(processor, model, table, test_loader, top_k) + def evaluate( + self, + model: Any, + processor: Any, + test_dataset: Any, + table: lancedb.table.Table, + batch_size: int, + ) -> dict[str, Any]: + """Evaluate the model on the test dataset. - accuracy = correct / total - print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}") - print(f"Correct: {correct}/{total}") + Args: + model: Feature extraction model. + processor: Image preprocessor. + test_dataset: Test dataset. + table: LanceDB table to search against. + batch_size: Batch size for DataLoader. - return accuracy + Returns: + Dictionary containing evaluation results with keys: + - accuracy: Recall@K accuracy (0.0 ~ 1.0) + - correct: Number of correct predictions + - total: Total number of test samples + - top_k: The K value used + """ + test_loader = DataLoader( + test_dataset.with_format("torch"), + batch_size=batch_size, + shuffle=False, + num_workers=4, + ) + correct, total = _evaluate_recall( + processor, model, table, test_loader, self.top_k + ) + + accuracy = correct / total if total > 0 else 0.0 + + return { + "accuracy": accuracy, + "correct": correct, + "total": total, + "top_k": self.top_k, + } diff --git a/mini-nav/configs/config.yaml b/mini-nav/configs/config.yaml index 230f99a..122bfec 100644 --- a/mini-nav/configs/config.yaml +++ b/mini-nav/configs/config.yaml @@ -2,10 +2,10 @@ model: name: "facebook/dinov2-large" compression_dim: 512 device: "auto" # auto-detect GPU - sam_model: "facebook/sam2.1-hiera-large" # SAM model name - sam_min_mask_area: 100 # Minimum mask area threshold - sam_max_masks: 10 # Maximum number of masks to keep - compressor_path: null # Path to trained HashCompressor weights (optional) + sam_model: "facebook/sam2.1-hiera-large" # SAM model name + sam_min_mask_area: 100 # Minimum mask area threshold + sam_max_masks: 10 # Maximum number of masks to keep + compressor_path: null # Path to trained HashCompressor weights (optional) output: directory: "./outputs" @@ -19,3 +19,17 @@ dataset: rotation_range: [-30, 30] overlap_threshold: 0.3 seed: 42 + +benchmark: + enabled: true + dataset: + source_type: "huggingface" + path: "uoft-cs/cifar10" + img_column: "img" + label_column: "label" + task: + name: "recall_at_k" + type: "retrieval" + top_k: 10 + batch_size: 64 + model_table_prefix: "benchmark" diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py index ec26b2b..52bf795 100644 --- a/mini-nav/configs/models.py +++ b/mini-nav/configs/models.py @@ -1,7 +1,7 @@ """Pydantic data models for feature compressor configuration.""" from pathlib import Path -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -98,6 +98,41 @@ class DatasetConfig(BaseModel): return v +class DatasetSourceConfig(BaseModel): + """Configuration for benchmark dataset source.""" + + model_config = ConfigDict(extra="ignore") + + source_type: Literal["huggingface", "local"] = "huggingface" + path: str = Field(default="", description="HuggingFace dataset ID or local path") + img_column: str = Field(default="img", description="Image column name") + label_column: str = Field(default="label", description="Label column name") + + +class BenchmarkTaskConfig(BaseModel): + """Configuration for benchmark task.""" + + model_config = ConfigDict(extra="ignore") + + name: str = Field(default="recall_at_k", description="Task name") + type: str = Field(default="retrieval", description="Task type") + top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation") + + +class BenchmarkConfig(BaseModel): + """Configuration for benchmark evaluation.""" + + model_config = ConfigDict(extra="ignore") + + enabled: bool = Field(default=False, description="Enable benchmark evaluation") + dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig) + task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig) + batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader") + model_table_prefix: str = Field( + default="benchmark", description="Prefix for LanceDB table names" + ) + + class Config(BaseModel): """Root configuration for the feature compressor.""" @@ -106,3 +141,4 @@ class Config(BaseModel): model: ModelConfig = Field(default_factory=ModelConfig) output: OutputConfig = Field(default_factory=OutputConfig) dataset: DatasetConfig = Field(default_factory=DatasetConfig) + benchmark: BenchmarkConfig = Field(default_factory=BenchmarkConfig) diff --git a/mini-nav/main.py b/mini-nav/main.py index 2bd62df..30aa32e 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -16,9 +16,51 @@ if __name__ == "__main__": epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt" ) elif args.action == "benchmark": - from benchmarks import evaluate + from typing import cast - evaluate("Dinov2", "CIFAR-10", "Recall@10") + import torch + from benchmarks import run_benchmark + from compressors import DinoCompressor + from configs import cfg_manager + from transformers import AutoImageProcessor, BitImageProcessorFast + from utils import get_device + + config = cfg_manager.get() + benchmark_cfg = config.benchmark + + if not benchmark_cfg.enabled: + print("Benchmark is not enabled. Set benchmark.enabled=true in config.yaml") + exit(1) + + device = get_device() + + # Load model and processor based on config + model_cfg = config.model + processor = cast( + BitImageProcessorFast, + AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device), + ) + + # Load compressor weights if specified in model config + model = DinoCompressor().to(device) + if model_cfg.compressor_path is not None: + from compressors import HashCompressor + + compressor = HashCompressor( + input_dim=model_cfg.compression_dim, + output_dim=model_cfg.compression_dim, + ) + compressor.load_state_dict(torch.load(model_cfg.compressor_path)) + # Wrap with compressor if path is specified + model.compressor = compressor + + # Run benchmark + run_benchmark( + model=model, + processor=processor, + config=benchmark_cfg, + model_name="dinov2", + ) elif args.action == "visualize": from visualizer import app diff --git a/mini-nav/tests/test_compressors.py b/mini-nav/tests/test_compressors.py index 374861b..a6e29d1 100644 --- a/mini-nav/tests/test_compressors.py +++ b/mini-nav/tests/test_compressors.py @@ -1,26 +1,21 @@ """Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline).""" -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - import pytest import torch -from PIL import Image - -from configs import cfg_manager from compressors import ( BinarySign, DinoCompressor, HashCompressor, - SegmentCompressor, SAMHashPipeline, - create_pipeline_from_config, + SegmentCompressor, bits_to_hash, - hash_to_bits, + create_pipeline_from_config, hamming_distance, hamming_similarity, + hash_to_bits, ) +from configs import cfg_manager +from PIL import Image class TestHashCompressor: