refactor(benchmarks): overhaul evaluation pipeline with LanceDB integration

This commit is contained in:
2026-02-10 17:01:51 +08:00
parent 6f60cd94d3
commit 5f9d2bfcd8
2 changed files with 265 additions and 57 deletions

View File

@@ -2,8 +2,9 @@ from typing import Literal, cast
import torch
from compressors import DinoCompressor, FloatCompressor
from configs import cfg_manager
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device, get_output_diretory
from utils import get_device
from .task_eval import task_eval
@@ -13,35 +14,44 @@ def evaluate(
dataset: Literal["CIFAR-10", "CIFAR-100"],
benchmark: Literal["Recall@1", "Recall@10"],
):
"""运行模型评估。
Args:
compressor_model: 压缩模型类型。
dataset: 数据集名称。
benchmark: 评估指标。
"""
device = get_device()
match compressor_model:
case "Dinov2":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
"facebook/dinov2-large", device_map=device
),
)
model = DinoCompressor().to(get_device())
model = DinoCompressor().to(device)
case "Dinov2WithCompressor":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
"facebook/dinov2-large", device_map=device
),
)
compressor = FloatCompressor().load_state_dict(
torch.load(get_output_diretory() / "compressor.pt")
)
model = DinoCompressor(compressor).to(get_device())
output_dir = cfg_manager.get().output.directory
compressor = FloatCompressor()
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
model = DinoCompressor(compressor).to(device)
case _:
raise ValueError(f"Unknown compressor: {compressor_model}")
# 根据 benchmark 确定 top_k
match benchmark:
case "Recall@1":
task_eval(processor, model, dataset, 1)
task_eval(processor, model, dataset, compressor_model, top_k=1)
case "Recall@10":
task_eval(processor, model, dataset, 10)
task_eval(processor, model, dataset, compressor_model, top_k=10)
case _:
raise ValueError(f"Unknown benchmark: {benchmark}")

View File

