feat(benchmarks): add evaluation framework for DINO-based compressors

This commit is contained in:
2026-02-08 22:43:38 +08:00
parent 3ba3705ba6
commit 7f6732edeb
11 changed files with 217 additions and 42 deletions

View File

@@ -0,0 +1,49 @@
from typing import Literal, cast
import torch
from compressors import DinoCompressor, FloatCompressor
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device, get_output_diretory
from .task_eval import task_eval
def evaluate(
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
dataset: Literal["CIFAR-10", "CIFAR-100"],
benchmark: Literal["Recall@1", "Recall@10"],
):
match compressor_model:
case "Dinov2":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
),
)
model = DinoCompressor().to(get_device())
case "Dinov2WithCompressor":
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=get_device()
),
)
compressor = FloatCompressor().load_state_dict(
torch.load(get_output_diretory() / "compressor.pt")
)
model = DinoCompressor(compressor).to(get_device())
case _:
raise ValueError(f"Unknown compressor: {compressor_model}")
match benchmark:
case "Recall@1":
task_eval(processor, model, dataset, 1)
case "Recall@10":
task_eval(processor, model, dataset, 10)
case _:
raise ValueError(f"Unknown benchmark: {benchmark}")
__all__ = ["task_eval", "evaluate"]

View File

@@ -0,0 +1,77 @@
from typing import Literal, cast
import polars as pl
import torch
from datasets import Dataset, load_dataset
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils import get_device
def establish_database(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Dataset,
batch_size: int = 32,
) -> pl.DataFrame:
df = pl.DataFrame()
model.eval()
dataloader = DataLoader(
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Establish Database"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
outputs = cast(Tensor, model(inputs))
return df
def task_eval(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Literal["CIFAR-10", "CIFAR-100"],
top_k: int = 10,
batch_size: int = 32,
):
match dataset:
case "CIFAR-10":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
case "CIFAR-100":
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
case _:
raise ValueError(
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
)
# Establish database
df = establish_database(processor, model, train_dataset, batch_size)
# Test
dataloader = DataLoader(
test_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=True,
num_workers=4,
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Test Evaluation"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
outputs = cast(Tensor, model(inputs))
for vec in outputs:
pass

View File

@@ -1,5 +1,6 @@
from .dino_compressor import DinoCompressor
from .float_compressor import FloatCompressor from .float_compressor import FloatCompressor
from .int_compressor import IntCompressor from .int_compressor import IntCompressor
from .train import train from .train import train
__all__ = ["train", "FloatCompressor", "IntCompressor"] __all__ = ["train", "FloatCompressor", "IntCompressor", "DinoCompressor"]

View File

@@ -0,0 +1,29 @@
from typing import Optional, cast
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, Dinov2Model
class DinoCompressor(nn.Module):
def __init__(self, compressor: Optional[nn.Module] = None):
super().__init__()
self.dino = cast(
Dinov2Model,
AutoModel.from_pretrained("facebook/dinov2-large"),
)
self.compressor = compressor
def forward(self, inputs):
teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
if self.compressor is None:
return teacher_embed
feats, recon = self.compressor(teacher_tokens)
return feats

View File

@@ -93,12 +93,10 @@ def train(
with torch.no_grad(): with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device) inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
# [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1) teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# [B,1024]
# ---- student forward ---- # ---- student forward ----
z512, recon = compressor(teacher_tokens) z512, recon = compressor(teacher_tokens)

View File

@@ -6,8 +6,14 @@ from database import db_manager
from datasets import load_dataset from datasets import load_dataset
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngImageFile from PIL.PngImagePlugin import PngImageFile
from torch import nn
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel from transformers import (
AutoImageProcessor,
AutoModel,
BitImageProcessorFast,
Dinov2Model,
)
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes: def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
@@ -31,8 +37,8 @@ class FeatureRetrieval:
_instance: Optional["FeatureRetrieval"] = None _instance: Optional["FeatureRetrieval"] = None
_initialized: bool = False _initialized: bool = False
processor: Any processor: BitImageProcessorFast
model: Any model: nn.Module
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval": def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
if cls._instance is None: if cls._instance is None:
@@ -40,7 +46,9 @@ class FeatureRetrieval:
return cls._instance return cls._instance
def __init__( def __init__(
self, processor: Optional[Any] = None, model: Optional[Any] = None self,
processor: Optional[BitImageProcessorFast] = None,
model: Optional[nn.Module] = None,
) -> None: ) -> None:
"""Initialize the singleton with processor and model. """Initialize the singleton with processor and model.
@@ -84,10 +92,10 @@ class FeatureRetrieval:
for i in tqdm(range(0, len(images), batch_size)): for i in tqdm(range(0, len(images), batch_size)):
batch_imgs = images[i : i + batch_size] batch_imgs = images[i : i + batch_size]
inputs = self.processor(images=batch_imgs, return_tensors="pt") inputs = self.processor(batch_imgs, return_tensors="pt")
# 迁移数据到GPU # 迁移数据到GPU
inputs.to(device, non_blocking=True) inputs.to(device)
outputs = self.model(**inputs) outputs = self.model(**inputs)
@@ -166,10 +174,14 @@ if __name__ == "__main__":
"truck", "truck",
] ]
processor = AutoImageProcessor.from_pretrained( processor = cast(
"facebook/dinov2-large", device_map="cuda" BitImageProcessorFast,
AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
model = cast(
Dinov2Model,
AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"),
) )
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
feature_retrieval = FeatureRetrieval(processor, model) feature_retrieval = FeatureRetrieval(processor, model)

View File

@@ -2,14 +2,22 @@ import argparse
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train") parser.add_argument(
"action",
choices=["train", "benchmark", "visualize"],
help="Action to perform: train, benchmark, or visualize",
)
args = parser.parse_args() args = parser.parse_args()
if args.train: if args.action == "train":
from compressors import FloatCompressor, train from compressors import FloatCompressor, train
train(FloatCompressor(), 1, 32) train(FloatCompressor(), 1, 32)
else: elif args.action == "benchmark":
from benchmarks import evaluate
evaluate("Dinov2", "CIFAR-10", "Recall@10")
else: # visualize
from visualizer import app from visualizer import app
app.run(debug=True) app.run(debug=True)

