mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(benchmarks): add evaluation framework for DINO-based compressors
This commit is contained in:
@@ -6,8 +6,14 @@ from database import db_manager
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngImageFile
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
BitImageProcessorFast,
|
||||
Dinov2Model,
|
||||
)
|
||||
|
||||
|
||||
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||
@@ -31,8 +37,8 @@ class FeatureRetrieval:
|
||||
_instance: Optional["FeatureRetrieval"] = None
|
||||
|
||||
_initialized: bool = False
|
||||
processor: Any
|
||||
model: Any
|
||||
processor: BitImageProcessorFast
|
||||
model: nn.Module
|
||||
|
||||
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
|
||||
if cls._instance is None:
|
||||
@@ -40,7 +46,9 @@ class FeatureRetrieval:
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self, processor: Optional[Any] = None, model: Optional[Any] = None
|
||||
self,
|
||||
processor: Optional[BitImageProcessorFast] = None,
|
||||
model: Optional[nn.Module] = None,
|
||||
) -> None:
|
||||
"""Initialize the singleton with processor and model.
|
||||
|
||||
@@ -84,10 +92,10 @@ class FeatureRetrieval:
|
||||
for i in tqdm(range(0, len(images), batch_size)):
|
||||
batch_imgs = images[i : i + batch_size]
|
||||
|
||||
inputs = self.processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = self.processor(batch_imgs, return_tensors="pt")
|
||||
|
||||
# 迁移数据到GPU
|
||||
inputs.to(device, non_blocking=True)
|
||||
inputs.to(device)
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
@@ -166,10 +174,14 @@ if __name__ == "__main__":
|
||||
"truck",
|
||||
]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map="cuda"
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"),
|
||||
)
|
||||
model = cast(
|
||||
Dinov2Model,
|
||||
AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"),
|
||||
)
|
||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
||||
|
||||
feature_retrieval = FeatureRetrieval(processor, model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user