@@ -1,77 +1,275 @@
from typing import Literal, cast
from typing import Dict, Literal, cast
import polars as pl
import lancedb
import pyarrow as pa
import torch
from datasets import Dataset, load_dataset
from torch import Tensor, nn
from database import db_manager
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils import get_device
# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名)
DATASET_CONFIG: Dict[str, tuple[str, str, str]] = {
"CIFAR-10": ("uoft-cs/cifar10", "img", "label"),
"CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"),
}
def establish_database(
def _get_table_name(dataset: str, model_name: str) -> str:
"""Generate database table name from dataset and model name.
Args:
dataset: Dataset name, e.g. "CIFAR-10".
model_name: Model name, e.g. "Dinov2".
Returns:
Formatted table name, e.g. "cifar10_dinov2".
"""
ds_part = dataset.lower().replace("-", "")
model_part = model_name.lower()
return f"{ds_part}_{model_part}"
def _infer_vector_dim(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Dataset,
batch_size: int = 32,
) -> pl.DataFrame:
df = pl.DataFrame()
sample_image,
) -> int:
"""Infer model output vector dimension via a single forward pass.
Args:
processor: Image preprocessor.
model: Feature extraction model.
sample_image: A sample image for dimension inference.
Returns:
Vector dimension.
"""
device = next(model.parameters()).device
model.eval()
dataloader = DataLoader(
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Establish Database"):
inputs = processor(images=sample_image, return_tensors="pt")
inputs.to(device)
output = model(inputs)
# output shape: [1, dim]
return output.shape[-1]
def _build_eval_schema(vector_dim: int) -> pa.Schema:
"""Build PyArrow schema for evaluation database table.
Args:
vector_dim: Feature vector dimension.
Returns:
PyArrow schema with id, label, and vector fields.
"""
return pa.schema(
[
pa.field("id", pa.int32()),
pa.field("label", pa.int32()),
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
]
)
@torch.no_grad()
def _establish_eval_database(
processor: BitImageProcessorFast,
model: nn.Module,
table: lancedb.table.Table,
dataloader: DataLoader,
) -> None:
"""Extract features from training images and store them in a database table.
Args:
processor: Image preprocessor.
model: Feature extraction model.
table: LanceDB table to store features.
dataloader: DataLoader for the training dataset.
"""
device = next(model.parameters()).device
model.eval()
global_idx = 0
for batch in tqdm(dataloader, desc="Building eval database"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
# 预处理并推理
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs) # [B, dim]
outputs = cast(Tensor, model(inputs))
# 整个batch一次性转到CPU
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
return df
# 逐条写入数据库
batch_size = len(labels_list)
table.add(
[
{
"id": global_idx + j,
"label": labels_list[j],
"vector": features[j].numpy(),
}
for j in range(batch_size)
]
)
global_idx += batch_size
@torch.no_grad()
def _evaluate_recall(
processor: BitImageProcessorFast,
model: nn.Module,
table: lancedb.table.Table,
dataloader: DataLoader,
top_k: int,
) -> tuple[int, int]:
"""Evaluate Recall@K by searching the database for each test image.
For each batch, features are extracted in one forward pass and moved to CPU,
then each sample is searched individually against the database.
Args:
processor: Image preprocessor.
model: Feature extraction model.
table: LanceDB table to search against.
dataloader: DataLoader for the test dataset.
top_k: Number of top results to retrieve.
Returns:
A tuple of (correct_count, total_count).
"""
device = next(model.parameters()).device
model.eval()
correct = 0
total = 0
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
imgs = batch["img"]
labels = batch["label"]
# 批量前向推理
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs) # [B, dim]
# 整个batch一次性转到CPU
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
# 逐条搜索并验证
for j in range(len(labels_list)):
feature = features[j].tolist()
true_label = labels_list[j]
# 搜索 top_k 最相似结果
results = (
table.search(feature)
.select(["label", "_distance"])
.limit(top_k)
.to_polars()
)
# 检查 top_k 中是否包含正确标签
retrieved_labels = results["label"].to_list()
if true_label in retrieved_labels:
correct += 1
total += 1
return correct, total
def task_eval(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Literal["CIFAR-10", "CIFAR-100"],
model_name: str,
top_k: int = 10,
batch_size: int = 32,
):
match dataset:
case "CIFAR-10":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
case "CIFAR-100":
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
case _:
batch_size: int = 64,
) -> float:
"""Evaluate model Recall@K accuracy on a dataset using vector retrieval.
Workflow:
1. Create or open a database table named by dataset and model.
2. Build database from training set features (skip if table exists).
3. Evaluate on test set: extract features in batches, search top_k,
check if correct label appears in results.
Args:
processor: Image preprocessor.
model: Feature extraction model.
dataset: Dataset name.
model_name: Model name, used for table name generation.
top_k: Number of top similar results to retrieve.
batch_size: Batch size for DataLoader.
Returns:
Recall@K accuracy (0.0 ~ 1.0).
Raises:
ValueError: If dataset name is not supported.
"""
if dataset not in DATASET_CONFIG:
raise ValueError(
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}."
)
hf_id, img_col, label_col = DATASET_CONFIG[dataset]
# Establish database
df = establish_database(processor, model, train_dataset, batch_size)
# 加载数据集
train_dataset = load_dataset(hf_id, split="train")
test_dataset = load_dataset(hf_id, split="test")
# Test
dataloader = DataLoader(
test_dataset.with_format("torch"),
# 生成表名,推断向量维度
table_name = _get_table_name(dataset, model_name)
vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col])
expected_schema = _build_eval_schema(vector_dim)
existing_tables = db_manager.db.list_tables().tables
# 如果旧表 schema 不匹配(如 label 类型变更),删除重建
if table_name in existing_tables:
old_table = db_manager.db.open_table(table_name)
if old_table.schema != expected_schema:
print(f"Table '{table_name}' schema mismatch, rebuilding.")
db_manager.db.drop_table(table_name)
existing_tables = []
if table_name in existing_tables:
# 表已存在且 schema 匹配,跳过建库
print(f"Table '{table_name}' already exists, skipping database build.")
table = db_manager.db.open_table(table_name)
else:
# 创建新表
table = db_manager.db.create_table(table_name, schema=expected_schema)
# 使用 DataLoader 批量建库
train_loader = DataLoader(
train_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=True,
shuffle=False,
num_workers=4,
)
_establish_eval_database(processor, model, table, train_loader)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Test Evaluation"):
imgs = batch["img"]
labels = batch["label"]
# 使用 DataLoader 批量评估
test_loader = DataLoader(
test_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=False,
num_workers=4,
)
correct, total = _evaluate_recall(processor, model, table, test_loader, top_k)
inputs = processor(imgs, return_tensors="pt").to(get_device())
accuracy = correct / total
print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}")
print(f"Correct: {correct}/{total}")
outputs = cast(Tensor, model(inputs))
for vec in outputs:
pass
return accuracy