mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(benchmarks): add evaluation framework for DINO-based compressors
This commit is contained in:
49
mini-nav/benchmarks/__init__.py
Normal file
49
mini-nav/benchmarks/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressors import DinoCompressor, FloatCompressor
|
||||||
|
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||||
|
from utils import get_device, get_output_diretory
|
||||||
|
|
||||||
|
from .task_eval import task_eval
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
compressor_model: Literal["Dinov2", "Dinov2WithCompressor"],
|
||||||
|
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||||
|
benchmark: Literal["Recall@1", "Recall@10"],
|
||||||
|
):
|
||||||
|
match compressor_model:
|
||||||
|
case "Dinov2":
|
||||||
|
processor = cast(
|
||||||
|
BitImageProcessorFast,
|
||||||
|
AutoImageProcessor.from_pretrained(
|
||||||
|
"facebook/dinov2-large", device_map=get_device()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = DinoCompressor().to(get_device())
|
||||||
|
case "Dinov2WithCompressor":
|
||||||
|
processor = cast(
|
||||||
|
BitImageProcessorFast,
|
||||||
|
AutoImageProcessor.from_pretrained(
|
||||||
|
"facebook/dinov2-large", device_map=get_device()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
compressor = FloatCompressor().load_state_dict(
|
||||||
|
torch.load(get_output_diretory() / "compressor.pt")
|
||||||
|
)
|
||||||
|
model = DinoCompressor(compressor).to(get_device())
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown compressor: {compressor_model}")
|
||||||
|
|
||||||
|
match benchmark:
|
||||||
|
case "Recall@1":
|
||||||
|
task_eval(processor, model, dataset, 1)
|
||||||
|
case "Recall@10":
|
||||||
|
task_eval(processor, model, dataset, 10)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown benchmark: {benchmark}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["task_eval", "evaluate"]
|
||||||
77
mini-nav/benchmarks/task_eval.py
Normal file
77
mini-nav/benchmarks/task_eval.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from typing import Literal, cast
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import BitImageProcessorFast
|
||||||
|
from utils import get_device
|
||||||
|
|
||||||
|
|
||||||
|
def establish_database(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
dataset: Dataset,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
df = pl.DataFrame()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(dataloader, desc="Establish Database"):
|
||||||
|
imgs = batch["img"]
|
||||||
|
labels = batch["label"]
|
||||||
|
|
||||||
|
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
||||||
|
|
||||||
|
outputs = cast(Tensor, model(inputs))
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def task_eval(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
||||||
|
top_k: int = 10,
|
||||||
|
batch_size: int = 32,
|
||||||
|
):
|
||||||
|
match dataset:
|
||||||
|
case "CIFAR-10":
|
||||||
|
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||||
|
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
|
||||||
|
case "CIFAR-100":
|
||||||
|
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
|
||||||
|
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Establish database
|
||||||
|
df = establish_database(processor, model, train_dataset, batch_size)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
dataloader = DataLoader(
|
||||||
|
test_dataset.with_format("torch"),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(dataloader, desc="Test Evaluation"):
|
||||||
|
imgs = batch["img"]
|
||||||
|
labels = batch["label"]
|
||||||
|
|
||||||
|
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
||||||
|
|
||||||
|
outputs = cast(Tensor, model(inputs))
|
||||||
|
for vec in outputs:
|
||||||
|
pass
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
|
from .dino_compressor import DinoCompressor
|
||||||
from .float_compressor import FloatCompressor
|
from .float_compressor import FloatCompressor
|
||||||
from .int_compressor import IntCompressor
|
from .int_compressor import IntCompressor
|
||||||
from .train import train
|
from .train import train
|
||||||
|
|
||||||
__all__ = ["train", "FloatCompressor", "IntCompressor"]
|
__all__ = ["train", "FloatCompressor", "IntCompressor", "DinoCompressor"]
|
||||||
|
|||||||
29
mini-nav/compressors/dino_compressor.py
Normal file
29
mini-nav/compressors/dino_compressor.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoModel, Dinov2Model
|
||||||
|
|
||||||
|
|
||||||
|
class DinoCompressor(nn.Module):
|
||||||
|
def __init__(self, compressor: Optional[nn.Module] = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dino = cast(
|
||||||
|
Dinov2Model,
|
||||||
|
AutoModel.from_pretrained("facebook/dinov2-large"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.compressor = compressor
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024]
|
||||||
|
|
||||||
|
teacher_embed = teacher_tokens.mean(dim=1)
|
||||||
|
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||||
|
|
||||||
|
if self.compressor is None:
|
||||||
|
return teacher_embed
|
||||||
|
|
||||||
|
feats, recon = self.compressor(teacher_tokens)
|
||||||
|
return feats
|
||||||
@@ -93,12 +93,10 @@ def train(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = processor(imgs, return_tensors="pt").to(device)
|
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||||
|
|
||||||
teacher_tokens = dino(**inputs).last_hidden_state
|
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
|
||||||
# [B,N,1024]
|
|
||||||
|
|
||||||
teacher_embed = teacher_tokens.mean(dim=1)
|
teacher_embed = teacher_tokens.mean(dim=1)
|
||||||
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||||
# [B,1024]
|
|
||||||
|
|
||||||
# ---- student forward ----
|
# ---- student forward ----
|
||||||
z512, recon = compressor(teacher_tokens)
|
z512, recon = compressor(teacher_tokens)
|
||||||
|
|||||||
@@ -6,8 +6,14 @@ from database import db_manager
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngImageFile
|
from PIL.PngImagePlugin import PngImageFile
|
||||||
|
from torch import nn
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoModel,
|
||||||
|
BitImageProcessorFast,
|
||||||
|
Dinov2Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||||
@@ -31,8 +37,8 @@ class FeatureRetrieval:
|
|||||||
_instance: Optional["FeatureRetrieval"] = None
|
_instance: Optional["FeatureRetrieval"] = None
|
||||||
|
|
||||||
_initialized: bool = False
|
_initialized: bool = False
|
||||||
processor: Any
|
processor: BitImageProcessorFast
|
||||||
model: Any
|
model: nn.Module
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
|
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
@@ -40,7 +46,9 @@ class FeatureRetrieval:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, processor: Optional[Any] = None, model: Optional[Any] = None
|
self,
|
||||||
|
processor: Optional[BitImageProcessorFast] = None,
|
||||||
|
model: Optional[nn.Module] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the singleton with processor and model.
|
"""Initialize the singleton with processor and model.
|
||||||
|
|
||||||
@@ -84,10 +92,10 @@ class FeatureRetrieval:
|
|||||||
for i in tqdm(range(0, len(images), batch_size)):
|
for i in tqdm(range(0, len(images), batch_size)):
|
||||||
batch_imgs = images[i : i + batch_size]
|
batch_imgs = images[i : i + batch_size]
|
||||||
|
|
||||||
inputs = self.processor(images=batch_imgs, return_tensors="pt")
|
inputs = self.processor(batch_imgs, return_tensors="pt")
|
||||||
|
|
||||||
# 迁移数据到GPU
|
# 迁移数据到GPU
|
||||||
inputs.to(device, non_blocking=True)
|
inputs.to(device)
|
||||||
|
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
@@ -166,10 +174,14 @@ if __name__ == "__main__":
|
|||||||
"truck",
|
"truck",
|
||||||
]
|
]
|
||||||
|
|
||||||
processor = AutoImageProcessor.from_pretrained(
|
processor = cast(
|
||||||
"facebook/dinov2-large", device_map="cuda"
|
BitImageProcessorFast,
|
||||||
|
AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"),
|
||||||
|
)
|
||||||
|
model = cast(
|
||||||
|
Dinov2Model,
|
||||||
|
AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"),
|
||||||
)
|
)
|
||||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
|
||||||
|
|
||||||
feature_retrieval = FeatureRetrieval(processor, model)
|
feature_retrieval = FeatureRetrieval(processor, model)
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,22 @@ import argparse
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("train")
|
parser.add_argument(
|
||||||
|
"action",
|
||||||
|
choices=["train", "benchmark", "visualize"],
|
||||||
|
help="Action to perform: train, benchmark, or visualize",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train:
|
if args.action == "train":
|
||||||
from compressors import FloatCompressor, train
|
from compressors import FloatCompressor, train
|
||||||
|
|
||||||
train(FloatCompressor(), 1, 32)
|
train(FloatCompressor(), 1, 32)
|
||||||
else:
|
elif args.action == "benchmark":
|
||||||
|
from benchmarks import evaluate
|
||||||
|
|
||||||
|
evaluate("Dinov2", "CIFAR-10", "Recall@10")
|
||||||
|
else: # visualize
|
||||||
from visualizer import app
|
from visualizer import app
|
||||||
|
|
||||||
app.run(debug=True)
|
app.run(debug=True)
|
||||||
|
|||||||
3
mini-nav/utils/__init__.py
Normal file
3
mini-nav/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .common import get_device, get_output_diretory
|
||||||
|
|
||||||
|
__all__ = ["get_device", "get_output_diretory"]
|
||||||
23
mini-nav/utils/common.py
Normal file
23
mini-nav/utils/common.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from configs import cfg_manager
|
||||||
|
from torch.types import Device
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_device() -> Device:
|
||||||
|
config = cfg_manager.get()
|
||||||
|
device = config.model.device
|
||||||
|
if device == "auto":
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(device)
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_output_diretory() -> Path:
|
||||||
|
config = cfg_manager.get()
|
||||||
|
return config.output.directory
|
||||||
@@ -15,7 +15,6 @@ dependencies = [
|
|||||||
"polars[database,numpy,pandas,pydantic]>=1.37.1",
|
"polars[database,numpy,pandas,pydantic]>=1.37.1",
|
||||||
"pyarrow>=23.0.0",
|
"pyarrow>=23.0.0",
|
||||||
"pydantic>=2.12.5",
|
"pydantic>=2.12.5",
|
||||||
"pytest-benchmark>=5.2.3",
|
|
||||||
"scikit-learn>=1.7.2",
|
"scikit-learn>=1.7.2",
|
||||||
"torch>=2.10.0",
|
"torch>=2.10.0",
|
||||||
"torchvision>=0.25.0",
|
"torchvision>=0.25.0",
|
||||||
|
|||||||
24
uv.lock
generated
24
uv.lock
generated
@@ -1114,7 +1114,6 @@ dependencies = [
|
|||||||
{ name = "polars", extra = ["database", "numpy", "pandas", "pydantic"] },
|
{ name = "polars", extra = ["database", "numpy", "pandas", "pydantic"] },
|
||||||
{ name = "pyarrow" },
|
{ name = "pyarrow" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pytest-benchmark" },
|
|
||||||
{ name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
{ name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||||
{ name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
{ name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||||
{ name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
{ name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" },
|
||||||
@@ -1142,7 +1141,6 @@ requires-dist = [
|
|||||||
{ name = "polars", extras = ["database", "numpy", "pandas", "pydantic"], specifier = ">=1.37.1" },
|
{ name = "polars", extras = ["database", "numpy", "pandas", "pydantic"], specifier = ">=1.37.1" },
|
||||||
{ name = "pyarrow", specifier = ">=23.0.0" },
|
{ name = "pyarrow", specifier = ">=23.0.0" },
|
||||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||||
{ name = "pytest-benchmark", specifier = ">=5.2.3" },
|
|
||||||
{ name = "scikit-learn", specifier = ">=1.7.2" },
|
{ name = "scikit-learn", specifier = ">=1.7.2" },
|
||||||
{ name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.10.0" },
|
{ name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.10.0" },
|
||||||
{ name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu130" },
|
{ name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu130" },
|
||||||
@@ -2141,15 +2139,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
|
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "py-cpuinfo"
|
|
||||||
version = "9.0.0"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyarrow"
|
name = "pyarrow"
|
||||||
version = "23.0.0"
|
version = "23.0.0"
|
||||||
@@ -2367,19 +2356,6 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pytest-benchmark"
|
|
||||||
version = "5.2.3"
|
|
||||||
source = { registry = "https://pypi.org/simple" }
|
|
||||||
dependencies = [
|
|
||||||
{ name = "py-cpuinfo" },
|
|
||||||
{ name = "pytest" },
|
|
||||||
]
|
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" }
|
|
||||||
wheels = [
|
|
||||||
{ url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest-mpi"
|
name = "pytest-mpi"
|
||||||
version = "0.6"
|
version = "0.6"
|
||||||
|
|||||||
Reference in New Issue
Block a user