diff --git a/mini-nav/benchmarks/__init__.py b/mini-nav/benchmarks/__init__.py new file mode 100644 index 0000000..62413a7 --- /dev/null +++ b/mini-nav/benchmarks/__init__.py @@ -0,0 +1,49 @@ +from typing import Literal, cast + +import torch +from compressors import DinoCompressor, FloatCompressor +from transformers import AutoImageProcessor, BitImageProcessorFast +from utils import get_device, get_output_diretory + +from .task_eval import task_eval + + +def evaluate( + compressor_model: Literal["Dinov2", "Dinov2WithCompressor"], + dataset: Literal["CIFAR-10", "CIFAR-100"], + benchmark: Literal["Recall@1", "Recall@10"], +): + match compressor_model: + case "Dinov2": + processor = cast( + BitImageProcessorFast, + AutoImageProcessor.from_pretrained( + "facebook/dinov2-large", device_map=get_device() + ), + ) + model = DinoCompressor().to(get_device()) + case "Dinov2WithCompressor": + processor = cast( + BitImageProcessorFast, + AutoImageProcessor.from_pretrained( + "facebook/dinov2-large", device_map=get_device() + ), + ) + + compressor = FloatCompressor().load_state_dict( + torch.load(get_output_diretory() / "compressor.pt") + ) + model = DinoCompressor(compressor).to(get_device()) + case _: + raise ValueError(f"Unknown compressor: {compressor_model}") + + match benchmark: + case "Recall@1": + task_eval(processor, model, dataset, 1) + case "Recall@10": + task_eval(processor, model, dataset, 10) + case _: + raise ValueError(f"Unknown benchmark: {benchmark}") + + +__all__ = ["task_eval", "evaluate"] diff --git a/mini-nav/benchmarks/task_eval.py b/mini-nav/benchmarks/task_eval.py new file mode 100644 index 0000000..7775652 --- /dev/null +++ b/mini-nav/benchmarks/task_eval.py @@ -0,0 +1,77 @@ +from typing import Literal, cast + +import polars as pl +import torch +from datasets import Dataset, load_dataset +from torch import Tensor, nn +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import BitImageProcessorFast +from utils import get_device + + +def establish_database( + processor: BitImageProcessorFast, + model: nn.Module, + dataset: Dataset, + batch_size: int = 32, +) -> pl.DataFrame: + df = pl.DataFrame() + + model.eval() + dataloader = DataLoader( + dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4 + ) + + with torch.no_grad(): + for batch in tqdm(dataloader, desc="Establish Database"): + imgs = batch["img"] + labels = batch["label"] + + inputs = processor(imgs, return_tensors="pt").to(get_device()) + + outputs = cast(Tensor, model(inputs)) + + return df + + +def task_eval( + processor: BitImageProcessorFast, + model: nn.Module, + dataset: Literal["CIFAR-10", "CIFAR-100"], + top_k: int = 10, + batch_size: int = 32, +): + match dataset: + case "CIFAR-10": + train_dataset = load_dataset("uoft-cs/cifar10", split="train") + test_dataset = load_dataset("uoft-cs/cifar10", split="test") + case "CIFAR-100": + train_dataset = load_dataset("uoft-cs/cifar100", split="train") + test_dataset = load_dataset("uoft-cs/cifar100", split="test") + case _: + raise ValueError( + f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'." + ) + + # Establish database + df = establish_database(processor, model, train_dataset, batch_size) + + # Test + dataloader = DataLoader( + test_dataset.with_format("torch"), + batch_size=batch_size, + shuffle=True, + num_workers=4, + ) + + with torch.no_grad(): + for batch in tqdm(dataloader, desc="Test Evaluation"): + imgs = batch["img"] + labels = batch["label"] + + inputs = processor(imgs, return_tensors="pt").to(get_device()) + + outputs = cast(Tensor, model(inputs)) + for vec in outputs: + pass diff --git a/mini-nav/compressors/__init__.py b/mini-nav/compressors/__init__.py index 6f77d5e..3d9efd3 100644 --- a/mini-nav/compressors/__init__.py +++ b/mini-nav/compressors/__init__.py @@ -1,5 +1,6 @@ +from .dino_compressor import DinoCompressor from .float_compressor import FloatCompressor from .int_compressor import IntCompressor from .train import train -__all__ = ["train", "FloatCompressor", "IntCompressor"] +__all__ = ["train", "FloatCompressor", "IntCompressor", "DinoCompressor"] diff --git a/mini-nav/compressors/dino_compressor.py b/mini-nav/compressors/dino_compressor.py new file mode 100644 index 0000000..1f5018b --- /dev/null +++ b/mini-nav/compressors/dino_compressor.py @@ -0,0 +1,29 @@ +from typing import Optional, cast + +import torch.nn.functional as F +from torch import nn +from transformers import AutoModel, Dinov2Model + + +class DinoCompressor(nn.Module): + def __init__(self, compressor: Optional[nn.Module] = None): + super().__init__() + + self.dino = cast( + Dinov2Model, + AutoModel.from_pretrained("facebook/dinov2-large"), + ) + + self.compressor = compressor + + def forward(self, inputs): + teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024] + + teacher_embed = teacher_tokens.mean(dim=1) + teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] + + if self.compressor is None: + return teacher_embed + + feats, recon = self.compressor(teacher_tokens) + return feats diff --git a/mini-nav/compressors/train.py b/mini-nav/compressors/train.py index 3aa9b50..2fa52aa 100644 --- a/mini-nav/compressors/train.py +++ b/mini-nav/compressors/train.py @@ -93,12 +93,10 @@ def train( with torch.no_grad(): inputs = processor(imgs, return_tensors="pt").to(device) - teacher_tokens = dino(**inputs).last_hidden_state - # [B,N,1024] + teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024] teacher_embed = teacher_tokens.mean(dim=1) - teacher_embed = F.normalize(teacher_embed, dim=-1) - # [B,1024] + teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] # ---- student forward ---- z512, recon = compressor(teacher_tokens) diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index 0181fd5..e256b6c 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -6,8 +6,14 @@ from database import db_manager from datasets import load_dataset from PIL import Image from PIL.PngImagePlugin import PngImageFile +from torch import nn from tqdm.auto import tqdm -from transformers import AutoImageProcessor, AutoModel +from transformers import ( + AutoImageProcessor, + AutoModel, + BitImageProcessorFast, + Dinov2Model, +) def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes: @@ -31,8 +37,8 @@ class FeatureRetrieval: _instance: Optional["FeatureRetrieval"] = None _initialized: bool = False - processor: Any - model: Any + processor: BitImageProcessorFast + model: nn.Module def __new__(cls, *args, **kwargs) -> "FeatureRetrieval": if cls._instance is None: @@ -40,7 +46,9 @@ class FeatureRetrieval: return cls._instance def __init__( - self, processor: Optional[Any] = None, model: Optional[Any] = None + self, + processor: Optional[BitImageProcessorFast] = None, + model: Optional[nn.Module] = None, ) -> None: """Initialize the singleton with processor and model. @@ -84,10 +92,10 @@ class FeatureRetrieval: for i in tqdm(range(0, len(images), batch_size)): batch_imgs = images[i : i + batch_size] - inputs = self.processor(images=batch_imgs, return_tensors="pt") + inputs = self.processor(batch_imgs, return_tensors="pt") # 迁移数据到GPU - inputs.to(device, non_blocking=True) + inputs.to(device) outputs = self.model(**inputs) @@ -166,10 +174,14 @@ if __name__ == "__main__": "truck", ] - processor = AutoImageProcessor.from_pretrained( - "facebook/dinov2-large", device_map="cuda" + processor = cast( + BitImageProcessorFast, + AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"), + ) + model = cast( + Dinov2Model, + AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"), ) - model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda") feature_retrieval = FeatureRetrieval(processor, model) diff --git a/mini-nav/main.py b/mini-nav/main.py index bd499f8..8dfccfc 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -2,14 +2,22 @@ import argparse if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("train") + parser.add_argument( + "action", + choices=["train", "benchmark", "visualize"], + help="Action to perform: train, benchmark, or visualize", + ) args = parser.parse_args() - if args.train: + if args.action == "train": from compressors import FloatCompressor, train train(FloatCompressor(), 1, 32) - else: + elif args.action == "benchmark": + from benchmarks import evaluate + + evaluate("Dinov2", "CIFAR-10", "Recall@10") + else: # visualize from visualizer import app app.run(debug=True) diff --git a/mini-nav/utils/__init__.py b/mini-nav/utils/__init__.py new file mode 100644 index 0000000..bf23b30 --- /dev/null +++ b/mini-nav/utils/__init__.py @@ -0,0 +1,3 @@ +from .common import get_device, get_output_diretory + +__all__ = ["get_device", "get_output_diretory"] diff --git a/mini-nav/utils/common.py b/mini-nav/utils/common.py new file mode 100644 index 0000000..0041a50 --- /dev/null +++ b/mini-nav/utils/common.py @@ -0,0 +1,23 @@ +from functools import lru_cache +from pathlib import Path + +import torch +from configs import cfg_manager +from torch.types import Device + + +@lru_cache(maxsize=1) +def get_device() -> Device: + config = cfg_manager.get() + device = config.model.device + if device == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device) + return device + + +@lru_cache(maxsize=1) +def get_output_diretory() -> Path: + config = cfg_manager.get() + return config.output.directory diff --git a/pyproject.toml b/pyproject.toml index ece73b7..7a9929b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ dependencies = [ "polars[database,numpy,pandas,pydantic]>=1.37.1", "pyarrow>=23.0.0", "pydantic>=2.12.5", - "pytest-benchmark>=5.2.3", "scikit-learn>=1.7.2", "torch>=2.10.0", "torchvision>=0.25.0", diff --git a/uv.lock b/uv.lock index 753d2df..820ed27 100644 --- a/uv.lock +++ b/uv.lock @@ -1114,7 +1114,6 @@ dependencies = [ { name = "polars", extra = ["database", "numpy", "pandas", "pydantic"] }, { name = "pyarrow" }, { name = "pydantic" }, - { name = "pytest-benchmark" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, @@ -1142,7 +1141,6 @@ requires-dist = [ { name = "polars", extras = ["database", "numpy", "pandas", "pydantic"], specifier = ">=1.37.1" }, { name = "pyarrow", specifier = ">=23.0.0" }, { name = "pydantic", specifier = ">=2.12.5" }, - { name = "pytest-benchmark", specifier = ">=5.2.3" }, { name = "scikit-learn", specifier = ">=1.7.2" }, { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=2.10.0" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu130" }, @@ -2141,15 +2139,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, ] -[[package]] -name = "py-cpuinfo" -version = "9.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, -] - [[package]] name = "pyarrow" version = "23.0.0" @@ -2367,19 +2356,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] -[[package]] -name = "pytest-benchmark" -version = "5.2.3" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "py-cpuinfo" }, - { name = "pytest" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" }, -] - [[package]] name = "pytest-mpi" version = "0.6"