mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(benchmarks): overhaul evaluation pipeline with LanceDB integration
This commit is contained in:
@@ -2,8 +2,9 @@ from typing import Literal, cast
|
||||
|
||||
import torch
|
||||
from compressors import DinoCompressor, FloatCompressor
|
||||
from configs import cfg_manager
|
||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||
from utils import get_device, get_output_diretory
|
||||
from utils import get_device
|
||||
|
||||
from .task_eval import task_eval
|
||||
|
||||
@@ -13,35 +14,44 @@ def evaluate(
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
benchmark: Literal["Recall@1", "Recall@10"],
|
||||
):
|
||||
"""运行模型评估。
|
||||
|
||||
Args:
|
||||
compressor_model: 压缩模型类型。
|
||||
dataset: 数据集名称。
|
||||
benchmark: 评估指标。
|
||||
"""
|
||||
device = get_device()
|
||||
|
||||
match compressor_model:
|
||||
case "Dinov2":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=get_device()
|
||||
"facebook/dinov2-large", device_map=device
|
||||
),
|
||||
)
|
||||
model = DinoCompressor().to(get_device())
|
||||
model = DinoCompressor().to(device)
|
||||
case "Dinov2WithCompressor":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=get_device()
|
||||
"facebook/dinov2-large", device_map=device
|
||||
),
|
||||
)
|
||||
|
||||
compressor = FloatCompressor().load_state_dict(
|
||||
torch.load(get_output_diretory() / "compressor.pt")
|
||||
)
|
||||
model = DinoCompressor(compressor).to(get_device())
|
||||
output_dir = cfg_manager.get().output.directory
|
||||
compressor = FloatCompressor()
|
||||
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
|
||||
model = DinoCompressor(compressor).to(device)
|
||||
case _:
|
||||
raise ValueError(f"Unknown compressor: {compressor_model}")
|
||||
|
||||
# 根据 benchmark 确定 top_k
|
||||
match benchmark:
|
||||
case "Recall@1":
|
||||
task_eval(processor, model, dataset, 1)
|
||||
task_eval(processor, model, dataset, compressor_model, top_k=1)
|
||||
case "Recall@10":
|
||||
task_eval(processor, model, dataset, 10)
|
||||
task_eval(processor, model, dataset, compressor_model, top_k=10)
|
||||
case _:
|
||||
raise ValueError(f"Unknown benchmark: {benchmark}")
|
||||
|
||||
|
||||
@@ -1,77 +1,275 @@
|
||||
from typing import Literal, cast
|
||||
from typing import Dict, Literal, cast
|
||||
|
||||
import polars as pl
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from torch import Tensor, nn
|
||||
from database import db_manager
|
||||
from datasets import load_dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import BitImageProcessorFast
|
||||
from utils import get_device
|
||||
|
||||
# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名)
|
||||
DATASET_CONFIG: Dict[str, tuple[str, str, str]] = {
|
||||
"CIFAR-10": ("uoft-cs/cifar10", "img", "label"),
|
||||
"CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"),
|
||||
}
|
||||
|
||||
|
||||
def establish_database(
|
||||
def _get_table_name(dataset: str, model_name: str) -> str:
|
||||
"""Generate database table name from dataset and model name.
|
||||
|
||||
Args:
|
||||
dataset: Dataset name, e.g. "CIFAR-10".
|
||||
model_name: Model name, e.g. "Dinov2".
|
||||
|
||||
Returns:
|
||||
Formatted table name, e.g. "cifar10_dinov2".
|
||||
"""
|
||||
ds_part = dataset.lower().replace("-", "")
|
||||
model_part = model_name.lower()
|
||||
return f"{ds_part}_{model_part}"
|
||||
|
||||
|
||||
def _infer_vector_dim(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
dataset: Dataset,
|
||||
batch_size: int = 32,
|
||||
) -> pl.DataFrame:
|
||||
df = pl.DataFrame()
|
||||
sample_image,
|
||||
) -> int:
|
||||
"""Infer model output vector dimension via a single forward pass.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
sample_image: A sample image for dimension inference.
|
||||
|
||||
Returns:
|
||||
Vector dimension.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
dataloader = DataLoader(
|
||||
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Establish Database"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
inputs = processor(images=sample_image, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
output = model(inputs)
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
||||
# output shape: [1, dim]
|
||||
return output.shape[-1]
|
||||
|
||||
outputs = cast(Tensor, model(inputs))
|
||||
|
||||
return df
|
||||
def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||
"""Build PyArrow schema for evaluation database table.
|
||||
|
||||
Args:
|
||||
vector_dim: Feature vector dimension.
|
||||
|
||||
Returns:
|
||||
PyArrow schema with id, label, and vector fields.
|
||||
"""
|
||||
return pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("label", pa.int32()),
|
||||
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _establish_eval_database(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
table: lancedb.table.Table,
|
||||
dataloader: DataLoader,
|
||||
) -> None:
|
||||
"""Extract features from training images and store them in a database table.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
table: LanceDB table to store features.
|
||||
dataloader: DataLoader for the training dataset.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
global_idx = 0
|
||||
for batch in tqdm(dataloader, desc="Building eval database"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
# 预处理并推理
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs) # [B, dim]
|
||||
|
||||
# 整个batch一次性转到CPU
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
# 逐条写入数据库
|
||||
batch_size = len(labels_list)
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
"id": global_idx + j,
|
||||
"label": labels_list[j],
|
||||
"vector": features[j].numpy(),
|
||||
}
|
||||
for j in range(batch_size)
|
||||
]
|
||||
)
|
||||
global_idx += batch_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _evaluate_recall(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
table: lancedb.table.Table,
|
||||
dataloader: DataLoader,
|
||||
top_k: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Evaluate Recall@K by searching the database for each test image.
|
||||
|
||||
For each batch, features are extracted in one forward pass and moved to CPU,
|
||||
then each sample is searched individually against the database.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
table: LanceDB table to search against.
|
||||
dataloader: DataLoader for the test dataset.
|
||||
top_k: Number of top results to retrieve.
|
||||
|
||||
Returns:
|
||||
A tuple of (correct_count, total_count).
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
# 批量前向推理
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs) # [B, dim]
|
||||
|
||||
# 整个batch一次性转到CPU
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
# 逐条搜索并验证
|
||||
for j in range(len(labels_list)):
|
||||
feature = features[j].tolist()
|
||||
true_label = labels_list[j]
|
||||
|
||||
# 搜索 top_k 最相似结果
|
||||
results = (
|
||||
table.search(feature)
|
||||
.select(["label", "_distance"])
|
||||
.limit(top_k)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
# 检查 top_k 中是否包含正确标签
|
||||
retrieved_labels = results["label"].to_list()
|
||||
if true_label in retrieved_labels:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
return correct, total
|
||||
|
||||
|
||||
def task_eval(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
model_name: str,
|
||||
top_k: int = 10,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
match dataset:
|
||||
case "CIFAR-10":
|
||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
|
||||
case "CIFAR-100":
|
||||
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
|
||||
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
|
||||
)
|
||||
batch_size: int = 64,
|
||||
) -> float:
|
||||
"""Evaluate model Recall@K accuracy on a dataset using vector retrieval.
|
||||
|
||||
# Establish database
|
||||
df = establish_database(processor, model, train_dataset, batch_size)
|
||||
Workflow:
|
||||
1. Create or open a database table named by dataset and model.
|
||||
2. Build database from training set features (skip if table exists).
|
||||
3. Evaluate on test set: extract features in batches, search top_k,
|
||||
check if correct label appears in results.
|
||||
|
||||
# Test
|
||||
dataloader = DataLoader(
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
dataset: Dataset name.
|
||||
model_name: Model name, used for table name generation.
|
||||
top_k: Number of top similar results to retrieve.
|
||||
batch_size: Batch size for DataLoader.
|
||||
|
||||
Returns:
|
||||
Recall@K accuracy (0.0 ~ 1.0).
|
||||
|
||||
Raises:
|
||||
ValueError: If dataset name is not supported.
|
||||
"""
|
||||
if dataset not in DATASET_CONFIG:
|
||||
raise ValueError(
|
||||
f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}."
|
||||
)
|
||||
hf_id, img_col, label_col = DATASET_CONFIG[dataset]
|
||||
|
||||
# 加载数据集
|
||||
train_dataset = load_dataset(hf_id, split="train")
|
||||
test_dataset = load_dataset(hf_id, split="test")
|
||||
|
||||
# 生成表名,推断向量维度
|
||||
table_name = _get_table_name(dataset, model_name)
|
||||
vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col])
|
||||
expected_schema = _build_eval_schema(vector_dim)
|
||||
existing_tables = db_manager.db.list_tables().tables
|
||||
|
||||
# 如果旧表 schema 不匹配(如 label 类型变更),删除重建
|
||||
if table_name in existing_tables:
|
||||
old_table = db_manager.db.open_table(table_name)
|
||||
if old_table.schema != expected_schema:
|
||||
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
||||
db_manager.db.drop_table(table_name)
|
||||
existing_tables = []
|
||||
|
||||
if table_name in existing_tables:
|
||||
# 表已存在且 schema 匹配,跳过建库
|
||||
print(f"Table '{table_name}' already exists, skipping database build.")
|
||||
table = db_manager.db.open_table(table_name)
|
||||
else:
|
||||
# 创建新表
|
||||
table = db_manager.db.create_table(table_name, schema=expected_schema)
|
||||
|
||||
# 使用 DataLoader 批量建库
|
||||
train_loader = DataLoader(
|
||||
train_dataset.with_format("torch"),
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
)
|
||||
_establish_eval_database(processor, model, table, train_loader)
|
||||
|
||||
# 使用 DataLoader 批量评估
|
||||
test_loader = DataLoader(
|
||||
test_dataset.with_format("torch"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
)
|
||||
correct, total = _evaluate_recall(processor, model, table, test_loader, top_k)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Test Evaluation"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
accuracy = correct / total
|
||||
print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}")
|
||||
print(f"Correct: {correct}/{total}")
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
||||
|
||||
outputs = cast(Tensor, model(inputs))
|
||||
for vec in outputs:
|
||||
pass
|
||||
return accuracy
|
||||
|
||||
Reference in New Issue
Block a user