mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(benchmarks): modularize benchmark system with config-driven execution
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -207,6 +207,7 @@ __marimo__/
|
|||||||
|
|
||||||
# Projects
|
# Projects
|
||||||
datasets/
|
datasets/
|
||||||
|
!mini-nav/**/datasets/
|
||||||
data/
|
data/
|
||||||
deps/
|
deps/
|
||||||
outputs/
|
outputs/
|
||||||
|
|||||||
@@ -1,59 +1,12 @@
|
|||||||
from typing import Literal, cast
|
"""Benchmark evaluation module.
|
||||||
|
|
||||||
import torch
|
This module provides a modular benchmark system that supports:
|
||||||
from compressors import DinoCompressor, FloatCompressor
|
- Multiple dataset sources (HuggingFace, local)
|
||||||
from configs import cfg_manager
|
- Multiple evaluation tasks (retrieval, with extensibility for more)
|
||||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
- Configuration-driven execution
|
||||||
from utils import get_device
|
"""
|
||||||
|
|
||||||
from .task_eval import task_eval
|
from .base import BaseBenchmarkTask, BaseDataset
|
||||||
|
from .runner import run_benchmark
|
||||||
|
|
||||||
|
__all__ = ["BaseBenchmarkTask", "BaseDataset", "run_benchmark"]
|
||||||
def evaluate(
|
|
||||||
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
|
|
||||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
|
||||||
benchmark: Literal["Recall@1", "Recall@10"],
|
|
||||||
):
|
|
||||||
"""运行模型评估。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
compressor_model: 压缩模型类型。
|
|
||||||
dataset: 数据集名称。
|
|
||||||
benchmark: 评估指标。
|
|
||||||
"""
|
|
||||||
device = get_device()
|
|
||||||
|
|
||||||
match compressor_model:
|
|
||||||
case "Dinov2":
|
|
||||||
processor = cast(
|
|
||||||
BitImageProcessorFast,
|
|
||||||
AutoImageProcessor.from_pretrained(
|
|
||||||
"facebook/dinov2-large", device_map=device
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model = DinoCompressor().to(device)
|
|
||||||
case "Dinov2WithCompressor":
|
|
||||||
processor = cast(
|
|
||||||
BitImageProcessorFast,
|
|
||||||
AutoImageProcessor.from_pretrained(
|
|
||||||
"facebook/dinov2-large", device_map=device
|
|
||||||
),
|
|
||||||
)
|
|
||||||
output_dir = cfg_manager.get().output.directory
|
|
||||||
compressor = FloatCompressor()
|
|
||||||
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
|
|
||||||
model = DinoCompressor(compressor).to(device)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown compressor: {compressor_model}")
|
|
||||||
|
|
||||||
# 根据 benchmark 确定 top_k
|
|
||||||
match benchmark:
|
|
||||||
case "Recall@1":
|
|
||||||
task_eval(processor, model, dataset, compressor_model, top_k=1)
|
|
||||||
case "Recall@10":
|
|
||||||
task_eval(processor, model, dataset, compressor_model, top_k=10)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown benchmark: {benchmark}")
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["task_eval", "evaluate"]
|
|
||||||
99
mini-nav/benchmarks/base.py
Normal file
99
mini-nav/benchmarks/base.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Base classes for benchmark datasets and tasks."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataset(ABC):
|
||||||
|
"""Abstract base class for benchmark datasets."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_train_split(self) -> Any:
|
||||||
|
"""Get training split of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset object for training.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_test_split(self) -> Any:
|
||||||
|
"""Get test/evaluation split of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset object for testing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseBenchmarkTask(ABC):
|
||||||
|
"""Abstract base class for benchmark evaluation tasks."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
"""Initialize the benchmark task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Task-specific configuration parameters.
|
||||||
|
"""
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_database(
|
||||||
|
self,
|
||||||
|
model: Any,
|
||||||
|
processor: Any,
|
||||||
|
train_dataset: Any,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
batch_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Build the evaluation database from training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Feature extraction model.
|
||||||
|
processor: Image preprocessor.
|
||||||
|
train_dataset: Training dataset.
|
||||||
|
table: LanceDB table to store features.
|
||||||
|
batch_size: Batch size for DataLoader.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
model: Any,
|
||||||
|
processor: Any,
|
||||||
|
test_dataset: Any,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
batch_size: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Evaluate the model on the test dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Feature extraction model.
|
||||||
|
processor: Image preprocessor.
|
||||||
|
test_dataset: Test dataset.
|
||||||
|
table: LanceDB table to search against.
|
||||||
|
batch_size: Batch size for DataLoader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing evaluation results.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetFactory(Protocol):
|
||||||
|
"""Protocol for dataset factory."""
|
||||||
|
|
||||||
|
def __call__(self, config: Any) -> BaseDataset:
|
||||||
|
"""Create a dataset from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dataset configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset instance.
|
||||||
|
"""
|
||||||
|
...
|
||||||
6
mini-nav/benchmarks/datasets/__init__.py
Normal file
6
mini-nav/benchmarks/datasets/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Dataset loaders for benchmark evaluation."""
|
||||||
|
|
||||||
|
from .huggingface import HuggingFaceDataset
|
||||||
|
from .local import LocalDataset
|
||||||
|
|
||||||
|
__all__ = ["HuggingFaceDataset", "LocalDataset"]
|
||||||
66
mini-nav/benchmarks/datasets/huggingface.py
Normal file
66
mini-nav/benchmarks/datasets/huggingface.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""HuggingFace dataset loader for benchmark evaluation."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from ..base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceDataset(BaseDataset):
|
||||||
|
"""Dataset loader for HuggingFace datasets."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hf_id: str,
|
||||||
|
img_column: str = "img",
|
||||||
|
label_column: str = "label",
|
||||||
|
):
|
||||||
|
"""Initialize HuggingFace dataset loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_id: HuggingFace dataset ID.
|
||||||
|
img_column: Name of the image column.
|
||||||
|
label_column: Name of the label column.
|
||||||
|
"""
|
||||||
|
self.hf_id = hf_id
|
||||||
|
self.img_column = img_column
|
||||||
|
self.label_column = label_column
|
||||||
|
self._train_dataset: Any = None
|
||||||
|
self._test_dataset: Any = None
|
||||||
|
|
||||||
|
def _load(self) -> tuple[Any, Any]:
|
||||||
|
"""Load dataset from HuggingFace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_dataset, test_dataset).
|
||||||
|
"""
|
||||||
|
if self._train_dataset is None:
|
||||||
|
dataset = load_dataset(self.hf_id)
|
||||||
|
# Handle datasets that use 'train' and 'test' splits
|
||||||
|
if "train" in dataset:
|
||||||
|
self._train_dataset = dataset["train"]
|
||||||
|
if "test" in dataset:
|
||||||
|
self._test_dataset = dataset["test"]
|
||||||
|
# Handle datasets that use 'train' and 'validation' splits
|
||||||
|
elif "validation" in dataset:
|
||||||
|
self._test_dataset = dataset["validation"]
|
||||||
|
return self._train_dataset, self._test_dataset
|
||||||
|
|
||||||
|
def get_train_split(self) -> Any:
|
||||||
|
"""Get training split of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training dataset.
|
||||||
|
"""
|
||||||
|
train, _ = self._load()
|
||||||
|
return train
|
||||||
|
|
||||||
|
def get_test_split(self) -> Any:
|
||||||
|
"""Get test/evaluation split of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Test dataset.
|
||||||
|
"""
|
||||||
|
_, test = self._load()
|
||||||
|
return test
|
||||||
157
mini-nav/benchmarks/datasets/local.py
Normal file
157
mini-nav/benchmarks/datasets/local.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Local dataset loader for benchmark evaluation."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ..base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class LocalDataset(BaseDataset):
|
||||||
|
"""Dataset loader for local datasets."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_path: str,
|
||||||
|
img_column: str = "image_path",
|
||||||
|
label_column: str = "label",
|
||||||
|
):
|
||||||
|
"""Initialize local dataset loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Path to local dataset directory or CSV file.
|
||||||
|
img_column: Name of the image path column.
|
||||||
|
label_column: Name of the label column.
|
||||||
|
"""
|
||||||
|
self.local_path = Path(local_path)
|
||||||
|
self.img_column = img_column
|
||||||
|
self.label_column = label_column
|
||||||
|
self._train_dataset: Optional[Any] = None
|
||||||
|
self._test_dataset: Optional[Any] = None
|
||||||
|
|
||||||
|
def _load_csv_dataset(self) -> tuple[Any, Any]:
|
||||||
|
"""Load dataset from CSV file.
|
||||||
|
|
||||||
|
Expected CSV format:
|
||||||
|
label,image_path,x1,y1,x2,y2
|
||||||
|
"class_name","path/to/image.jpg",100,200,300,400
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_dataset, test_dataset).
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset as TorchDataset
|
||||||
|
|
||||||
|
# Load CSV file
|
||||||
|
df = pd.read_csv(self.local_path)
|
||||||
|
|
||||||
|
# Create a simple dataset class
|
||||||
|
class CSVDataset(TorchDataset):
|
||||||
|
def __init__(self, dataframe: pd.DataFrame, img_col: str, label_col: str):
|
||||||
|
self.df = dataframe.reset_index(drop=True)
|
||||||
|
self.img_col = img_col
|
||||||
|
self.label_col = label_col
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.df)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||||
|
row = self.df.iloc[idx]
|
||||||
|
return {
|
||||||
|
"img": row[self.img_col],
|
||||||
|
"label": row[self.label_col],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Split into train/test (80/20)
|
||||||
|
split_idx = int(len(df) * 0.8)
|
||||||
|
train_df = df.iloc[:split_idx]
|
||||||
|
test_df = df.iloc[split_idx:]
|
||||||
|
|
||||||
|
self._train_dataset = CSVDataset(train_df, self.img_column, self.label_column)
|
||||||
|
self._test_dataset = CSVDataset(test_df, self.img_column, self.label_column)
|
||||||
|
|
||||||
|
return self._train_dataset, self._test_dataset
|
||||||
|
|
||||||
|
def _load_directory_dataset(self) -> tuple[Any, Any]:
|
||||||
|
"""Load dataset from directory structure.
|
||||||
|
|
||||||
|
Expected structure:
|
||||||
|
local_path/
|
||||||
|
train/
|
||||||
|
class_name_1/
|
||||||
|
image1.jpg
|
||||||
|
image2.jpg
|
||||||
|
class_name_2/
|
||||||
|
image1.jpg
|
||||||
|
test/
|
||||||
|
class_name_1/
|
||||||
|
image1.jpg
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_dataset, test_dataset).
|
||||||
|
"""
|
||||||
|
from torch.utils.data import Dataset as TorchDataset
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
class DirectoryDataset(TorchDataset):
|
||||||
|
def __init__(self, root_dir: Path, transform=None):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.samples = []
|
||||||
|
self.label_map = {}
|
||||||
|
|
||||||
|
# Build label map
|
||||||
|
classes = sorted([d.name for d in root_dir.iterdir() if d.is_dir()])
|
||||||
|
self.label_map = {cls: idx for idx, cls in enumerate(classes)}
|
||||||
|
|
||||||
|
# Build sample list
|
||||||
|
for cls_dir in root_dir.iterdir():
|
||||||
|
if cls_dir.is_dir():
|
||||||
|
label = self.label_map[cls_dir.name]
|
||||||
|
for img_path in cls_dir.iterdir():
|
||||||
|
if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
|
||||||
|
self.samples.append((img_path, label))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||||
|
img_path, label = self.samples[idx]
|
||||||
|
image = Image.open(img_path).convert("RGB")
|
||||||
|
return {"img": image, "label": label}
|
||||||
|
|
||||||
|
train_dir = self.local_path / "train"
|
||||||
|
test_dir = self.local_path / "test"
|
||||||
|
|
||||||
|
if train_dir.exists():
|
||||||
|
self._train_dataset = DirectoryDataset(train_dir)
|
||||||
|
if test_dir.exists():
|
||||||
|
self._test_dataset = DirectoryDataset(test_dir)
|
||||||
|
|
||||||
|
return self._train_dataset, self._test_dataset
|
||||||
|
|
||||||
|
def get_train_split(self) -> Any:
|
||||||
|
"""Get training split of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Training dataset.
|
||||||
|
"""
|
||||||
|
if self._train_dataset is None:
|
||||||
|
if self.local_path.suffix.lower() == ".csv":
|
||||||
|
self._load_csv_dataset()
|
||||||
|
else:
|
||||||
|
self._load_directory_dataset()
|
||||||
|
return self._train_dataset
|
||||||
|
|
||||||
|
def get_test_split(self) -> Any:
|
||||||
|
"""Get test/evaluation split of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Test dataset.
|
||||||
|
"""
|
||||||
|
if self._test_dataset is None:
|
||||||
|
if self.local_path.suffix.lower() == ".csv":
|
||||||
|
self._load_csv_dataset()
|
||||||
|
else:
|
||||||
|
self._load_directory_dataset()
|
||||||
|
return self._test_dataset
|
||||||
186
mini-nav/benchmarks/runner.py
Normal file
186
mini-nav/benchmarks/runner.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Benchmark runner for executing evaluations."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
|
||||||
|
from benchmarks.tasks import get_task
|
||||||
|
from configs.models import BenchmarkConfig, DatasetSourceConfig
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(config: DatasetSourceConfig) -> Any:
|
||||||
|
"""Create a dataset instance from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dataset source configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If source_type is not supported.
|
||||||
|
"""
|
||||||
|
if config.source_type == "huggingface":
|
||||||
|
return HuggingFaceDataset(
|
||||||
|
hf_id=config.path,
|
||||||
|
img_column=config.img_column,
|
||||||
|
label_column=config.label_column,
|
||||||
|
)
|
||||||
|
elif config.source_type == "local":
|
||||||
|
return LocalDataset(
|
||||||
|
local_path=config.path,
|
||||||
|
img_column=config.img_column,
|
||||||
|
label_column=config.label_column,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported source_type: {config.source_type}. "
|
||||||
|
f"Supported types: 'huggingface', 'local'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_table_name(config: BenchmarkConfig, model_name: str) -> str:
|
||||||
|
"""Generate database table name from config and model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Benchmark configuration.
|
||||||
|
model_name: Model name for table naming.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted table name.
|
||||||
|
"""
|
||||||
|
prefix = config.model_table_prefix
|
||||||
|
# Use dataset path as part of table name (sanitize)
|
||||||
|
dataset_name = Path(config.dataset.path).name.lower().replace("-", "_")
|
||||||
|
return f"{prefix}_{dataset_name}_{model_name}"
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_table(
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
model_name: str,
|
||||||
|
vector_dim: int,
|
||||||
|
) -> lancedb.table.Table:
|
||||||
|
"""Ensure the LanceDB table exists with correct schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Benchmark configuration.
|
||||||
|
model_name: Model name for table naming.
|
||||||
|
vector_dim: Feature vector dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LanceDB table instance.
|
||||||
|
"""
|
||||||
|
import pyarrow as pa
|
||||||
|
from database import db_manager
|
||||||
|
|
||||||
|
table_name = _get_table_name(config, model_name)
|
||||||
|
|
||||||
|
# Build expected schema
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.int32()),
|
||||||
|
pa.field("label", pa.int32()),
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
db = db_manager.db
|
||||||
|
existing_tables = db.list_tables().tables
|
||||||
|
|
||||||
|
# Check if table exists and has correct schema
|
||||||
|
if table_name in existing_tables:
|
||||||
|
table = db.open_table(table_name)
|
||||||
|
if table.schema != schema:
|
||||||
|
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
||||||
|
db.drop_table(table_name)
|
||||||
|
table = db.create_table(table_name, schema=schema)
|
||||||
|
else:
|
||||||
|
table = db.create_table(table_name, schema=schema)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(
|
||||||
|
model: Any,
|
||||||
|
processor: Any,
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
model_name: str = "model",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run benchmark evaluation.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Create dataset from configuration
|
||||||
|
2. Create benchmark task from configuration
|
||||||
|
3. Build evaluation database from training set
|
||||||
|
4. Evaluate on test set
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Feature extraction model.
|
||||||
|
processor: Image preprocessor.
|
||||||
|
config: Benchmark configuration.
|
||||||
|
model_name: Model name for table naming.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing evaluation results.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If benchmark is not enabled in config.
|
||||||
|
"""
|
||||||
|
if not config.enabled:
|
||||||
|
raise ValueError(
|
||||||
|
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
|
||||||
|
dataset = create_dataset(config.dataset)
|
||||||
|
|
||||||
|
# Get train and test splits
|
||||||
|
train_dataset = dataset.get_train_split()
|
||||||
|
test_dataset = dataset.get_test_split()
|
||||||
|
|
||||||
|
if train_dataset is None or test_dataset is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {config.dataset.path} does not have train/test splits"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Infer vector dimension from a sample
|
||||||
|
sample = train_dataset[0]
|
||||||
|
sample_image = sample["img"]
|
||||||
|
|
||||||
|
from .tasks.retrieval import _infer_vector_dim
|
||||||
|
|
||||||
|
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||||
|
print(f"Model output dimension: {vector_dim}")
|
||||||
|
|
||||||
|
# Ensure table exists with correct schema
|
||||||
|
table = _ensure_table(config, model_name, vector_dim)
|
||||||
|
table_name = _get_table_name(config, model_name)
|
||||||
|
|
||||||
|
# Check if database is already built
|
||||||
|
table_count = table.count_rows()
|
||||||
|
if table_count > 0:
|
||||||
|
print(
|
||||||
|
f"Table '{table_name}' already has {table_count} entries, skipping database build."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create and run benchmark task
|
||||||
|
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||||
|
print(f"Building database with {len(train_dataset)} training samples...")
|
||||||
|
task.build_database(model, processor, train_dataset, table, config.batch_size)
|
||||||
|
|
||||||
|
# Run evaluation
|
||||||
|
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||||
|
print(f"Evaluating on {len(test_dataset)} test samples...")
|
||||||
|
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n=== Benchmark Results ===")
|
||||||
|
print(f"Dataset: {config.dataset.path}")
|
||||||
|
print(f"Task: {config.task.name}")
|
||||||
|
print(f"Top-K: {results['top_k']}")
|
||||||
|
print(f"Accuracy: {results['accuracy']:.4f}")
|
||||||
|
print(f"Correct: {results['correct']}/{results['total']}")
|
||||||
|
|
||||||
|
return results
|
||||||
6
mini-nav/benchmarks/tasks/__init__.py
Normal file
6
mini-nav/benchmarks/tasks/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Benchmark evaluation tasks."""
|
||||||
|
|
||||||
|
from .retrieval import RetrievalTask
|
||||||
|
from .registry import TASK_REGISTRY, get_task
|
||||||
|
|
||||||
|
__all__ = ["RetrievalTask", "TASK_REGISTRY", "get_task"]
|
||||||
59
mini-nav/benchmarks/tasks/registry.py
Normal file
59
mini-nav/benchmarks/tasks/registry.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Task registry for benchmark evaluation."""
|
||||||
|
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
from benchmarks.base import BaseBenchmarkTask
|
||||||
|
|
||||||
|
# Task registry: maps task type string to task class
|
||||||
|
TASK_REGISTRY: dict[str, Type[BaseBenchmarkTask]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterTask:
|
||||||
|
"""Decorator class to register a benchmark task.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@register_task("retrieval")
|
||||||
|
class RetrievalTask(BaseBenchmarkTask):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_type: str):
|
||||||
|
"""Initialize the decorator with task type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: Task type identifier.
|
||||||
|
"""
|
||||||
|
self.task_type = task_type
|
||||||
|
|
||||||
|
def __call__(self, cls: type[BaseBenchmarkTask]) -> type[BaseBenchmarkTask]:
|
||||||
|
"""Register the decorated class to the task registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The class to be decorated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The unmodified class.
|
||||||
|
"""
|
||||||
|
TASK_REGISTRY[self.task_type] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_task(task_type: str, **kwargs: Any) -> BaseBenchmarkTask:
|
||||||
|
"""Get a benchmark task instance by type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: Task type identifier.
|
||||||
|
**kwargs: Additional arguments passed to task constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If task type is not registered.
|
||||||
|
"""
|
||||||
|
if task_type not in TASK_REGISTRY:
|
||||||
|
available = list(TASK_REGISTRY.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown task type: {task_type}. Available tasks: {available}"
|
||||||
|
)
|
||||||
|
return TASK_REGISTRY[task_type](**kwargs)
|
||||||
@@ -1,41 +1,22 @@
|
|||||||
from typing import Dict, Literal, cast
|
"""Retrieval task for benchmark evaluation (Recall@K)."""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
from database import db_manager
|
from benchmarks.base import BaseBenchmarkTask
|
||||||
from datasets import load_dataset
|
from benchmarks.tasks.registry import RegisterTask
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import BitImageProcessorFast
|
from transformers import BitImageProcessorFast
|
||||||
|
|
||||||
# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名)
|
|
||||||
DATASET_CONFIG: Dict[str, tuple[str, str, str]] = {
|
|
||||||
"CIFAR-10": ("uoft-cs/cifar10", "img", "label"),
|
|
||||||
"CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_table_name(dataset: str, model_name: str) -> str:
|
|
||||||
"""Generate database table name from dataset and model name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Dataset name, e.g. "CIFAR-10".
|
|
||||||
model_name: Model name, e.g. "Dinov2".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted table name, e.g. "cifar10_dinov2".
|
|
||||||
"""
|
|
||||||
ds_part = dataset.lower().replace("-", "")
|
|
||||||
model_part = model_name.lower()
|
|
||||||
return f"{ds_part}_{model_part}"
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_vector_dim(
|
def _infer_vector_dim(
|
||||||
processor: BitImageProcessorFast,
|
processor: BitImageProcessorFast,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sample_image,
|
sample_image: Any,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Infer model output vector dimension via a single forward pass.
|
"""Infer model output vector dimension via a single forward pass.
|
||||||
|
|
||||||
@@ -55,7 +36,6 @@ def _infer_vector_dim(
|
|||||||
inputs.to(device)
|
inputs.to(device)
|
||||||
output = model(inputs)
|
output = model(inputs)
|
||||||
|
|
||||||
# output shape: [1, dim]
|
|
||||||
return output.shape[-1]
|
return output.shape[-1]
|
||||||
|
|
||||||
|
|
||||||
@@ -100,16 +80,13 @@ def _establish_eval_database(
|
|||||||
imgs = batch["img"]
|
imgs = batch["img"]
|
||||||
labels = batch["label"]
|
labels = batch["label"]
|
||||||
|
|
||||||
# 预处理并推理
|
|
||||||
inputs = processor(imgs, return_tensors="pt")
|
inputs = processor(imgs, return_tensors="pt")
|
||||||
inputs.to(device)
|
inputs.to(device)
|
||||||
outputs = model(inputs) # [B, dim]
|
outputs = model(inputs)
|
||||||
|
|
||||||
# 整个batch一次性转到CPU
|
|
||||||
features = cast(torch.Tensor, outputs).cpu()
|
features = cast(torch.Tensor, outputs).cpu()
|
||||||
labels_list = labels.tolist()
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
# 逐条写入数据库
|
|
||||||
batch_size = len(labels_list)
|
batch_size = len(labels_list)
|
||||||
table.add(
|
table.add(
|
||||||
[
|
[
|
||||||
@@ -134,9 +111,6 @@ def _evaluate_recall(
|
|||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""Evaluate Recall@K by searching the database for each test image.
|
"""Evaluate Recall@K by searching the database for each test image.
|
||||||
|
|
||||||
For each batch, features are extracted in one forward pass and moved to CPU,
|
|
||||||
then each sample is searched individually against the database.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
processor: Image preprocessor.
|
processor: Image preprocessor.
|
||||||
model: Feature extraction model.
|
model: Feature extraction model.
|
||||||
@@ -157,21 +131,17 @@ def _evaluate_recall(
|
|||||||
imgs = batch["img"]
|
imgs = batch["img"]
|
||||||
labels = batch["label"]
|
labels = batch["label"]
|
||||||
|
|
||||||
# 批量前向推理
|
|
||||||
inputs = processor(imgs, return_tensors="pt")
|
inputs = processor(imgs, return_tensors="pt")
|
||||||
inputs.to(device)
|
inputs.to(device)
|
||||||
outputs = model(inputs) # [B, dim]
|
outputs = model(inputs)
|
||||||
|
|
||||||
# 整个batch一次性转到CPU
|
|
||||||
features = cast(torch.Tensor, outputs).cpu()
|
features = cast(torch.Tensor, outputs).cpu()
|
||||||
labels_list = labels.tolist()
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
# 逐条搜索并验证
|
|
||||||
for j in range(len(labels_list)):
|
for j in range(len(labels_list)):
|
||||||
feature = features[j].tolist()
|
feature = features[j].tolist()
|
||||||
true_label = labels_list[j]
|
true_label = labels_list[j]
|
||||||
|
|
||||||
# 搜索 top_k 最相似结果
|
|
||||||
results = (
|
results = (
|
||||||
table.search(feature)
|
table.search(feature)
|
||||||
.select(["label", "_distance"])
|
.select(["label", "_distance"])
|
||||||
@@ -179,7 +149,6 @@ def _evaluate_recall(
|
|||||||
.to_polars()
|
.to_polars()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查 top_k 中是否包含正确标签
|
|
||||||
retrieved_labels = results["label"].to_list()
|
retrieved_labels = results["label"].to_list()
|
||||||
if true_label in retrieved_labels:
|
if true_label in retrieved_labels:
|
||||||
correct += 1
|
correct += 1
|
||||||
@@ -188,69 +157,51 @@ def _evaluate_recall(
|
|||||||
return correct, total
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
def task_eval(
|
@RegisterTask("retrieval")
|
||||||
processor: BitImageProcessorFast,
|
class RetrievalTask(BaseBenchmarkTask):
|
||||||
model: nn.Module,
|
"""Retrieval evaluation task (Recall@K)."""
|
||||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
|
||||||
model_name: str,
|
|
||||||
top_k: int = 10,
|
|
||||||
batch_size: int = 64,
|
|
||||||
) -> float:
|
|
||||||
"""Evaluate model Recall@K accuracy on a dataset using vector retrieval.
|
|
||||||
|
|
||||||
Workflow:
|
def __init__(self, top_k: int = 10):
|
||||||
1. Create or open a database table named by dataset and model.
|
"""Initialize retrieval task.
|
||||||
2. Build database from training set features (skip if table exists).
|
|
||||||
3. Evaluate on test set: extract features in batches, search top_k,
|
|
||||||
check if correct label appears in results.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
processor: Image preprocessor.
|
top_k: Number of top results to retrieve for recall calculation.
|
||||||
model: Feature extraction model.
|
"""
|
||||||
dataset: Dataset name.
|
super().__init__(top_k=top_k)
|
||||||
model_name: Model name, used for table name generation.
|
self.top_k = top_k
|
||||||
top_k: Number of top similar results to retrieve.
|
|
||||||
batch_size: Batch size for DataLoader.
|
|
||||||
|
|
||||||
Returns:
|
def build_database(
|
||||||
Recall@K accuracy (0.0 ~ 1.0).
|
self,
|
||||||
|
model: Any,
|
||||||
|
processor: Any,
|
||||||
|
train_dataset: Any,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
batch_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Build the evaluation database from training data.
|
||||||
|
|
||||||
Raises:
|
Args:
|
||||||
ValueError: If dataset name is not supported.
|
model: Feature extraction model.
|
||||||
"""
|
processor: Image preprocessor.
|
||||||
if dataset not in DATASET_CONFIG:
|
train_dataset: Training dataset.
|
||||||
raise ValueError(
|
table: LanceDB table to store features.
|
||||||
f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}."
|
batch_size: Batch size for DataLoader.
|
||||||
)
|
"""
|
||||||
hf_id, img_col, label_col = DATASET_CONFIG[dataset]
|
# Get a sample image to infer vector dimension
|
||||||
|
sample = train_dataset[0]
|
||||||
|
sample_image = sample["img"]
|
||||||
|
|
||||||
# 加载数据集
|
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||||
train_dataset = load_dataset(hf_id, split="train")
|
expected_schema = _build_eval_schema(vector_dim)
|
||||||
test_dataset = load_dataset(hf_id, split="test")
|
|
||||||
|
|
||||||
# 生成表名,推断向量维度
|
# Check schema compatibility
|
||||||
table_name = _get_table_name(dataset, model_name)
|
if table.schema != expected_schema:
|
||||||
vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col])
|
raise ValueError(
|
||||||
expected_schema = _build_eval_schema(vector_dim)
|
f"Table schema mismatch. Expected: {expected_schema}, "
|
||||||
existing_tables = db_manager.db.list_tables().tables
|
f"Got: {table.schema}"
|
||||||
|
)
|
||||||
|
|
||||||
# 如果旧表 schema 不匹配(如 label 类型变更),删除重建
|
# Build database
|
||||||
if table_name in existing_tables:
|
|
||||||
old_table = db_manager.db.open_table(table_name)
|
|
||||||
if old_table.schema != expected_schema:
|
|
||||||
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
|
||||||
db_manager.db.drop_table(table_name)
|
|
||||||
existing_tables = []
|
|
||||||
|
|
||||||
if table_name in existing_tables:
|
|
||||||
# 表已存在且 schema 匹配,跳过建库
|
|
||||||
print(f"Table '{table_name}' already exists, skipping database build.")
|
|
||||||
table = db_manager.db.open_table(table_name)
|
|
||||||
else:
|
|
||||||
# 创建新表
|
|
||||||
table = db_manager.db.create_table(table_name, schema=expected_schema)
|
|
||||||
|
|
||||||
# 使用 DataLoader 批量建库
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset.with_format("torch"),
|
train_dataset.with_format("torch"),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -259,17 +210,45 @@ def task_eval(
|
|||||||
)
|
)
|
||||||
_establish_eval_database(processor, model, table, train_loader)
|
_establish_eval_database(processor, model, table, train_loader)
|
||||||
|
|
||||||
# 使用 DataLoader 批量评估
|
def evaluate(
|
||||||
test_loader = DataLoader(
|
self,
|
||||||
test_dataset.with_format("torch"),
|
model: Any,
|
||||||
batch_size=batch_size,
|
processor: Any,
|
||||||
shuffle=False,
|
test_dataset: Any,
|
||||||
num_workers=4,
|
table: lancedb.table.Table,
|
||||||
)
|
batch_size: int,
|
||||||
correct, total = _evaluate_recall(processor, model, table, test_loader, top_k)
|
) -> dict[str, Any]:
|
||||||
|
"""Evaluate the model on the test dataset.
|
||||||
|
|
||||||
accuracy = correct / total
|
Args:
|
||||||
print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}")
|
model: Feature extraction model.
|
||||||
print(f"Correct: {correct}/{total}")
|
processor: Image preprocessor.
|
||||||
|
test_dataset: Test dataset.
|
||||||
|
table: LanceDB table to search against.
|
||||||
|
batch_size: Batch size for DataLoader.
|
||||||
|
|
||||||
return accuracy
|
Returns:
|
||||||
|
Dictionary containing evaluation results with keys:
|
||||||
|
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
|
||||||
|
- correct: Number of correct predictions
|
||||||
|
- total: Total number of test samples
|
||||||
|
- top_k: The K value used
|
||||||
|
"""
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset.with_format("torch"),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
)
|
||||||
|
correct, total = _evaluate_recall(
|
||||||
|
processor, model, table, test_loader, self.top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
accuracy = correct / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct": correct,
|
||||||
|
"total": total,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
}
|
||||||
@@ -2,10 +2,10 @@ model:
|
|||||||
name: "facebook/dinov2-large"
|
name: "facebook/dinov2-large"
|
||||||
compression_dim: 512
|
compression_dim: 512
|
||||||
device: "auto" # auto-detect GPU
|
device: "auto" # auto-detect GPU
|
||||||
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
|
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
|
||||||
sam_min_mask_area: 100 # Minimum mask area threshold
|
sam_min_mask_area: 100 # Minimum mask area threshold
|
||||||
sam_max_masks: 10 # Maximum number of masks to keep
|
sam_max_masks: 10 # Maximum number of masks to keep
|
||||||
compressor_path: null # Path to trained HashCompressor weights (optional)
|
compressor_path: null # Path to trained HashCompressor weights (optional)
|
||||||
|
|
||||||
output:
|
output:
|
||||||
directory: "./outputs"
|
directory: "./outputs"
|
||||||
@@ -19,3 +19,17 @@ dataset:
|
|||||||
rotation_range: [-30, 30]
|
rotation_range: [-30, 30]
|
||||||
overlap_threshold: 0.3
|
overlap_threshold: 0.3
|
||||||
seed: 42
|
seed: 42
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
enabled: true
|
||||||
|
dataset:
|
||||||
|
source_type: "huggingface"
|
||||||
|
path: "uoft-cs/cifar10"
|
||||||
|
img_column: "img"
|
||||||
|
label_column: "label"
|
||||||
|
task:
|
||||||
|
name: "recall_at_k"
|
||||||
|
type: "retrieval"
|
||||||
|
top_k: 10
|
||||||
|
batch_size: 64
|
||||||
|
model_table_prefix: "benchmark"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Pydantic data models for feature compressor configuration."""
|
"""Pydantic data models for feature compressor configuration."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
@@ -98,6 +98,41 @@ class DatasetConfig(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSourceConfig(BaseModel):
|
||||||
|
"""Configuration for benchmark dataset source."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
source_type: Literal["huggingface", "local"] = "huggingface"
|
||||||
|
path: str = Field(default="", description="HuggingFace dataset ID or local path")
|
||||||
|
img_column: str = Field(default="img", description="Image column name")
|
||||||
|
label_column: str = Field(default="label", description="Label column name")
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkTaskConfig(BaseModel):
|
||||||
|
"""Configuration for benchmark task."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
name: str = Field(default="recall_at_k", description="Task name")
|
||||||
|
type: str = Field(default="retrieval", description="Task type")
|
||||||
|
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(BaseModel):
|
||||||
|
"""Configuration for benchmark evaluation."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
enabled: bool = Field(default=False, description="Enable benchmark evaluation")
|
||||||
|
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
|
||||||
|
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
|
||||||
|
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")
|
||||||
|
model_table_prefix: str = Field(
|
||||||
|
default="benchmark", description="Prefix for LanceDB table names"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
"""Root configuration for the feature compressor."""
|
"""Root configuration for the feature compressor."""
|
||||||
|
|
||||||
@@ -106,3 +141,4 @@ class Config(BaseModel):
|
|||||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
output: OutputConfig = Field(default_factory=OutputConfig)
|
output: OutputConfig = Field(default_factory=OutputConfig)
|
||||||
dataset: DatasetConfig = Field(default_factory=DatasetConfig)
|
dataset: DatasetConfig = Field(default_factory=DatasetConfig)
|
||||||
|
benchmark: BenchmarkConfig = Field(default_factory=BenchmarkConfig)
|
||||||
|
|||||||
@@ -16,9 +16,51 @@ if __name__ == "__main__":
|
|||||||
epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt"
|
epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt"
|
||||||
)
|
)
|
||||||
elif args.action == "benchmark":
|
elif args.action == "benchmark":
|
||||||
from benchmarks import evaluate
|
from typing import cast
|
||||||
|
|
||||||
evaluate("Dinov2", "CIFAR-10", "Recall@10")
|
import torch
|
||||||
|
from benchmarks import run_benchmark
|
||||||
|
from compressors import DinoCompressor
|
||||||
|
from configs import cfg_manager
|
||||||
|
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||||
|
from utils import get_device
|
||||||
|
|
||||||
|
config = cfg_manager.get()
|
||||||
|
benchmark_cfg = config.benchmark
|
||||||
|
|
||||||
|
if not benchmark_cfg.enabled:
|
||||||
|
print("Benchmark is not enabled. Set benchmark.enabled=true in config.yaml")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
|
# Load model and processor based on config
|
||||||
|
model_cfg = config.model
|
||||||
|
processor = cast(
|
||||||
|
BitImageProcessorFast,
|
||||||
|
AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load compressor weights if specified in model config
|
||||||
|
model = DinoCompressor().to(device)
|
||||||
|
if model_cfg.compressor_path is not None:
|
||||||
|
from compressors import HashCompressor
|
||||||
|
|
||||||
|
compressor = HashCompressor(
|
||||||
|
input_dim=model_cfg.compression_dim,
|
||||||
|
output_dim=model_cfg.compression_dim,
|
||||||
|
)
|
||||||
|
compressor.load_state_dict(torch.load(model_cfg.compressor_path))
|
||||||
|
# Wrap with compressor if path is specified
|
||||||
|
model.compressor = compressor
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
run_benchmark(
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
config=benchmark_cfg,
|
||||||
|
model_name="dinov2",
|
||||||
|
)
|
||||||
elif args.action == "visualize":
|
elif args.action == "visualize":
|
||||||
from visualizer import app
|
from visualizer import app
|
||||||
|
|
||||||
|
|||||||
@@ -1,26 +1,21 @@
|
|||||||
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
|
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from configs import cfg_manager
|
|
||||||
from compressors import (
|
from compressors import (
|
||||||
BinarySign,
|
BinarySign,
|
||||||
DinoCompressor,
|
DinoCompressor,
|
||||||
HashCompressor,
|
HashCompressor,
|
||||||
SegmentCompressor,
|
|
||||||
SAMHashPipeline,
|
SAMHashPipeline,
|
||||||
create_pipeline_from_config,
|
SegmentCompressor,
|
||||||
bits_to_hash,
|
bits_to_hash,
|
||||||
hash_to_bits,
|
create_pipeline_from_config,
|
||||||
hamming_distance,
|
hamming_distance,
|
||||||
hamming_similarity,
|
hamming_similarity,
|
||||||
|
hash_to_bits,
|
||||||
)
|
)
|
||||||
|
from configs import cfg_manager
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class TestHashCompressor:
|
class TestHashCompressor:
|
||||||
|
|||||||
Reference in New Issue
Block a user