From 5f9d2bfcd870b0e1502048a2fb147785e0db10e3 Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Tue, 10 Feb 2026 17:01:51 +0800 Subject: [PATCH] refactor(benchmarks): overhaul evaluation pipeline with LanceDB integration --- mini-nav/benchmarks/__init__.py | 32 ++-- mini-nav/benchmarks/task_eval.py | 290 ++++++++++++++++++++++++++----- 2 files changed, 265 insertions(+), 57 deletions(-) diff --git a/mini-nav/benchmarks/__init__.py b/mini-nav/benchmarks/__init__.py index 62413a7..06f74c5 100644 --- a/mini-nav/benchmarks/__init__.py +++ b/mini-nav/benchmarks/__init__.py @@ -2,8 +2,9 @@ from typing import Literal, cast import torch from compressors import DinoCompressor, FloatCompressor +from configs import cfg_manager from transformers import AutoImageProcessor, BitImageProcessorFast -from utils import get_device, get_output_diretory +from utils import get_device from .task_eval import task_eval @@ -13,35 +14,44 @@ def evaluate( dataset: Literal["CIFAR-10", "CIFAR-100"], benchmark: Literal["Recall@1", "Recall@10"], ): + """运行模型评估。 + + Args: + compressor_model: 压缩模型类型。 + dataset: 数据集名称。 + benchmark: 评估指标。 + """ + device = get_device() + match compressor_model: case "Dinov2": processor = cast( BitImageProcessorFast, AutoImageProcessor.from_pretrained( - "facebook/dinov2-large", device_map=get_device() + "facebook/dinov2-large", device_map=device ), ) - model = DinoCompressor().to(get_device()) + model = DinoCompressor().to(device) case "Dinov2WithCompressor": processor = cast( BitImageProcessorFast, AutoImageProcessor.from_pretrained( - "facebook/dinov2-large", device_map=get_device() + "facebook/dinov2-large", device_map=device ), ) - - compressor = FloatCompressor().load_state_dict( - torch.load(get_output_diretory() / "compressor.pt") - ) - model = DinoCompressor(compressor).to(get_device()) + output_dir = cfg_manager.get().output.directory + compressor = FloatCompressor() + compressor.load_state_dict(torch.load(output_dir / "compressor.pt")) + model = DinoCompressor(compressor).to(device) case _: raise ValueError(f"Unknown compressor: {compressor_model}") + # 根据 benchmark 确定 top_k match benchmark: case "Recall@1": - task_eval(processor, model, dataset, 1) + task_eval(processor, model, dataset, compressor_model, top_k=1) case "Recall@10": - task_eval(processor, model, dataset, 10) + task_eval(processor, model, dataset, compressor_model, top_k=10) case _: raise ValueError(f"Unknown benchmark: {benchmark}") diff --git a/mini-nav/benchmarks/task_eval.py b/mini-nav/benchmarks/task_eval.py index 7775652..5009587 100644 --- a/mini-nav/benchmarks/task_eval.py +++ b/mini-nav/benchmarks/task_eval.py @@ -1,77 +1,275 @@ -from typing import Literal, cast +from typing import Dict, Literal, cast -import polars as pl +import lancedb +import pyarrow as pa import torch -from datasets import Dataset, load_dataset -from torch import Tensor, nn +from database import db_manager +from datasets import load_dataset +from torch import nn from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import BitImageProcessorFast -from utils import get_device + +# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名) +DATASET_CONFIG: Dict[str, tuple[str, str, str]] = { + "CIFAR-10": ("uoft-cs/cifar10", "img", "label"), + "CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"), +} -def establish_database( +def _get_table_name(dataset: str, model_name: str) -> str: + """Generate database table name from dataset and model name. + + Args: + dataset: Dataset name, e.g. "CIFAR-10". + model_name: Model name, e.g. "Dinov2". + + Returns: + Formatted table name, e.g. "cifar10_dinov2". + """ + ds_part = dataset.lower().replace("-", "") + model_part = model_name.lower() + return f"{ds_part}_{model_part}" + + +def _infer_vector_dim( processor: BitImageProcessorFast, model: nn.Module, - dataset: Dataset, - batch_size: int = 32, -) -> pl.DataFrame: - df = pl.DataFrame() + sample_image, +) -> int: + """Infer model output vector dimension via a single forward pass. + Args: + processor: Image preprocessor. + model: Feature extraction model. + sample_image: A sample image for dimension inference. + + Returns: + Vector dimension. + """ + device = next(model.parameters()).device model.eval() - dataloader = DataLoader( - dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4 - ) with torch.no_grad(): - for batch in tqdm(dataloader, desc="Establish Database"): - imgs = batch["img"] - labels = batch["label"] + inputs = processor(images=sample_image, return_tensors="pt") + inputs.to(device) + output = model(inputs) - inputs = processor(imgs, return_tensors="pt").to(get_device()) + # output shape: [1, dim] + return output.shape[-1] - outputs = cast(Tensor, model(inputs)) - return df +def _build_eval_schema(vector_dim: int) -> pa.Schema: + """Build PyArrow schema for evaluation database table. + + Args: + vector_dim: Feature vector dimension. + + Returns: + PyArrow schema with id, label, and vector fields. + """ + return pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("label", pa.int32()), + pa.field("vector", pa.list_(pa.float32(), vector_dim)), + ] + ) + + +@torch.no_grad() +def _establish_eval_database( + processor: BitImageProcessorFast, + model: nn.Module, + table: lancedb.table.Table, + dataloader: DataLoader, +) -> None: + """Extract features from training images and store them in a database table. + + Args: + processor: Image preprocessor. + model: Feature extraction model. + table: LanceDB table to store features. + dataloader: DataLoader for the training dataset. + """ + device = next(model.parameters()).device + model.eval() + + global_idx = 0 + for batch in tqdm(dataloader, desc="Building eval database"): + imgs = batch["img"] + labels = batch["label"] + + # 预处理并推理 + inputs = processor(imgs, return_tensors="pt") + inputs.to(device) + outputs = model(inputs) # [B, dim] + + # 整个batch一次性转到CPU + features = cast(torch.Tensor, outputs).cpu() + labels_list = labels.tolist() + + # 逐条写入数据库 + batch_size = len(labels_list) + table.add( + [ + { + "id": global_idx + j, + "label": labels_list[j], + "vector": features[j].numpy(), + } + for j in range(batch_size) + ] + ) + global_idx += batch_size + + +@torch.no_grad() +def _evaluate_recall( + processor: BitImageProcessorFast, + model: nn.Module, + table: lancedb.table.Table, + dataloader: DataLoader, + top_k: int, +) -> tuple[int, int]: + """Evaluate Recall@K by searching the database for each test image. + + For each batch, features are extracted in one forward pass and moved to CPU, + then each sample is searched individually against the database. + + Args: + processor: Image preprocessor. + model: Feature extraction model. + table: LanceDB table to search against. + dataloader: DataLoader for the test dataset. + top_k: Number of top results to retrieve. + + Returns: + A tuple of (correct_count, total_count). + """ + device = next(model.parameters()).device + model.eval() + + correct = 0 + total = 0 + + for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"): + imgs = batch["img"] + labels = batch["label"] + + # 批量前向推理 + inputs = processor(imgs, return_tensors="pt") + inputs.to(device) + outputs = model(inputs) # [B, dim] + + # 整个batch一次性转到CPU + features = cast(torch.Tensor, outputs).cpu() + labels_list = labels.tolist() + + # 逐条搜索并验证 + for j in range(len(labels_list)): + feature = features[j].tolist() + true_label = labels_list[j] + + # 搜索 top_k 最相似结果 + results = ( + table.search(feature) + .select(["label", "_distance"]) + .limit(top_k) + .to_polars() + ) + + # 检查 top_k 中是否包含正确标签 + retrieved_labels = results["label"].to_list() + if true_label in retrieved_labels: + correct += 1 + total += 1 + + return correct, total def task_eval( processor: BitImageProcessorFast, model: nn.Module, dataset: Literal["CIFAR-10", "CIFAR-100"], + model_name: str, top_k: int = 10, - batch_size: int = 32, -): - match dataset: - case "CIFAR-10": - train_dataset = load_dataset("uoft-cs/cifar10", split="train") - test_dataset = load_dataset("uoft-cs/cifar10", split="test") - case "CIFAR-100": - train_dataset = load_dataset("uoft-cs/cifar100", split="train") - test_dataset = load_dataset("uoft-cs/cifar100", split="test") - case _: - raise ValueError( - f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'." - ) + batch_size: int = 64, +) -> float: + """Evaluate model Recall@K accuracy on a dataset using vector retrieval. - # Establish database - df = establish_database(processor, model, train_dataset, batch_size) + Workflow: + 1. Create or open a database table named by dataset and model. + 2. Build database from training set features (skip if table exists). + 3. Evaluate on test set: extract features in batches, search top_k, + check if correct label appears in results. - # Test - dataloader = DataLoader( + Args: + processor: Image preprocessor. + model: Feature extraction model. + dataset: Dataset name. + model_name: Model name, used for table name generation. + top_k: Number of top similar results to retrieve. + batch_size: Batch size for DataLoader. + + Returns: + Recall@K accuracy (0.0 ~ 1.0). + + Raises: + ValueError: If dataset name is not supported. + """ + if dataset not in DATASET_CONFIG: + raise ValueError( + f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}." + ) + hf_id, img_col, label_col = DATASET_CONFIG[dataset] + + # 加载数据集 + train_dataset = load_dataset(hf_id, split="train") + test_dataset = load_dataset(hf_id, split="test") + + # 生成表名,推断向量维度 + table_name = _get_table_name(dataset, model_name) + vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col]) + expected_schema = _build_eval_schema(vector_dim) + existing_tables = db_manager.db.list_tables().tables + + # 如果旧表 schema 不匹配(如 label 类型变更),删除重建 + if table_name in existing_tables: + old_table = db_manager.db.open_table(table_name) + if old_table.schema != expected_schema: + print(f"Table '{table_name}' schema mismatch, rebuilding.") + db_manager.db.drop_table(table_name) + existing_tables = [] + + if table_name in existing_tables: + # 表已存在且 schema 匹配,跳过建库 + print(f"Table '{table_name}' already exists, skipping database build.") + table = db_manager.db.open_table(table_name) + else: + # 创建新表 + table = db_manager.db.create_table(table_name, schema=expected_schema) + + # 使用 DataLoader 批量建库 + train_loader = DataLoader( + train_dataset.with_format("torch"), + batch_size=batch_size, + shuffle=False, + num_workers=4, + ) + _establish_eval_database(processor, model, table, train_loader) + + # 使用 DataLoader 批量评估 + test_loader = DataLoader( test_dataset.with_format("torch"), batch_size=batch_size, - shuffle=True, + shuffle=False, num_workers=4, ) + correct, total = _evaluate_recall(processor, model, table, test_loader, top_k) - with torch.no_grad(): - for batch in tqdm(dataloader, desc="Test Evaluation"): - imgs = batch["img"] - labels = batch["label"] + accuracy = correct / total + print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}") + print(f"Correct: {correct}/{total}") - inputs = processor(imgs, return_tensors="pt").to(get_device()) - - outputs = cast(Tensor, model(inputs)) - for vec in outputs: - pass + return accuracy