Files
Mini-Nav/mini-nav/feature_retrieval.py

143 lines
4.0 KiB
Python

import io
from typing import Dict, List, Optional, cast
import torch
from database import db_manager
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from torch import nn
from tqdm.auto import tqdm
from transformers import (
AutoImageProcessor,
AutoModel,
BitImageProcessorFast,
Dinov2Model,
)
from utils.feature_extractor import extract_batch_features
from datasets import load_dataset
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
"""Convert a PIL Image to bytes in the specified format.
Args:
image: PIL Image to convert.
format: Image format (e.g., 'PNG', 'JPEG').
Returns:
bytes: The encoded image bytes.
"""
buffer = io.BytesIO()
image.save(buffer, format=format)
return buffer.getvalue()
class FeatureRetrieval:
"""Singleton feature retrieval manager for image feature extraction."""
_instance: Optional["FeatureRetrieval"] = None
_initialized: bool = False
processor: BitImageProcessorFast
model: nn.Module
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
processor: Optional[BitImageProcessorFast] = None,
model: Optional[nn.Module] = None,
) -> None:
"""Initialize the singleton with processor and model.
Args:
processor: Image processor for preprocessing images.
model: Model for feature extraction.
"""
# 如果已经初始化过,直接返回
if self._initialized:
return
# 首次初始化时必须提供 processor 和 model
if processor is None or model is None:
raise ValueError(
"Processor and model must be provided on first initialization."
)
self.processor = processor
self.model = model
self._initialized = True
@torch.no_grad()
def establish_database(
self,
images: List[PngImageFile],
labels: List[int] | List[str],
batch_size: int = 64,
label_map: Optional[Dict[int, str] | List[str]] = None,
) -> None:
"""Extract features from images and store them in the database.
Args:
images: List of images to process.
labels: List of labels corresponding to images.
batch_size: Number of images to process in a batch.
label_map: Optional mapping from label indices to string names.
"""
# Extract features using the utility function
cls_tokens = extract_batch_features(
self.processor, self.model, images, batch_size=batch_size
)
for i in tqdm(range(len(labels)), desc="Storing to database"):
batch_label = labels[i] if label_map is None else label_map[labels[i]]
# Store to database
db_manager.table.add(
[
{
"id": i,
"label": batch_label,
"vector": cls_tokens[i].numpy(),
"binary": pil_image_to_bytes(images[i]),
}
]
)
if __name__ == "__main__":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
label_map = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
model = cast(
Dinov2Model,
AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda"),
)
feature_retrieval = FeatureRetrieval(processor, model)
feature_retrieval.establish_database(
train_dataset["img"],
train_dataset["label"],
label_map=label_map,
)