mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(benchmarks): add evaluation framework for DINO-based compressors
This commit is contained in:
49
mini-nav/benchmarks/__init__.py
Normal file
49
mini-nav/benchmarks/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Literal, cast
|
||||
|
||||
import torch
|
||||
from compressors import DinoCompressor, FloatCompressor
|
||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||
from utils import get_device, get_output_diretory
|
||||
|
||||
from .task_eval import task_eval
|
||||
|
||||
|
||||
def evaluate(
|
||||
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
benchmark: Literal["Recall@1", "Recall@10"],
|
||||
):
|
||||
match compressor_model:
|
||||
case "Dinov2":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=get_device()
|
||||
),
|
||||
)
|
||||
model = DinoCompressor().to(get_device())
|
||||
case "Dinov2WithCompressor":
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=get_device()
|
||||
),
|
||||
)
|
||||
|
||||
compressor = FloatCompressor().load_state_dict(
|
||||
torch.load(get_output_diretory() / "compressor.pt")
|
||||
)
|
||||
model = DinoCompressor(compressor).to(get_device())
|
||||
case _:
|
||||
raise ValueError(f"Unknown compressor: {compressor_model}")
|
||||
|
||||
match benchmark:
|
||||
case "Recall@1":
|
||||
task_eval(processor, model, dataset, 1)
|
||||
case "Recall@10":
|
||||
task_eval(processor, model, dataset, 10)
|
||||
case _:
|
||||
raise ValueError(f"Unknown benchmark: {benchmark}")
|
||||
|
||||
|
||||
__all__ = ["task_eval", "evaluate"]
|
||||
77
mini-nav/benchmarks/task_eval.py
Normal file
77
mini-nav/benchmarks/task_eval.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import Literal, cast
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import BitImageProcessorFast
|
||||
from utils import get_device
|
||||
|
||||
|
||||
def establish_database(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
dataset: Dataset,
|
||||
batch_size: int = 32,
|
||||
) -> pl.DataFrame:
|
||||
df = pl.DataFrame()
|
||||
|
||||
model.eval()
|
||||
dataloader = DataLoader(
|
||||
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Establish Database"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
||||
|
||||
outputs = cast(Tensor, model(inputs))
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def task_eval(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||
top_k: int = 10,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
match dataset:
|
||||
case "CIFAR-10":
|
||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
|
||||
case "CIFAR-100":
|
||||
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
|
||||
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
|
||||
)
|
||||
|
||||
# Establish database
|
||||
df = establish_database(processor, model, train_dataset, batch_size)
|
||||
|
||||
# Test
|
||||
dataloader = DataLoader(
|
||||
test_dataset.with_format("torch"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Test Evaluation"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
||||
|
||||
outputs = cast(Tensor, model(inputs))
|
||||
for vec in outputs:
|
||||
pass
|
||||
Reference in New Issue
Block a user