feat(benchmarks): add evaluation framework for DINO-based compressors

This commit is contained in:
2026-02-08 22:43:38 +08:00
parent 3ba3705ba6
commit 7f6732edeb
11 changed files with 217 additions and 42 deletions

View File

@@ -0,0 +1,49 @@
from typing import Literal, cast
import torch
from compressors import DinoCompressor, FloatCompressor
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device, get_output_diretory
from .task_eval import task_eval
def evaluate(
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
dataset: Literal["CIFAR-10", "CIFAR-100"],
benchmark: Literal["Recall@1", "Recall@10"],
):
match compressor_model:
case "Dinov2":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
),
)
model = DinoCompressor().to(get_device())
case "Dinov2WithCompressor":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
),
)
compressor = FloatCompressor().load_state_dict(
torch.load(get_output_diretory() / "compressor.pt")
)
model = DinoCompressor(compressor).to(get_device())
case _:
raise ValueError(f"Unknown compressor: {compressor_model}")
match benchmark:
case "Recall@1":
task_eval(processor, model, dataset, 1)
case "Recall@10":
task_eval(processor, model, dataset, 10)
case _:
raise ValueError(f"Unknown benchmark: {benchmark}")
__all__ = ["task_eval", "evaluate"]

View File

@@ -0,0 +1,77 @@
from typing import Literal, cast
import polars as pl
import torch
from datasets import Dataset, load_dataset
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils import get_device
def establish_database(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Dataset,
batch_size: int = 32,
) -> pl.DataFrame:
df = pl.DataFrame()
model.eval()
dataloader = DataLoader(
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Establish Database"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
outputs = cast(Tensor, model(inputs))
return df
def task_eval(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Literal["CIFAR-10", "CIFAR-100"],
top_k: int = 10,
batch_size: int = 32,
):
match dataset:
case "CIFAR-10":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
case "CIFAR-100":
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
case _:
raise ValueError(
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
)
# Establish database
df = establish_database(processor, model, train_dataset, batch_size)
# Test
dataloader = DataLoader(
test_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=True,
num_workers=4,
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Test Evaluation"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
outputs = cast(Tensor, model(inputs))
for vec in outputs:
pass

View File

@@ -1,5 +1,6 @@
from .dino_compressor import DinoCompressor
from .float_compressor import FloatCompressor
from .int_compressor import IntCompressor
from .train import train
__all__ = ["train", "FloatCompressor", "IntCompressor"]
__all__ = ["train", "FloatCompressor", "IntCompressor", "DinoCompressor"]

View File

@@ -0,0 +1,29 @@
from typing import Optional, cast
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, Dinov2Model
class DinoCompressor(nn.Module):
def __init__(self, compressor: Optional[nn.Module] = None):
super().__init__()
self.dino = cast(
Dinov2Model,
AutoModel.from_pretrained("facebook/dinov2-large"),
)
self.compressor = compressor
def forward(self, inputs):
teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
if self.compressor is None:
return teacher_embed
feats, recon = self.compressor(teacher_tokens)
return feats

View File

@@ -93,12 +93,10 @@ def train(
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state
# [B,N,1024]
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1)
# [B,1024]
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# ---- student forward ----
z512, recon = compressor(teacher_tokens)

View File

@@ -6,8 +6,14 @@ from database import db_manager
from datasets import load_dataset
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
from transformers import (
AutoImageProcessor,
AutoModel,
BitImageProcessorFast,
Dinov2Model,
)
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
@@ -31,8 +37,8 @@ class FeatureRetrieval:
_instance: Optional["FeatureRetrieval"] = None
_initialized: bool = False
processor: Any
model: Any
processor: BitImageProcessorFast
model: nn.Module
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
if cls._instance is None:
@@ -40,7 +46,9 @@ class FeatureRetrieval:
return cls._instance
def __init__(
self, processor: Optional[Any] = None, model: Optional[Any] = None
self,
processor: Optional[BitImageProcessorFast] = None,
model: Optional[nn.Module] = None,
) -> None:
"""Initialize the singleton with processor and model.
@@ -84,10 +92,10 @@ class FeatureRetrieval:
for i in tqdm(range(0, len(images), batch_size)):
batch_imgs = images[i : i + batch_size]
inputs = self.processor(images=batch_imgs, return_tensors="pt")
inputs = self.processor(batch_imgs, return_tensors="pt")
# 迁移数据到GPU
inputs.to(device, non_blocking=True)
inputs.to(device)
outputs = self.model(**inputs)
@@ -166,10 +174,14 @@ if __name__ == "__main__":
"truck",
]
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map="cuda"
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
model = cast(
Dinov2Model,
AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
feature_retrieval = FeatureRetrieval(processor, model)

View File

@@ -2,14 +2,22 @@ import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("train")
parser.add_argument(
"action",
choices=["train", "benchmark", "visualize"],
help="Action to perform: train, benchmark, or visualize",
)
args = parser.parse_args()
if args.train:
if args.action == "train":
from compressors import FloatCompressor, train
train(FloatCompressor(), 1, 32)
else:
elif args.action == "benchmark":
from benchmarks import evaluate
evaluate("Dinov2", "CIFAR-10", "Recall@10")
else: # visualize
from visualizer import app
app.run(debug=True)

View File

@@ -0,0 +1,3 @@
from .common import get_device, get_output_diretory
__all__ = ["get_device", "get_output_diretory"]

23
mini-nav/utils/common.py Normal file
View File

@@ -0,0 +1,23 @@
from functools import lru_cache
from pathlib import Path
import torch
from configs import cfg_manager
from torch.types import Device
@lru_cache(maxsize=1)
def get_device() -> Device:
config = cfg_manager.get()
device = config.model.device
if device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
return device
@lru_cache(maxsize=1)
def get_output_diretory() -> Path:
config = cfg_manager.get()
return config.output.directory