mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
refactor(benchmarks): overhaul evaluation pipeline with LanceDB integration
This commit is contained in:
@@ -2,8 +2,9 @@ from typing import Literal, cast
|
||||
|
||||
import torch
|
||||
from compressors import DinoCompressor, FloatCompressor
|
||||
from configs import cfg_manager
|
||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||
from utils import get_device, get_output_diretory
|
||||
from utils import get_device
|
||||
|
||||
from .task_eval import task_eval
|
||||
|
||||
@@ -13,35 +14,44 @@ def evaluate(
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
benchmark: Literal["Recall@1", "Recall@10"],
|
||||
):
|
||||
"""运行模型评估。
|
||||
|
||||
Args:
|
||||
compressor_model: 压缩模型类型。
|
||||
dataset: 数据集名称。
|
||||
benchmark: 评估指标。
|
||||
"""
|
||||
device = get_device()
|
||||
|
||||
match compressor_model:
|
||||
case "Dinov2":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=get_device()
|
||||
"facebook/dinov2-large", device_map=device
|
||||
),
|
||||
)
|
||||
model = DinoCompressor().to(get_device())
|
||||
model = DinoCompressor().to(device)
|
||||
case "Dinov2WithCompressor":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=get_device()
|
||||
"facebook/dinov2-large", device_map=device
|
||||
),
|
||||
)
|
||||
|
||||
compressor = FloatCompressor().load_state_dict(
|
||||
torch.load(get_output_diretory() / "compressor.pt")
|
||||
)
|
||||
model = DinoCompressor(compressor).to(get_device())
|
||||
output_dir = cfg_manager.get().output.directory
|
||||
compressor = FloatCompressor()
|
||||
compressor.load_state_dict(torch.load(output_dir / "compressor.pt"))
|
||||
model = DinoCompressor(compressor).to(device)
|
||||
case _:
|
||||
raise ValueError(f"Unknown compressor: {compressor_model}")
|
||||
|
||||
# 根据 benchmark 确定 top_k
|
||||
match benchmark:
|
||||
case "Recall@1":
|
||||
task_eval(processor, model, dataset, 1)
|
||||
task_eval(processor, model, dataset, compressor_model, top_k=1)
|
||||
case "Recall@10":
|
||||
task_eval(processor, model, dataset, 10)
|
||||
task_eval(processor, model, dataset, compressor_model, top_k=10)
|
||||
case _:
|
||||
raise ValueError(f"Unknown benchmark: {benchmark}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user