mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(benchmarks): modularize benchmark system with config-driven execution
This commit is contained in:
@@ -1,59 +1,12 @@
|
||||
from typing import Literal, cast
|
||||
"""Benchmark evaluation module.
|
||||
|
||||
import torch
|
||||
from compressors import DinoCompressor, FloatCompressor
|
||||
from configs import cfg_manager
|
||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||
from utils import get_device
|
||||
This module provides a modular benchmark system that supports:
|
||||
- Multiple dataset sources (HuggingFace, local)
|
||||
- Multiple evaluation tasks (retrieval, with extensibility for more)
|
||||
- Configuration-driven execution
|
||||
"""
|
||||
|
||||
from .task_eval import task_eval
|
||||
from .base import BaseBenchmarkTask, BaseDataset
|
||||
from .runner import run_benchmark
|
||||
|
||||
|
||||
def evaluate(
|
||||
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
benchmark: Literal["Recall@1", "Recall@10"],
|
||||
):
|
||||
"""运行模型评估。
|
||||
|
||||
Args:
|
||||
compressor_model: 压缩模型类型。
|
||||
dataset: 数据集名称。
|
||||
benchmark: 评估指标。
|
||||
"""
|
||||
device = get_device()
|
||||
|
||||
match compressor_model:
|
||||
case "Dinov2":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=device
|
||||
),
|
||||
)
|
||||
model = DinoCompressor().to(device)
|
||||
case "Dinov2WithCompressor":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=device
|
||||
),
|
||||
)
|
||||
output_dir = cfg_manager.get().output.directory
|
||||
compressor = FloatCompressor()
|
||||
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
|
||||
model = DinoCompressor(compressor).to(device)
|
||||
case _:
|
||||
raise ValueError(f"Unknown compressor: {compressor_model}")
|
||||
|
||||
# 根据 benchmark 确定 top_k
|
||||
match benchmark:
|
||||
case "Recall@1":
|
||||
task_eval(processor, model, dataset, compressor_model, top_k=1)
|
||||
case "Recall@10":
|
||||
task_eval(processor, model, dataset, compressor_model, top_k=10)
|
||||
case _:
|
||||
raise ValueError(f"Unknown benchmark: {benchmark}")
|
||||
|
||||
|
||||
__all__ = ["task_eval", "evaluate"]
|
||||
__all__ = ["BaseBenchmarkTask", "BaseDataset", "run_benchmark"]
|
||||
99
mini-nav/benchmarks/base.py
Normal file
99
mini-nav/benchmarks/base.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Base classes for benchmark datasets and tasks."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Protocol
|
||||
|
||||
import lancedb
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
"""Abstract base class for benchmark datasets."""
|
||||
|
||||
@abstractmethod
|
||||
def get_train_split(self) -> Any:
|
||||
"""Get training split of the dataset.
|
||||
|
||||
Returns:
|
||||
Dataset object for training.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_test_split(self) -> Any:
|
||||
"""Get test/evaluation split of the dataset.
|
||||
|
||||
Returns:
|
||||
Dataset object for testing.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseBenchmarkTask(ABC):
|
||||
"""Abstract base class for benchmark evaluation tasks."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the benchmark task.
|
||||
|
||||
Args:
|
||||
**kwargs: Task-specific configuration parameters.
|
||||
"""
|
||||
self.config = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def build_database(
|
||||
self,
|
||||
model: Any,
|
||||
processor: Any,
|
||||
train_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""Build the evaluation database from training data.
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
train_dataset: Training dataset.
|
||||
table: LanceDB table to store features.
|
||||
batch_size: Batch size for DataLoader.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
model: Any,
|
||||
processor: Any,
|
||||
test_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Evaluate the model on the test dataset.
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
test_dataset: Test dataset.
|
||||
table: LanceDB table to search against.
|
||||
batch_size: Batch size for DataLoader.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation results.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetFactory(Protocol):
|
||||
"""Protocol for dataset factory."""
|
||||
|
||||
def __call__(self, config: Any) -> BaseDataset:
|
||||
"""Create a dataset from configuration.
|
||||
|
||||
Args:
|
||||
config: Dataset configuration.
|
||||
|
||||
Returns:
|
||||
Dataset instance.
|
||||
"""
|
||||
...
|
||||
6
mini-nav/benchmarks/datasets/__init__.py
Normal file
6
mini-nav/benchmarks/datasets/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Dataset loaders for benchmark evaluation."""
|
||||
|
||||
from .huggingface import HuggingFaceDataset
|
||||
from .local import LocalDataset
|
||||
|
||||
__all__ = ["HuggingFaceDataset", "LocalDataset"]
|
||||
66
mini-nav/benchmarks/datasets/huggingface.py
Normal file
66
mini-nav/benchmarks/datasets/huggingface.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""HuggingFace dataset loader for benchmark evaluation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from ..base import BaseDataset
|
||||
|
||||
|
||||
class HuggingFaceDataset(BaseDataset):
|
||||
"""Dataset loader for HuggingFace datasets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hf_id: str,
|
||||
img_column: str = "img",
|
||||
label_column: str = "label",
|
||||
):
|
||||
"""Initialize HuggingFace dataset loader.
|
||||
|
||||
Args:
|
||||
hf_id: HuggingFace dataset ID.
|
||||
img_column: Name of the image column.
|
||||
label_column: Name of the label column.
|
||||
"""
|
||||
self.hf_id = hf_id
|
||||
self.img_column = img_column
|
||||
self.label_column = label_column
|
||||
self._train_dataset: Any = None
|
||||
self._test_dataset: Any = None
|
||||
|
||||
def _load(self) -> tuple[Any, Any]:
|
||||
"""Load dataset from HuggingFace.
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, test_dataset).
|
||||
"""
|
||||
if self._train_dataset is None:
|
||||
dataset = load_dataset(self.hf_id)
|
||||
# Handle datasets that use 'train' and 'test' splits
|
||||
if "train" in dataset:
|
||||
self._train_dataset = dataset["train"]
|
||||
if "test" in dataset:
|
||||
self._test_dataset = dataset["test"]
|
||||
# Handle datasets that use 'train' and 'validation' splits
|
||||
elif "validation" in dataset:
|
||||
self._test_dataset = dataset["validation"]
|
||||
return self._train_dataset, self._test_dataset
|
||||
|
||||
def get_train_split(self) -> Any:
|
||||
"""Get training split of the dataset.
|
||||
|
||||
Returns:
|
||||
Training dataset.
|
||||
"""
|
||||
train, _ = self._load()
|
||||
return train
|
||||
|
||||
def get_test_split(self) -> Any:
|
||||
"""Get test/evaluation split of the dataset.
|
||||
|
||||
Returns:
|
||||
Test dataset.
|
||||
"""
|
||||
_, test = self._load()
|
||||
return test
|
||||
157
mini-nav/benchmarks/datasets/local.py
Normal file
157
mini-nav/benchmarks/datasets/local.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Local dataset loader for benchmark evaluation."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..base import BaseDataset
|
||||
|
||||
|
||||
class LocalDataset(BaseDataset):
|
||||
"""Dataset loader for local datasets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_path: str,
|
||||
img_column: str = "image_path",
|
||||
label_column: str = "label",
|
||||
):
|
||||
"""Initialize local dataset loader.
|
||||
|
||||
Args:
|
||||
local_path: Path to local dataset directory or CSV file.
|
||||
img_column: Name of the image path column.
|
||||
label_column: Name of the label column.
|
||||
"""
|
||||
self.local_path = Path(local_path)
|
||||
self.img_column = img_column
|
||||
self.label_column = label_column
|
||||
self._train_dataset: Optional[Any] = None
|
||||
self._test_dataset: Optional[Any] = None
|
||||
|
||||
def _load_csv_dataset(self) -> tuple[Any, Any]:
|
||||
"""Load dataset from CSV file.
|
||||
|
||||
Expected CSV format:
|
||||
label,image_path,x1,y1,x2,y2
|
||||
"class_name","path/to/image.jpg",100,200,300,400
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, test_dataset).
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
# Load CSV file
|
||||
df = pd.read_csv(self.local_path)
|
||||
|
||||
# Create a simple dataset class
|
||||
class CSVDataset(TorchDataset):
|
||||
def __init__(self, dataframe: pd.DataFrame, img_col: str, label_col: str):
|
||||
self.df = dataframe.reset_index(drop=True)
|
||||
self.img_col = img_col
|
||||
self.label_col = label_col
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
row = self.df.iloc[idx]
|
||||
return {
|
||||
"img": row[self.img_col],
|
||||
"label": row[self.label_col],
|
||||
}
|
||||
|
||||
# Split into train/test (80/20)
|
||||
split_idx = int(len(df) * 0.8)
|
||||
train_df = df.iloc[:split_idx]
|
||||
test_df = df.iloc[split_idx:]
|
||||
|
||||
self._train_dataset = CSVDataset(train_df, self.img_column, self.label_column)
|
||||
self._test_dataset = CSVDataset(test_df, self.img_column, self.label_column)
|
||||
|
||||
return self._train_dataset, self._test_dataset
|
||||
|
||||
def _load_directory_dataset(self) -> tuple[Any, Any]:
|
||||
"""Load dataset from directory structure.
|
||||
|
||||
Expected structure:
|
||||
local_path/
|
||||
train/
|
||||
class_name_1/
|
||||
image1.jpg
|
||||
image2.jpg
|
||||
class_name_2/
|
||||
image1.jpg
|
||||
test/
|
||||
class_name_1/
|
||||
image1.jpg
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, test_dataset).
|
||||
"""
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from PIL import Image
|
||||
|
||||
class DirectoryDataset(TorchDataset):
|
||||
def __init__(self, root_dir: Path, transform=None):
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
self.label_map = {}
|
||||
|
||||
# Build label map
|
||||
classes = sorted([d.name for d in root_dir.iterdir() if d.is_dir()])
|
||||
self.label_map = {cls: idx for idx, cls in enumerate(classes)}
|
||||
|
||||
# Build sample list
|
||||
for cls_dir in root_dir.iterdir():
|
||||
if cls_dir.is_dir():
|
||||
label = self.label_map[cls_dir.name]
|
||||
for img_path in cls_dir.iterdir():
|
||||
if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
|
||||
self.samples.append((img_path, label))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
img_path, label = self.samples[idx]
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
return {"img": image, "label": label}
|
||||
|
||||
train_dir = self.local_path / "train"
|
||||
test_dir = self.local_path / "test"
|
||||
|
||||
if train_dir.exists():
|
||||
self._train_dataset = DirectoryDataset(train_dir)
|
||||
if test_dir.exists():
|
||||
self._test_dataset = DirectoryDataset(test_dir)
|
||||
|
||||
return self._train_dataset, self._test_dataset
|
||||
|
||||
def get_train_split(self) -> Any:
|
||||
"""Get training split of the dataset.
|
||||
|
||||
Returns:
|
||||
Training dataset.
|
||||
"""
|
||||
if self._train_dataset is None:
|
||||
if self.local_path.suffix.lower() == ".csv":
|
||||
self._load_csv_dataset()
|
||||
else:
|
||||
self._load_directory_dataset()
|
||||
return self._train_dataset
|
||||
|
||||
def get_test_split(self) -> Any:
|
||||
"""Get test/evaluation split of the dataset.
|
||||
|
||||
Returns:
|
||||
Test dataset.
|
||||
"""
|
||||
if self._test_dataset is None:
|
||||
if self.local_path.suffix.lower() == ".csv":
|
||||
self._load_csv_dataset()
|
||||
else:
|
||||
self._load_directory_dataset()
|
||||
return self._test_dataset
|
||||
186
mini-nav/benchmarks/runner.py
Normal file
186
mini-nav/benchmarks/runner.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Benchmark runner for executing evaluations."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
|
||||
from benchmarks.tasks import get_task
|
||||
from configs.models import BenchmarkConfig, DatasetSourceConfig
|
||||
|
||||
|
||||
def create_dataset(config: DatasetSourceConfig) -> Any:
|
||||
"""Create a dataset instance from configuration.
|
||||
|
||||
Args:
|
||||
config: Dataset source configuration.
|
||||
|
||||
Returns:
|
||||
Dataset instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If source_type is not supported.
|
||||
"""
|
||||
if config.source_type == "huggingface":
|
||||
return HuggingFaceDataset(
|
||||
hf_id=config.path,
|
||||
img_column=config.img_column,
|
||||
label_column=config.label_column,
|
||||
)
|
||||
elif config.source_type == "local":
|
||||
return LocalDataset(
|
||||
local_path=config.path,
|
||||
img_column=config.img_column,
|
||||
label_column=config.label_column,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported source_type: {config.source_type}. "
|
||||
f"Supported types: 'huggingface', 'local'"
|
||||
)
|
||||
|
||||
|
||||
def _get_table_name(config: BenchmarkConfig, model_name: str) -> str:
|
||||
"""Generate database table name from config and model name.
|
||||
|
||||
Args:
|
||||
config: Benchmark configuration.
|
||||
model_name: Model name for table naming.
|
||||
|
||||
Returns:
|
||||
Formatted table name.
|
||||
"""
|
||||
prefix = config.model_table_prefix
|
||||
# Use dataset path as part of table name (sanitize)
|
||||
dataset_name = Path(config.dataset.path).name.lower().replace("-", "_")
|
||||
return f"{prefix}_{dataset_name}_{model_name}"
|
||||
|
||||
|
||||
def _ensure_table(
|
||||
config: BenchmarkConfig,
|
||||
model_name: str,
|
||||
vector_dim: int,
|
||||
) -> lancedb.table.Table:
|
||||
"""Ensure the LanceDB table exists with correct schema.
|
||||
|
||||
Args:
|
||||
config: Benchmark configuration.
|
||||
model_name: Model name for table naming.
|
||||
vector_dim: Feature vector dimension.
|
||||
|
||||
Returns:
|
||||
LanceDB table instance.
|
||||
"""
|
||||
import pyarrow as pa
|
||||
from database import db_manager
|
||||
|
||||
table_name = _get_table_name(config, model_name)
|
||||
|
||||
# Build expected schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("label", pa.int32()),
|
||||
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||
]
|
||||
)
|
||||
|
||||
db = db_manager.db
|
||||
existing_tables = db.list_tables().tables
|
||||
|
||||
# Check if table exists and has correct schema
|
||||
if table_name in existing_tables:
|
||||
table = db.open_table(table_name)
|
||||
if table.schema != schema:
|
||||
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
||||
db.drop_table(table_name)
|
||||
table = db.create_table(table_name, schema=schema)
|
||||
else:
|
||||
table = db.create_table(table_name, schema=schema)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model: Any,
|
||||
processor: Any,
|
||||
config: BenchmarkConfig,
|
||||
model_name: str = "model",
|
||||
) -> dict[str, Any]:
|
||||
"""Run benchmark evaluation.
|
||||
|
||||
Workflow:
|
||||
1. Create dataset from configuration
|
||||
2. Create benchmark task from configuration
|
||||
3. Build evaluation database from training set
|
||||
4. Evaluate on test set
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
config: Benchmark configuration.
|
||||
model_name: Model name for table naming.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation results.
|
||||
|
||||
Raises:
|
||||
ValueError: If benchmark is not enabled in config.
|
||||
"""
|
||||
if not config.enabled:
|
||||
raise ValueError(
|
||||
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
|
||||
)
|
||||
|
||||
# Create dataset
|
||||
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
|
||||
dataset = create_dataset(config.dataset)
|
||||
|
||||
# Get train and test splits
|
||||
train_dataset = dataset.get_train_split()
|
||||
test_dataset = dataset.get_test_split()
|
||||
|
||||
if train_dataset is None or test_dataset is None:
|
||||
raise ValueError(
|
||||
f"Dataset {config.dataset.path} does not have train/test splits"
|
||||
)
|
||||
|
||||
# Infer vector dimension from a sample
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["img"]
|
||||
|
||||
from .tasks.retrieval import _infer_vector_dim
|
||||
|
||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||
print(f"Model output dimension: {vector_dim}")
|
||||
|
||||
# Ensure table exists with correct schema
|
||||
table = _ensure_table(config, model_name, vector_dim)
|
||||
table_name = _get_table_name(config, model_name)
|
||||
|
||||
# Check if database is already built
|
||||
table_count = table.count_rows()
|
||||
if table_count > 0:
|
||||
print(
|
||||
f"Table '{table_name}' already has {table_count} entries, skipping database build."
|
||||
)
|
||||
else:
|
||||
# Create and run benchmark task
|
||||
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||
print(f"Building database with {len(train_dataset)} training samples...")
|
||||
task.build_database(model, processor, train_dataset, table, config.batch_size)
|
||||
|
||||
# Run evaluation
|
||||
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||
print(f"Evaluating on {len(test_dataset)} test samples...")
|
||||
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
|
||||
|
||||
# Print results
|
||||
print("\n=== Benchmark Results ===")
|
||||
print(f"Dataset: {config.dataset.path}")
|
||||
print(f"Task: {config.task.name}")
|
||||
print(f"Top-K: {results['top_k']}")
|
||||
print(f"Accuracy: {results['accuracy']:.4f}")
|
||||
print(f"Correct: {results['correct']}/{results['total']}")
|
||||
|
||||
return results
|
||||
6
mini-nav/benchmarks/tasks/__init__.py
Normal file
6
mini-nav/benchmarks/tasks/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Benchmark evaluation tasks."""
|
||||
|
||||
from .retrieval import RetrievalTask
|
||||
from .registry import TASK_REGISTRY, get_task
|
||||
|
||||
__all__ = ["RetrievalTask", "TASK_REGISTRY", "get_task"]
|
||||
59
mini-nav/benchmarks/tasks/registry.py
Normal file
59
mini-nav/benchmarks/tasks/registry.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Task registry for benchmark evaluation."""
|
||||
|
||||
from typing import Any, Type
|
||||
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
|
||||
# Task registry: maps task type string to task class
|
||||
TASK_REGISTRY: dict[str, Type[BaseBenchmarkTask]] = {}
|
||||
|
||||
|
||||
class RegisterTask:
|
||||
"""Decorator class to register a benchmark task.
|
||||
|
||||
Usage:
|
||||
@register_task("retrieval")
|
||||
class RetrievalTask(BaseBenchmarkTask):
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, task_type: str):
|
||||
"""Initialize the decorator with task type.
|
||||
|
||||
Args:
|
||||
task_type: Task type identifier.
|
||||
"""
|
||||
self.task_type = task_type
|
||||
|
||||
def __call__(self, cls: type[BaseBenchmarkTask]) -> type[BaseBenchmarkTask]:
|
||||
"""Register the decorated class to the task registry.
|
||||
|
||||
Args:
|
||||
cls: The class to be decorated.
|
||||
|
||||
Returns:
|
||||
The unmodified class.
|
||||
"""
|
||||
TASK_REGISTRY[self.task_type] = cls
|
||||
return cls
|
||||
|
||||
|
||||
def get_task(task_type: str, **kwargs: Any) -> BaseBenchmarkTask:
|
||||
"""Get a benchmark task instance by type.
|
||||
|
||||
Args:
|
||||
task_type: Task type identifier.
|
||||
**kwargs: Additional arguments passed to task constructor.
|
||||
|
||||
Returns:
|
||||
Task instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If task type is not registered.
|
||||
"""
|
||||
if task_type not in TASK_REGISTRY:
|
||||
available = list(TASK_REGISTRY.keys())
|
||||
raise ValueError(
|
||||
f"Unknown task type: {task_type}. Available tasks: {available}"
|
||||
)
|
||||
return TASK_REGISTRY[task_type](**kwargs)
|
||||
@@ -1,41 +1,22 @@
|
||||
from typing import Dict, Literal, cast
|
||||
"""Retrieval task for benchmark evaluation (Recall@K)."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
from database import db_manager
|
||||
from datasets import load_dataset
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
from benchmarks.tasks.registry import RegisterTask
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import BitImageProcessorFast
|
||||
|
||||
# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名)
|
||||
DATASET_CONFIG: Dict[str, tuple[str, str, str]] = {
|
||||
"CIFAR-10": ("uoft-cs/cifar10", "img", "label"),
|
||||
"CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"),
|
||||
}
|
||||
|
||||
|
||||
def _get_table_name(dataset: str, model_name: str) -> str:
|
||||
"""Generate database table name from dataset and model name.
|
||||
|
||||
Args:
|
||||
dataset: Dataset name, e.g. "CIFAR-10".
|
||||
model_name: Model name, e.g. "Dinov2".
|
||||
|
||||
Returns:
|
||||
Formatted table name, e.g. "cifar10_dinov2".
|
||||
"""
|
||||
ds_part = dataset.lower().replace("-", "")
|
||||
model_part = model_name.lower()
|
||||
return f"{ds_part}_{model_part}"
|
||||
|
||||
|
||||
def _infer_vector_dim(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
sample_image,
|
||||
sample_image: Any,
|
||||
) -> int:
|
||||
"""Infer model output vector dimension via a single forward pass.
|
||||
|
||||
@@ -55,7 +36,6 @@ def _infer_vector_dim(
|
||||
inputs.to(device)
|
||||
output = model(inputs)
|
||||
|
||||
# output shape: [1, dim]
|
||||
return output.shape[-1]
|
||||
|
||||
|
||||
@@ -100,16 +80,13 @@ def _establish_eval_database(
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
# 预处理并推理
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs) # [B, dim]
|
||||
outputs = model(inputs)
|
||||
|
||||
# 整个batch一次性转到CPU
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
# 逐条写入数据库
|
||||
batch_size = len(labels_list)
|
||||
table.add(
|
||||
[
|
||||
@@ -134,9 +111,6 @@ def _evaluate_recall(
|
||||
) -> tuple[int, int]:
|
||||
"""Evaluate Recall@K by searching the database for each test image.
|
||||
|
||||
For each batch, features are extracted in one forward pass and moved to CPU,
|
||||
then each sample is searched individually against the database.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
@@ -157,21 +131,17 @@ def _evaluate_recall(
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
# 批量前向推理
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs) # [B, dim]
|
||||
outputs = model(inputs)
|
||||
|
||||
# 整个batch一次性转到CPU
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
# 逐条搜索并验证
|
||||
for j in range(len(labels_list)):
|
||||
feature = features[j].tolist()
|
||||
true_label = labels_list[j]
|
||||
|
||||
# 搜索 top_k 最相似结果
|
||||
results = (
|
||||
table.search(feature)
|
||||
.select(["label", "_distance"])
|
||||
@@ -179,7 +149,6 @@ def _evaluate_recall(
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
# 检查 top_k 中是否包含正确标签
|
||||
retrieved_labels = results["label"].to_list()
|
||||
if true_label in retrieved_labels:
|
||||
correct += 1
|
||||
@@ -188,69 +157,51 @@ def _evaluate_recall(
|
||||
return correct, total
|
||||
|
||||
|
||||
def task_eval(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
model_name: str,
|
||||
top_k: int = 10,
|
||||
batch_size: int = 64,
|
||||
) -> float:
|
||||
"""Evaluate model Recall@K accuracy on a dataset using vector retrieval.
|
||||
@RegisterTask("retrieval")
|
||||
class RetrievalTask(BaseBenchmarkTask):
|
||||
"""Retrieval evaluation task (Recall@K)."""
|
||||
|
||||
Workflow:
|
||||
1. Create or open a database table named by dataset and model.
|
||||
2. Build database from training set features (skip if table exists).
|
||||
3. Evaluate on test set: extract features in batches, search top_k,
|
||||
check if correct label appears in results.
|
||||
def __init__(self, top_k: int = 10):
|
||||
"""Initialize retrieval task.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
dataset: Dataset name.
|
||||
model_name: Model name, used for table name generation.
|
||||
top_k: Number of top similar results to retrieve.
|
||||
batch_size: Batch size for DataLoader.
|
||||
Args:
|
||||
top_k: Number of top results to retrieve for recall calculation.
|
||||
"""
|
||||
super().__init__(top_k=top_k)
|
||||
self.top_k = top_k
|
||||
|
||||
Returns:
|
||||
Recall@K accuracy (0.0 ~ 1.0).
|
||||
def build_database(
|
||||
self,
|
||||
model: Any,
|
||||
processor: Any,
|
||||
train_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""Build the evaluation database from training data.
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset name is not supported.
|
||||
"""
|
||||
if dataset not in DATASET_CONFIG:
|
||||
raise ValueError(
|
||||
f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}."
|
||||
)
|
||||
hf_id, img_col, label_col = DATASET_CONFIG[dataset]
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
train_dataset: Training dataset.
|
||||
table: LanceDB table to store features.
|
||||
batch_size: Batch size for DataLoader.
|
||||
"""
|
||||
# Get a sample image to infer vector dimension
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["img"]
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = load_dataset(hf_id, split="train")
|
||||
test_dataset = load_dataset(hf_id, split="test")
|
||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||
expected_schema = _build_eval_schema(vector_dim)
|
||||
|
||||
# 生成表名,推断向量维度
|
||||
table_name = _get_table_name(dataset, model_name)
|
||||
vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col])
|
||||
expected_schema = _build_eval_schema(vector_dim)
|
||||
existing_tables = db_manager.db.list_tables().tables
|
||||
# Check schema compatibility
|
||||
if table.schema != expected_schema:
|
||||
raise ValueError(
|
||||
f"Table schema mismatch. Expected: {expected_schema}, "
|
||||
f"Got: {table.schema}"
|
||||
)
|
||||
|
||||
# 如果旧表 schema 不匹配(如 label 类型变更),删除重建
|
||||
if table_name in existing_tables:
|
||||
old_table = db_manager.db.open_table(table_name)
|
||||
if old_table.schema != expected_schema:
|
||||
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
||||
db_manager.db.drop_table(table_name)
|
||||
existing_tables = []
|
||||
|
||||
if table_name in existing_tables:
|
||||
# 表已存在且 schema 匹配,跳过建库
|
||||
print(f"Table '{table_name}' already exists, skipping database build.")
|
||||
table = db_manager.db.open_table(table_name)
|
||||
else:
|
||||
# 创建新表
|
||||
table = db_manager.db.create_table(table_name, schema=expected_schema)
|
||||
|
||||
# 使用 DataLoader 批量建库
|
||||
# Build database
|
||||
train_loader = DataLoader(
|
||||
train_dataset.with_format("torch"),
|
||||
batch_size=batch_size,
|
||||
@@ -259,17 +210,45 @@ def task_eval(
|
||||
)
|
||||
_establish_eval_database(processor, model, table, train_loader)
|
||||
|
||||
# 使用 DataLoader 批量评估
|
||||
test_loader = DataLoader(
|
||||
test_dataset.with_format("torch"),
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
)
|
||||
correct, total = _evaluate_recall(processor, model, table, test_loader, top_k)
|
||||
def evaluate(
|
||||
self,
|
||||
model: Any,
|
||||
processor: Any,
|
||||
test_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Evaluate the model on the test dataset.
|
||||
|
||||
accuracy = correct / total
|
||||
print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}")
|
||||
print(f"Correct: {correct}/{total}")
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
test_dataset: Test dataset.
|
||||
table: LanceDB table to search against.
|
||||
batch_size: Batch size for DataLoader.
|
||||
|
||||
return accuracy
|
||||
Returns:
|
||||
Dictionary containing evaluation results with keys:
|
||||
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
|
||||
- correct: Number of correct predictions
|
||||
- total: Total number of test samples
|
||||
- top_k: The K value used
|
||||
"""
|
||||
test_loader = DataLoader(
|
||||
test_dataset.with_format("torch"),
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
)
|
||||
correct, total = _evaluate_recall(
|
||||
processor, model, table, test_loader, self.top_k
|
||||
)
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
Reference in New Issue
Block a user