feat(benchmarks): add evaluation framework for DINO-based compressors

This commit is contained in:
2026-02-08 22:43:38 +08:00
parent 3ba3705ba6
commit 7f6732edeb
11 changed files with 217 additions and 42 deletions

View File

@@ -6,8 +6,14 @@ from database import db_manager
from datasets import load_dataset
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
from transformers import (
AutoImageProcessor,
AutoModel,
BitImageProcessorFast,
Dinov2Model,
)
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
@@ -31,8 +37,8 @@ class FeatureRetrieval:
_instance: Optional["FeatureRetrieval"] = None
_initialized: bool = False
processor: Any
model: Any
processor: BitImageProcessorFast
model: nn.Module
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
if cls._instance is None:
@@ -40,7 +46,9 @@ class FeatureRetrieval:
return cls._instance
def __init__(
self, processor: Optional[Any] = None, model: Optional[Any] = None
self,
processor: Optional[BitImageProcessorFast] = None,
model: Optional[nn.Module] = None,
) -> None:
"""Initialize the singleton with processor and model.
@@ -84,10 +92,10 @@ class FeatureRetrieval:
for i in tqdm(range(0, len(images), batch_size)):
batch_imgs = images[i : i + batch_size]
inputs = self.processor(images=batch_imgs, return_tensors="pt")
inputs = self.processor(batch_imgs, return_tensors="pt")
# 迁移数据到GPU
inputs.to(device, non_blocking=True)
inputs.to(device)
outputs = self.model(**inputs)
@@ -166,10 +174,14 @@ if __name__ == "__main__":
"truck",
]
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map="cuda"
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
model = cast(
Dinov2Model,
AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
feature_retrieval = FeatureRetrieval(processor, model)