View File

@@ -0,0 +1,3 @@
from .common import get_device, get_output_diretory
__all__ = ["get_device", "get_output_diretory"]

23
mini-nav/utils/common.py Normal file
View File

@@ -0,0 +1,23 @@
from functools import lru_cache
from pathlib import Path
import torch
from configs import cfg_manager
from torch.types import Device
@lru_cache(maxsize=1)
def get_device() -> Device:
config = cfg_manager.get()
device = config.model.device
if device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
return device
@lru_cache(maxsize=1)
def get_output_diretory() -> Path:
config = cfg_manager.get()
return config.output.directory

View File

@@ -15,7 +15,6 @@ dependencies = [
"polars[database,numpy,pandas,pydantic]>=1.37.1", "polars[database,numpy,pandas,pydantic]>=1.37.1",
"pyarrow>=23.0.0", "pyarrow>=23.0.0",
"pydantic>=2.12.5", "pydantic>=2.12.5",
"pytest-benchmark>=5.2.3",
"scikit-learn>=1.7.2", "scikit-learn>=1.7.2",
"torch>=2.10.0", "torch>=2.10.0",
"torchvision>=0.25.0", "torchvision>=0.25.0",

24
uv.lock generated
View File

@@ -1114,7 +1114,6 @@ dependencies = [
{ name = "polars", extra = ["database", "numpy", "pandas", "pydantic"] }, { name = "polars", extra = ["database", "numpy", "pandas", "pydantic"] },
{ name = "pyarrow" }, { name = "pyarrow" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "pytest-benchmark" },
{ name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
@@ -1142,7 +1141,6 @@ requires-dist = [
{ name = "polars", extras = ["database", "numpy", "pandas", "pydantic"], specifier = ">=1.37.1" }, { name = "polars", extras = ["database", "numpy", "pandas", "pydantic"], specifier = ">=1.37.1" },
{ name = "pyarrow", specifier = ">=23.0.0" }, { name = "pyarrow", specifier = ">=23.0.0" },
{ name = "pydantic", specifier = ">=2.12.5" }, { name = "pydantic", specifier = ">=2.12.5" },
{ name = "pytest-benchmark", specifier = ">=5.2.3" },
{ name = "scikit-learn", specifier = ">=1.7.2" }, { name = "scikit-learn", specifier = ">=1.7.2" },
{ name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.10.0" }, { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.10.0" },
{ name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu130" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu130" },
@@ -2141,15 +2139,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
] ]
[[package]]
name = "py-cpuinfo"
version = "9.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" },
]
[[package]] [[package]]
name = "pyarrow" name = "pyarrow"
version = "23.0.0" version = "23.0.0"
@@ -2367,19 +2356,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
] ]
[[package]]
name = "pytest-benchmark"
version = "5.2.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "py-cpuinfo" },
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" },
]
[[package]] [[package]]
name = "pytest-mpi" name = "pytest-mpi"
version = "0.6" version = "0.6"