diff --git a/mini-nav/__init__.py b/mini-nav/__init__.py index 2b77514..0eab1fe 100644 --- a/mini-nav/__init__.py +++ b/mini-nav/__init__.py @@ -1,3 +1,3 @@ -from database import DatabaseManager, db_manager, db_schema +from database import DatabaseManager, db_manager -__all__ = ["DatabaseManager", "db_manager", "db_schema"] +__all__ = ["DatabaseManager", "db_manager"] diff --git a/mini-nav/benchmarks/base.py b/mini-nav/benchmarks/base.py index e013393..e1b0d82 100644 --- a/mini-nav/benchmarks/base.py +++ b/mini-nav/benchmarks/base.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from typing import Any, Protocol import lancedb -from torch.utils.data import DataLoader class BaseDataset(ABC): diff --git a/mini-nav/benchmarks/runner.py b/mini-nav/benchmarks/runner.py index 4f11cb8..4f8cdab 100644 --- a/mini-nav/benchmarks/runner.py +++ b/mini-nav/benchmarks/runner.py @@ -149,9 +149,9 @@ def run_benchmark( sample = train_dataset[0] sample_image = sample["img"] - from .tasks.retrieval import _infer_vector_dim + from utils.feature_extractor import infer_vector_dim - vector_dim = _infer_vector_dim(processor, model, sample_image) + vector_dim = infer_vector_dim(processor, model, sample_image) print(f"Model output dimension: {vector_dim}") # Ensure table exists with correct schema diff --git a/mini-nav/benchmarks/tasks/retrieval.py b/mini-nav/benchmarks/tasks/retrieval.py index 2878c51..c230dbe 100644 --- a/mini-nav/benchmarks/tasks/retrieval.py +++ b/mini-nav/benchmarks/tasks/retrieval.py @@ -1,10 +1,9 @@ """Retrieval task for benchmark evaluation (Recall@K).""" -from typing import Any, cast +from typing import Any import lancedb import pyarrow as pa -import torch from benchmarks.base import BaseBenchmarkTask from benchmarks.tasks.registry import RegisterTask from torch import nn @@ -12,31 +11,7 @@ from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import BitImageProcessorFast - -def _infer_vector_dim( - processor: BitImageProcessorFast, - model: nn.Module, - sample_image: Any, -) -> int: - """Infer model output vector dimension via a single forward pass. - - Args: - processor: Image preprocessor. - model: Feature extraction model. - sample_image: A sample image for dimension inference. - - Returns: - Vector dimension. - """ - device = next(model.parameters()).device - model.eval() - - with torch.no_grad(): - inputs = processor(images=sample_image, return_tensors="pt") - inputs.to(device) - output = model(inputs) - - return output.shape[-1] +from utils.feature_extractor import extract_batch_features, infer_vector_dim def _build_eval_schema(vector_dim: int) -> pa.Schema: @@ -57,7 +32,6 @@ def _build_eval_schema(vector_dim: int) -> pa.Schema: ) -@torch.no_grad() def _establish_eval_database( processor: BitImageProcessorFast, model: nn.Module, @@ -72,28 +46,22 @@ def _establish_eval_database( table: LanceDB table to store features. dataloader: DataLoader for the training dataset. """ - device = next(model.parameters()).device - model.eval() + # Extract all features using the utility function + all_features = extract_batch_features(processor, model, dataloader, show_progress=True) + # Store features to database global_idx = 0 - for batch in tqdm(dataloader, desc="Building eval database"): - imgs = batch["img"] + for batch in tqdm(dataloader, desc="Storing eval database"): labels = batch["label"] - - inputs = processor(imgs, return_tensors="pt") - inputs.to(device) - outputs = model(inputs) - - features = cast(torch.Tensor, outputs).cpu() labels_list = labels.tolist() - batch_size = len(labels_list) + table.add( [ { "id": global_idx + j, "label": labels_list[j], - "vector": features[j].numpy(), + "vector": all_features[global_idx + j].numpy(), } for j in range(batch_size) ] @@ -101,7 +69,6 @@ def _establish_eval_database( global_idx += batch_size -@torch.no_grad() def _evaluate_recall( processor: BitImageProcessorFast, model: nn.Module, @@ -121,25 +88,19 @@ def _evaluate_recall( Returns: A tuple of (correct_count, total_count). """ - device = next(model.parameters()).device - model.eval() + # Extract all features using the utility function + all_features = extract_batch_features(processor, model, dataloader, show_progress=True) correct = 0 total = 0 + feature_idx = 0 for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"): - imgs = batch["img"] labels = batch["label"] - - inputs = processor(imgs, return_tensors="pt") - inputs.to(device) - outputs = model(inputs) - - features = cast(torch.Tensor, outputs).cpu() labels_list = labels.tolist() for j in range(len(labels_list)): - feature = features[j].tolist() + feature = all_features[feature_idx + j].tolist() true_label = labels_list[j] results = ( @@ -154,6 +115,8 @@ def _evaluate_recall( correct += 1 total += 1 + feature_idx += len(labels_list) + return correct, total @@ -191,7 +154,7 @@ class RetrievalTask(BaseBenchmarkTask): sample = train_dataset[0] sample_image = sample["img"] - vector_dim = _infer_vector_dim(processor, model, sample_image) + vector_dim = infer_vector_dim(processor, model, sample_image) expected_schema = _build_eval_schema(vector_dim) # Check schema compatibility diff --git a/mini-nav/configs/config.yaml b/mini-nav/configs/config.yaml index 122bfec..70d7700 100644 --- a/mini-nav/configs/config.yaml +++ b/mini-nav/configs/config.yaml @@ -30,6 +30,6 @@ benchmark: task: name: "recall_at_k" type: "retrieval" - top_k: 10 + top_k: 1 batch_size: 64 model_table_prefix: "benchmark" diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py index 52bf795..679d281 100644 --- a/mini-nav/configs/models.py +++ b/mini-nav/configs/models.py @@ -11,7 +11,7 @@ class ModelConfig(BaseModel): model_config = ConfigDict(extra="ignore") - name: str = "facebook/dinov2-large" + dino_model: str = "facebook/dinov2-large" compression_dim: int = Field( default=512, gt=0, description="Output feature dimension" ) diff --git a/mini-nav/database.py b/mini-nav/database.py index acf1435..da86521 100644 --- a/mini-nav/database.py +++ b/mini-nav/database.py @@ -4,14 +4,16 @@ import lancedb import pyarrow as pa from configs import cfg_manager -db_schema = pa.schema( - [ - pa.field("id", pa.int32()), - pa.field("label", pa.string()), - pa.field("vector", pa.list_(pa.float32(), 1024)), - pa.field("binary", pa.binary()), - ] -) + +def _build_database_schema(): + return pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("label", pa.string()), + pa.field("vector", pa.list_(pa.float32(), 1024)), + pa.field("binary", pa.binary()), + ] + ) class DatabaseManager: @@ -34,7 +36,9 @@ class DatabaseManager: # 初始化数据库与表格 self.db = lancedb.connect(db_path) if "default" not in self.db.list_tables().tables: - self.table = self.db.create_table("default", schema=db_schema) + self.table = self.db.create_table( + "default", schema=_build_database_schema() + ) else: self.table = self.db.open_table("default") diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index e256b6c..94bef54 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -1,9 +1,8 @@ import io -from typing import Any, Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, cast import torch from database import db_manager -from datasets import load_dataset from PIL import Image from PIL.PngImagePlugin import PngImageFile from torch import nn @@ -14,6 +13,9 @@ from transformers import ( BitImageProcessorFast, Dinov2Model, ) +from utils.feature_extractor import extract_batch_features + +from datasets import load_dataset def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes: @@ -86,78 +88,26 @@ class FeatureRetrieval: batch_size: Number of images to process in a batch. label_map: Optional mapping from label indices to string names. """ - device = self.model.device - self.model.eval() + # Extract features using the utility function + cls_tokens = extract_batch_features( + self.processor, self.model, images, batch_size=batch_size + ) - for i in tqdm(range(0, len(images), batch_size)): - batch_imgs = images[i : i + batch_size] + for i in tqdm(range(len(labels)), desc="Storing to database"): + batch_label = labels[i] if label_map is None else label_map[labels[i]] - inputs = self.processor(batch_imgs, return_tensors="pt") - - # 迁移数据到GPU - inputs.to(device) - - outputs = self.model(**inputs) - - # 后处理 - feats = outputs.last_hidden_state # [B, N, D] - cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items - cls_tokens = cast(torch.Tensor, cls_tokens) - - # 迁移输出到CPU - cls_tokens = cls_tokens.cpu() - batch_labels = ( - labels[i : i + batch_size] - if label_map is None - else list( - map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size]) - ) - ) - actual_batch_size = len(batch_labels) - - # 存库 + # Store to database db_manager.table.add( [ { - "id": i + j, - "label": batch_labels[j], - "vector": cls_tokens[j].numpy(), - "binary": pil_image_to_bytes(batch_imgs[j]), + "id": i, + "label": batch_label, + "vector": cls_tokens[i].numpy(), + "binary": pil_image_to_bytes(images[i]), } - for j in range(actual_batch_size) ] ) - @torch.no_grad() - def extract_single_image_feature( - self, image: Union[Image.Image, Any] - ) -> List[float]: - """Extract feature from a single image without storing to database. - - Args: - image: A single image (PIL Image or other supported format). - - Returns: - pl.Series: The extracted CLS token feature vector as a Polars Series. - """ - device = self.model.device - self.model.eval() - - # 预处理图片 - inputs = self.processor(images=image, return_tensors="pt") - inputs.to(device, non_blocking=True) - - # 提取特征 - outputs = self.model(**inputs) - - # 获取 CLS token - feats = outputs.last_hidden_state # [1, N, D] - cls_token = feats[:, 0] # [1, D] - cls_token = cast(torch.Tensor, cls_token) - - # 返回 CLS List - return cls_token.cpu().squeeze(0).tolist() - if __name__ == "__main__": train_dataset = load_dataset("uoft-cs/cifar10", split="train") diff --git a/mini-nav/main.py b/mini-nav/main.py index 30aa32e..d4cf882 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -38,7 +38,7 @@ if __name__ == "__main__": model_cfg = config.model processor = cast( BitImageProcessorFast, - AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device), + AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device), ) # Load compressor weights if specified in model config @@ -84,4 +84,6 @@ if __name__ == "__main__": ) generated_files = synthesizer.generate() - print(f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}") + print( + f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}" + ) diff --git a/mini-nav/tests/test_config.py b/mini-nav/tests/test_config.py index 5b9e4be..0e4019a 100644 --- a/mini-nav/tests/test_config.py +++ b/mini-nav/tests/test_config.py @@ -25,7 +25,7 @@ class TestConfigModels: def test_model_config_defaults(self): """Verify ModelConfig creates with correct defaults.""" config = ModelConfig() - assert config.name == "facebook/dinov2-large" + assert config.dino_model == "facebook/dinov2-large" assert config.compression_dim == 512 assert config.device == "auto" @@ -73,7 +73,7 @@ class TestYamlLoader: config = load_yaml(config_path, Config) # Verify model config - assert config.model.name == "facebook/dinov2-large" + assert config.model.dino_model == "facebook/dinov2-large" assert config.model.compression_dim == 256 # Verify output config diff --git a/mini-nav/tests/test_feature_extractor.py b/mini-nav/tests/test_feature_extractor.py new file mode 100644 index 0000000..2a3274c --- /dev/null +++ b/mini-nav/tests/test_feature_extractor.py @@ -0,0 +1,50 @@ +"""Tests for feature extraction utilities.""" + +import pytest +import torch +from PIL import Image +from transformers import AutoImageProcessor, AutoModel + +from utils.feature_extractor import ( + extract_batch_features, + extract_single_image_feature, + infer_vector_dim, +) + +TEST_MODEL_NAME = "facebook/dinov2-base" + + +@pytest.fixture +def model_and_processor(): + processor = AutoImageProcessor.from_pretrained(TEST_MODEL_NAME) + model = AutoModel.from_pretrained(TEST_MODEL_NAME) + model.eval() + yield processor, model + del model + del processor + + +def test_infer_vector_dim(model_and_processor): + """Verify infer_vector_dim returns correct dimension.""" + processor, model = model_and_processor + sample_image = Image.new("RGB", (224, 224), color="blue") + dim = infer_vector_dim(processor, model, sample_image) + assert dim == 768 + + +def test_extract_single_image_feature(model_and_processor): + """Verify single image feature extraction.""" + processor, model = model_and_processor + sample_image = Image.new("RGB", (224, 224), color="red") + features = extract_single_image_feature(processor, model, sample_image) + assert isinstance(features, list) + assert len(features) == 768 + + +def test_extract_batch_features(model_and_processor): + """Verify batch feature extraction.""" + processor, model = model_and_processor + images = [Image.new("RGB", (224, 224), color="red") for _ in range(3)] + features = extract_batch_features(processor, model, images) + assert isinstance(features, torch.Tensor) + assert features.shape == (3, 768) diff --git a/mini-nav/utils/__init__.py b/mini-nav/utils/__init__.py index bf23b30..1591fce 100644 --- a/mini-nav/utils/__init__.py +++ b/mini-nav/utils/__init__.py @@ -1,3 +1,14 @@ from .common import get_device, get_output_diretory +from .feature_extractor import ( + extract_batch_features, + extract_single_image_feature, + infer_vector_dim, +) -__all__ = ["get_device", "get_output_diretory"] +__all__ = [ + "get_device", + "get_output_diretory", + "infer_vector_dim", + "extract_single_image_feature", + "extract_batch_features", +] diff --git a/mini-nav/utils/feature_extractor.py b/mini-nav/utils/feature_extractor.py new file mode 100644 index 0000000..01acb29 --- /dev/null +++ b/mini-nav/utils/feature_extractor.py @@ -0,0 +1,130 @@ +"""Feature extraction utilities for image models.""" + +from typing import Any, List, Union, cast + +import torch +from PIL import Image +from torch import nn +from torch.utils.data import DataLoader +from transformers import BitImageProcessorFast +from tqdm.auto import tqdm + + +def _extract_features_from_output(output: Any) -> torch.Tensor: + """Extract features from model output, handling both HuggingFace ModelOutput and raw tensors. + + Args: + output: Model output (either ModelOutput with .last_hidden_state or raw tensor). + + Returns: + Feature tensor of shape [B, D]. + """ + # Handle HuggingFace ModelOutput (has .last_hidden_state) + if hasattr(output, "last_hidden_state"): + return output.last_hidden_state[:, 0] # [B, D] - CLS token + # Handle raw tensor output (like DinoCompressor) + return cast(torch.Tensor, output) + + +def infer_vector_dim( + processor: BitImageProcessorFast, + model: nn.Module, + sample_image: Any, +) -> int: + """Infer model output vector dimension via a single forward pass. + + Args: + processor: Image preprocessor. + model: Feature extraction model. + sample_image: A sample image for dimension inference. + + Returns: + Vector dimension. + """ + device = next(model.parameters()).device + model.eval() + + with torch.no_grad(): + inputs = processor(images=sample_image, return_tensors="pt") + inputs.to(device) + output = model(inputs) + + features = _extract_features_from_output(output) + return features.shape[-1] + + +@torch.no_grad() +def extract_single_image_feature( + processor: BitImageProcessorFast, + model: nn.Module, + image: Union[Image.Image, Any], +) -> List[float]: + """Extract feature from a single image. + + Args: + processor: Image preprocessor. + model: Feature extraction model. + image: A single image (PIL Image or other supported format). + + Returns: + The extracted CLS token feature vector as a list of floats. + """ + device = next(model.parameters()).device + model.eval() + + inputs = processor(images=image, return_tensors="pt") + inputs.to(device, non_blocking=True) + outputs = model(inputs) + + features = _extract_features_from_output(outputs) # [1, D] + return features.cpu().squeeze(0).tolist() + + +@torch.no_grad() +def extract_batch_features( + processor: BitImageProcessorFast, + model: nn.Module, + images: Union[List[Any], Any], + batch_size: int = 32, + show_progress: bool = False, +) -> torch.Tensor: + """Extract features from a batch of images. + + Args: + processor: Image preprocessor. + model: Feature extraction model. + images: List of images, DataLoader, or other iterable. + batch_size: Batch size for processing. + show_progress: Whether to show progress bar. + + Returns: + Tensor of shape [batch_size, feature_dim]. + """ + device = next(model.parameters()).device + model.eval() + + # Handle DataLoader input + if isinstance(images, DataLoader): + all_features = [] + iterator = tqdm(images, desc="Extracting features") if show_progress else images + for batch in iterator: + imgs = batch["img"] if isinstance(batch, dict) else batch[0] + inputs = processor(images=imgs, return_tensors="pt") + inputs.to(device) + outputs = model(inputs) + features = _extract_features_from_output(outputs) # [B, D] + all_features.append(features.cpu()) + return torch.cat(all_features, dim=0) + + # Handle list of images + all_features = [] + iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size) + for i in iterator: + batch_imgs = images[i : i + batch_size] + inputs = processor(images=batch_imgs, return_tensors="pt") + inputs.to(device) + outputs = model(inputs) + features = _extract_features_from_output(outputs) # [B, D] + all_features.append(features.cpu()) + + return torch.cat(all_features, dim=0)