feat(benchmarks): add evaluation framework for DINO-based compressors

This commit is contained in:
2026-02-08 22:43:38 +08:00
parent 3ba3705ba6
commit 7f6732edeb
11 changed files with 217 additions and 42 deletions

View File

@@ -0,0 +1,49 @@
from typing import Literal, cast
import torch
from compressors import DinoCompressor, FloatCompressor
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device, get_output_diretory
from .task_eval import task_eval
def evaluate(
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
dataset: Literal["CIFAR-10", "CIFAR-100"],
benchmark: Literal["Recall@1", "Recall@10"],
):
match compressor_model:
case "Dinov2":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
),
)
model = DinoCompressor().to(get_device())
case "Dinov2WithCompressor":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
),
)
compressor = FloatCompressor().load_state_dict(
torch.load(get_output_diretory() / "compressor.pt")
)
model = DinoCompressor(compressor).to(get_device())
case _:
raise ValueError(f"Unknown compressor: {compressor_model}")
match benchmark:
case "Recall@1":
task_eval(processor, model, dataset, 1)
case "Recall@10":
task_eval(processor, model, dataset, 10)
case _:
raise ValueError(f"Unknown benchmark: {benchmark}")
__all__ = ["task_eval", "evaluate"]