mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
from typing import Literal, cast
|
|
|
|
import torch
|
|
from compressors import DinoCompressor, FloatCompressor
|
|
from transformers import AutoImageProcessor, BitImageProcessorFast
|
|
from utils import get_device, get_output_diretory
|
|
|
|
from .task_eval import task_eval
|
|
|
|
|
|
def evaluate(
|
|
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
|
|
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
|
benchmark: Literal["Recall@1", "Recall@10"],
|
|
):
|
|
match compressor_model:
|
|
case "Dinov2":
|
|
processor = cast(
|
|
BitImageProcessorFast,
|
|
AutoImageProcessor.from_pretrained(
|
|
"facebook/dinov2-large", device_map=get_device()
|
|
),
|
|
)
|
|
model = DinoCompressor().to(get_device())
|
|
case "Dinov2WithCompressor":
|
|
processor = cast(
|
|
BitImageProcessorFast,
|
|
AutoImageProcessor.from_pretrained(
|
|
"facebook/dinov2-large", device_map=get_device()
|
|
),
|
|
)
|
|
|
|
compressor = FloatCompressor().load_state_dict(
|
|
torch.load(get_output_diretory() / "compressor.pt")
|
|
)
|
|
model = DinoCompressor(compressor).to(get_device())
|
|
case _:
|
|
raise ValueError(f"Unknown compressor: {compressor_model}")
|
|
|
|
match benchmark:
|
|
case "Recall@1":
|
|
task_eval(processor, model, dataset, 1)
|
|
case "Recall@10":
|
|
task_eval(processor, model, dataset, 10)
|
|
case _:
|
|
raise ValueError(f"Unknown benchmark: {benchmark}")
|
|
|
|
|
|
__all__ = ["task_eval", "evaluate"]
|