mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(benchmarks): overhaul evaluation pipeline with LanceDB integration
This commit is contained in:
@@ -2,8 +2,9 @@ from typing import Literal, cast
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressors import DinoCompressor, FloatCompressor
|
from compressors import DinoCompressor, FloatCompressor
|
||||||
|
from configs import cfg_manager
|
||||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||||
from utils import get_device, get_output_diretory
|
from utils import get_device
|
||||||
|
|
||||||
from .task_eval import task_eval
|
from .task_eval import task_eval
|
||||||
|
|
||||||
@@ -13,35 +14,44 @@ def evaluate(
|
|||||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||||
benchmark: Literal["Recall@1", "Recall@10"],
|
benchmark: Literal["Recall@1", "Recall@10"],
|
||||||
):
|
):
|
||||||
|
"""运行模型评估。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compressor_model: 压缩模型类型。
|
||||||
|
dataset: 数据集名称。
|
||||||
|
benchmark: 评估指标。
|
||||||
|
"""
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
match compressor_model:
|
match compressor_model:
|
||||||
case "Dinov2":
|
case "Dinov2":
|
||||||
processor = cast(
|
processor = cast(
|
||||||
BitImageProcessorFast,
|
BitImageProcessorFast,
|
||||||
AutoImageProcessor.from_pretrained(
|
AutoImageProcessor.from_pretrained(
|
||||||
"facebook/dinov2-large", device_map=get_device()
|
"facebook/dinov2-large", device_map=device
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
model = DinoCompressor().to(get_device())
|
model = DinoCompressor().to(device)
|
||||||
case "Dinov2WithCompressor":
|
case "Dinov2WithCompressor":
|
||||||
processor = cast(
|
processor = cast(
|
||||||
BitImageProcessorFast,
|
BitImageProcessorFast,
|
||||||
AutoImageProcessor.from_pretrained(
|
AutoImageProcessor.from_pretrained(
|
||||||
"facebook/dinov2-large", device_map=get_device()
|
"facebook/dinov2-large", device_map=device
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
output_dir = cfg_manager.get().output.directory
|
||||||
compressor = FloatCompressor().load_state_dict(
|
compressor = FloatCompressor()
|
||||||
torch.load(get_output_diretory() / "compressor.pt")
|
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
|
||||||
)
|
model = DinoCompressor(compressor).to(device)
|
||||||
model = DinoCompressor(compressor).to(get_device())
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown compressor: {compressor_model}")
|
raise ValueError(f"Unknown compressor: {compressor_model}")
|
||||||
|
|
||||||
|
# 根据 benchmark 确定 top_k
|
||||||
match benchmark:
|
match benchmark:
|
||||||
case "Recall@1":
|
case "Recall@1":
|
||||||
task_eval(processor, model, dataset, 1)
|
task_eval(processor, model, dataset, compressor_model, top_k=1)
|
||||||
case "Recall@10":
|
case "Recall@10":
|
||||||
task_eval(processor, model, dataset, 10)
|
task_eval(processor, model, dataset, compressor_model, top_k=10)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown benchmark: {benchmark}")
|
raise ValueError(f"Unknown benchmark: {benchmark}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,77 +1,275 @@
|
|||||||
from typing import Literal, cast
|
from typing import Dict, Literal, cast
|
||||||
|
|
||||||
import polars as pl
|
import lancedb
|
||||||
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, load_dataset
|
from database import db_manager
|
||||||
from torch import Tensor, nn
|
from datasets import load_dataset
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import BitImageProcessorFast
|
from transformers import BitImageProcessorFast
|
||||||
from utils import get_device
|
|
||||||
|
# 数据集配置:数据集名称 -> (HuggingFace ID, 图片列名, 标签列名)
|
||||||
|
DATASET_CONFIG: Dict[str, tuple[str, str, str]] = {
|
||||||
|
"CIFAR-10": ("uoft-cs/cifar10", "img", "label"),
|
||||||
|
"CIFAR-100": ("uoft-cs/cifar100", "img", "fine_label"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def establish_database(
|
def _get_table_name(dataset: str, model_name: str) -> str:
|
||||||
|
"""Generate database table name from dataset and model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset name, e.g. "CIFAR-10".
|
||||||
|
model_name: Model name, e.g. "Dinov2".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted table name, e.g. "cifar10_dinov2".
|
||||||
|
"""
|
||||||
|
ds_part = dataset.lower().replace("-", "")
|
||||||
|
model_part = model_name.lower()
|
||||||
|
return f"{ds_part}_{model_part}"
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_vector_dim(
|
||||||
processor: BitImageProcessorFast,
|
processor: BitImageProcessorFast,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
dataset: Dataset,
|
sample_image,
|
||||||
batch_size: int = 32,
|
) -> int:
|
||||||
) -> pl.DataFrame:
|
"""Infer model output vector dimension via a single forward pass.
|
||||||
df = pl.DataFrame()
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
sample_image: A sample image for dimension inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Vector dimension.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
model.eval()
|
model.eval()
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(dataloader, desc="Establish Database"):
|
inputs = processor(images=sample_image, return_tensors="pt")
|
||||||
imgs = batch["img"]
|
inputs.to(device)
|
||||||
labels = batch["label"]
|
output = model(inputs)
|
||||||
|
|
||||||
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
# output shape: [1, dim]
|
||||||
|
return output.shape[-1]
|
||||||
|
|
||||||
outputs = cast(Tensor, model(inputs))
|
|
||||||
|
|
||||||
return df
|
def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||||
|
"""Build PyArrow schema for evaluation database table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_dim: Feature vector dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow schema with id, label, and vector fields.
|
||||||
|
"""
|
||||||
|
return pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.int32()),
|
||||||
|
pa.field("label", pa.int32()),
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _establish_eval_database(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
) -> None:
|
||||||
|
"""Extract features from training images and store them in a database table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
table: LanceDB table to store features.
|
||||||
|
dataloader: DataLoader for the training dataset.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
global_idx = 0
|
||||||
|
for batch in tqdm(dataloader, desc="Building eval database"):
|
||||||
|
imgs = batch["img"]
|
||||||
|
labels = batch["label"]
|
||||||
|
|
||||||
|
# 预处理并推理
|
||||||
|
inputs = processor(imgs, return_tensors="pt")
|
||||||
|
inputs.to(device)
|
||||||
|
outputs = model(inputs) # [B, dim]
|
||||||
|
|
||||||
|
# 整个batch一次性转到CPU
|
||||||
|
features = cast(torch.Tensor, outputs).cpu()
|
||||||
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
|
# 逐条写入数据库
|
||||||
|
batch_size = len(labels_list)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": global_idx + j,
|
||||||
|
"label": labels_list[j],
|
||||||
|
"vector": features[j].numpy(),
|
||||||
|
}
|
||||||
|
for j in range(batch_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
global_idx += batch_size
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _evaluate_recall(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
top_k: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""Evaluate Recall@K by searching the database for each test image.
|
||||||
|
|
||||||
|
For each batch, features are extracted in one forward pass and moved to CPU,
|
||||||
|
then each sample is searched individually against the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
table: LanceDB table to search against.
|
||||||
|
dataloader: DataLoader for the test dataset.
|
||||||
|
top_k: Number of top results to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (correct_count, total_count).
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
||||||
|
imgs = batch["img"]
|
||||||
|
labels = batch["label"]
|
||||||
|
|
||||||
|
# 批量前向推理
|
||||||
|
inputs = processor(imgs, return_tensors="pt")
|
||||||
|
inputs.to(device)
|
||||||
|
outputs = model(inputs) # [B, dim]
|
||||||
|
|
||||||
|
# 整个batch一次性转到CPU
|
||||||
|
features = cast(torch.Tensor, outputs).cpu()
|
||||||
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
|
# 逐条搜索并验证
|
||||||
|
for j in range(len(labels_list)):
|
||||||
|
feature = features[j].tolist()
|
||||||
|
true_label = labels_list[j]
|
||||||
|
|
||||||
|
# 搜索 top_k 最相似结果
|
||||||
|
results = (
|
||||||
|
table.search(feature)
|
||||||
|
.select(["label", "_distance"])
|
||||||
|
.limit(top_k)
|
||||||
|
.to_polars()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查 top_k 中是否包含正确标签
|
||||||
|
retrieved_labels = results["label"].to_list()
|
||||||
|
if true_label in retrieved_labels:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
def task_eval(
|
def task_eval(
|
||||||
processor: BitImageProcessorFast,
|
processor: BitImageProcessorFast,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||||
|
model_name: str,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
batch_size: int = 32,
|
batch_size: int = 64,
|
||||||
):
|
) -> float:
|
||||||
match dataset:
|
"""Evaluate model Recall@K accuracy on a dataset using vector retrieval.
|
||||||
case "CIFAR-10":
|
|
||||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
|
||||||
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
|
|
||||||
case "CIFAR-100":
|
|
||||||
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
|
|
||||||
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
|
|
||||||
case _:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Establish database
|
Workflow:
|
||||||
df = establish_database(processor, model, train_dataset, batch_size)
|
1. Create or open a database table named by dataset and model.
|
||||||
|
2. Build database from training set features (skip if table exists).
|
||||||
|
3. Evaluate on test set: extract features in batches, search top_k,
|
||||||
|
check if correct label appears in results.
|
||||||
|
|
||||||
# Test
|
Args:
|
||||||
dataloader = DataLoader(
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
dataset: Dataset name.
|
||||||
|
model_name: Model name, used for table name generation.
|
||||||
|
top_k: Number of top similar results to retrieve.
|
||||||
|
batch_size: Batch size for DataLoader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recall@K accuracy (0.0 ~ 1.0).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If dataset name is not supported.
|
||||||
|
"""
|
||||||
|
if dataset not in DATASET_CONFIG:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown dataset: {dataset}. Only support: {list(DATASET_CONFIG.keys())}."
|
||||||
|
)
|
||||||
|
hf_id, img_col, label_col = DATASET_CONFIG[dataset]
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
train_dataset = load_dataset(hf_id, split="train")
|
||||||
|
test_dataset = load_dataset(hf_id, split="test")
|
||||||
|
|
||||||
|
# 生成表名,推断向量维度
|
||||||
|
table_name = _get_table_name(dataset, model_name)
|
||||||
|
vector_dim = _infer_vector_dim(processor, model, train_dataset[0][img_col])
|
||||||
|
expected_schema = _build_eval_schema(vector_dim)
|
||||||
|
existing_tables = db_manager.db.list_tables().tables
|
||||||
|
|
||||||
|
# 如果旧表 schema 不匹配(如 label 类型变更),删除重建
|
||||||
|
if table_name in existing_tables:
|
||||||
|
old_table = db_manager.db.open_table(table_name)
|
||||||
|
if old_table.schema != expected_schema:
|
||||||
|
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
||||||
|
db_manager.db.drop_table(table_name)
|
||||||
|
existing_tables = []
|
||||||
|
|
||||||
|
if table_name in existing_tables:
|
||||||
|
# 表已存在且 schema 匹配,跳过建库
|
||||||
|
print(f"Table '{table_name}' already exists, skipping database build.")
|
||||||
|
table = db_manager.db.open_table(table_name)
|
||||||
|
else:
|
||||||
|
# 创建新表
|
||||||
|
table = db_manager.db.create_table(table_name, schema=expected_schema)
|
||||||
|
|
||||||
|
# 使用 DataLoader 批量建库
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset.with_format("torch"),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
)
|
||||||
|
_establish_eval_database(processor, model, table, train_loader)
|
||||||
|
|
||||||
|
# 使用 DataLoader 批量评估
|
||||||
|
test_loader = DataLoader(
|
||||||
test_dataset.with_format("torch"),
|
test_dataset.with_format("torch"),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
)
|
)
|
||||||
|
correct, total = _evaluate_recall(processor, model, table, test_loader, top_k)
|
||||||
|
|
||||||
with torch.no_grad():
|
accuracy = correct / total
|
||||||
for batch in tqdm(dataloader, desc="Test Evaluation"):
|
print(f"\nRecall@{top_k} on {dataset} with {model_name}: {accuracy:.4f}")
|
||||||
imgs = batch["img"]
|
print(f"Correct: {correct}/{total}")
|
||||||
labels = batch["label"]
|
|
||||||
|
|
||||||
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
return accuracy
|
||||||
|
|
||||||
outputs = cast(Tensor, model(inputs))
|
|
||||||
for vec in outputs:
|
|
||||||
pass
|
|
||||||
|
|||||||
Reference in New Issue
Block a user