mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(utils): add feature extraction utilities and tests
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from database import DatabaseManager, db_manager, db_schema
|
||||
from database import DatabaseManager, db_manager
|
||||
|
||||
__all__ = ["DatabaseManager", "db_manager", "db_schema"]
|
||||
__all__ = ["DatabaseManager", "db_manager"]
|
||||
|
||||
@@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Protocol
|
||||
|
||||
import lancedb
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class BaseDataset(ABC):
|
||||
|
||||
@@ -149,9 +149,9 @@ def run_benchmark(
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["img"]
|
||||
|
||||
from .tasks.retrieval import _infer_vector_dim
|
||||
from utils.feature_extractor import infer_vector_dim
|
||||
|
||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||
print(f"Model output dimension: {vector_dim}")
|
||||
|
||||
# Ensure table exists with correct schema
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
"""Retrieval task for benchmark evaluation (Recall@K)."""
|
||||
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
from benchmarks.tasks.registry import RegisterTask
|
||||
from torch import nn
|
||||
@@ -12,31 +11,7 @@ from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import BitImageProcessorFast
|
||||
|
||||
|
||||
def _infer_vector_dim(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
sample_image: Any,
|
||||
) -> int:
|
||||
"""Infer model output vector dimension via a single forward pass.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
sample_image: A sample image for dimension inference.
|
||||
|
||||
Returns:
|
||||
Vector dimension.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(images=sample_image, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
output = model(inputs)
|
||||
|
||||
return output.shape[-1]
|
||||
from utils.feature_extractor import extract_batch_features, infer_vector_dim
|
||||
|
||||
|
||||
def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||
@@ -57,7 +32,6 @@ def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _establish_eval_database(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
@@ -72,28 +46,22 @@ def _establish_eval_database(
|
||||
table: LanceDB table to store features.
|
||||
dataloader: DataLoader for the training dataset.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
# Extract all features using the utility function
|
||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||
|
||||
# Store features to database
|
||||
global_idx = 0
|
||||
for batch in tqdm(dataloader, desc="Building eval database"):
|
||||
imgs = batch["img"]
|
||||
for batch in tqdm(dataloader, desc="Storing eval database"):
|
||||
labels = batch["label"]
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs)
|
||||
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
batch_size = len(labels_list)
|
||||
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
"id": global_idx + j,
|
||||
"label": labels_list[j],
|
||||
"vector": features[j].numpy(),
|
||||
"vector": all_features[global_idx + j].numpy(),
|
||||
}
|
||||
for j in range(batch_size)
|
||||
]
|
||||
@@ -101,7 +69,6 @@ def _establish_eval_database(
|
||||
global_idx += batch_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _evaluate_recall(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
@@ -121,25 +88,19 @@ def _evaluate_recall(
|
||||
Returns:
|
||||
A tuple of (correct_count, total_count).
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
# Extract all features using the utility function
|
||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
feature_idx = 0
|
||||
|
||||
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs)
|
||||
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
for j in range(len(labels_list)):
|
||||
feature = features[j].tolist()
|
||||
feature = all_features[feature_idx + j].tolist()
|
||||
true_label = labels_list[j]
|
||||
|
||||
results = (
|
||||
@@ -154,6 +115,8 @@ def _evaluate_recall(
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
feature_idx += len(labels_list)
|
||||
|
||||
return correct, total
|
||||
|
||||
|
||||
@@ -191,7 +154,7 @@ class RetrievalTask(BaseBenchmarkTask):
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["img"]
|
||||
|
||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||
expected_schema = _build_eval_schema(vector_dim)
|
||||
|
||||
# Check schema compatibility
|
||||
|
||||
@@ -30,6 +30,6 @@ benchmark:
|
||||
task:
|
||||
name: "recall_at_k"
|
||||
type: "retrieval"
|
||||
top_k: 10
|
||||
top_k: 1
|
||||
batch_size: 64
|
||||
model_table_prefix: "benchmark"
|
||||
|
||||
@@ -11,7 +11,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
name: str = "facebook/dinov2-large"
|
||||
dino_model: str = "facebook/dinov2-large"
|
||||
compression_dim: int = Field(
|
||||
default=512, gt=0, description="Output feature dimension"
|
||||
)
|
||||
|
||||
@@ -4,14 +4,16 @@ import lancedb
|
||||
import pyarrow as pa
|
||||
from configs import cfg_manager
|
||||
|
||||
db_schema = pa.schema(
|
||||
|
||||
def _build_database_schema():
|
||||
return pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("label", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
||||
pa.field("binary", pa.binary()),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
@@ -34,7 +36,9 @@ class DatabaseManager:
|
||||
# 初始化数据库与表格
|
||||
self.db = lancedb.connect(db_path)
|
||||
if "default" not in self.db.list_tables().tables:
|
||||
self.table = self.db.create_table("default", schema=db_schema)
|
||||
self.table = self.db.create_table(
|
||||
"default", schema=_build_database_schema()
|
||||
)
|
||||
else:
|
||||
self.table = self.db.open_table("default")
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import io
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
from database import db_manager
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngImageFile
|
||||
from torch import nn
|
||||
@@ -14,6 +13,9 @@ from transformers import (
|
||||
BitImageProcessorFast,
|
||||
Dinov2Model,
|
||||
)
|
||||
from utils.feature_extractor import extract_batch_features
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||
@@ -86,78 +88,26 @@ class FeatureRetrieval:
|
||||
batch_size: Number of images to process in a batch.
|
||||
label_map: Optional mapping from label indices to string names.
|
||||
"""
|
||||
device = self.model.device
|
||||
self.model.eval()
|
||||
|
||||
for i in tqdm(range(0, len(images), batch_size)):
|
||||
batch_imgs = images[i : i + batch_size]
|
||||
|
||||
inputs = self.processor(batch_imgs, return_tensors="pt")
|
||||
|
||||
# 迁移数据到GPU
|
||||
inputs.to(device)
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# 后处理
|
||||
feats = outputs.last_hidden_state # [B, N, D]
|
||||
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
|
||||
cls_tokens = cast(torch.Tensor, cls_tokens)
|
||||
|
||||
# 迁移输出到CPU
|
||||
cls_tokens = cls_tokens.cpu()
|
||||
batch_labels = (
|
||||
labels[i : i + batch_size]
|
||||
if label_map is None
|
||||
else list(
|
||||
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
||||
# Extract features using the utility function
|
||||
cls_tokens = extract_batch_features(
|
||||
self.processor, self.model, images, batch_size=batch_size
|
||||
)
|
||||
)
|
||||
actual_batch_size = len(batch_labels)
|
||||
|
||||
# 存库
|
||||
for i in tqdm(range(len(labels)), desc="Storing to database"):
|
||||
batch_label = labels[i] if label_map is None else label_map[labels[i]]
|
||||
|
||||
# Store to database
|
||||
db_manager.table.add(
|
||||
[
|
||||
{
|
||||
"id": i + j,
|
||||
"label": batch_labels[j],
|
||||
"vector": cls_tokens[j].numpy(),
|
||||
"binary": pil_image_to_bytes(batch_imgs[j]),
|
||||
"id": i,
|
||||
"label": batch_label,
|
||||
"vector": cls_tokens[i].numpy(),
|
||||
"binary": pil_image_to_bytes(images[i]),
|
||||
}
|
||||
for j in range(actual_batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_single_image_feature(
|
||||
self, image: Union[Image.Image, Any]
|
||||
) -> List[float]:
|
||||
"""Extract feature from a single image without storing to database.
|
||||
|
||||
Args:
|
||||
image: A single image (PIL Image or other supported format).
|
||||
|
||||
Returns:
|
||||
pl.Series: The extracted CLS token feature vector as a Polars Series.
|
||||
"""
|
||||
device = self.model.device
|
||||
self.model.eval()
|
||||
|
||||
# 预处理图片
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
inputs.to(device, non_blocking=True)
|
||||
|
||||
# 提取特征
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# 获取 CLS token
|
||||
feats = outputs.last_hidden_state # [1, N, D]
|
||||
cls_token = feats[:, 0] # [1, D]
|
||||
cls_token = cast(torch.Tensor, cls_token)
|
||||
|
||||
# 返回 CLS List
|
||||
return cls_token.cpu().squeeze(0).tolist()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||
|
||||
@@ -38,7 +38,7 @@ if __name__ == "__main__":
|
||||
model_cfg = config.model
|
||||
processor = cast(
|
||||
BitImageProcessorFast,
|
||||
AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device),
|
||||
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
|
||||
)
|
||||
|
||||
# Load compressor weights if specified in model config
|
||||
@@ -84,4 +84,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
generated_files = synthesizer.generate()
|
||||
print(f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}")
|
||||
print(
|
||||
f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}"
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestConfigModels:
|
||||
def test_model_config_defaults(self):
|
||||
"""Verify ModelConfig creates with correct defaults."""
|
||||
config = ModelConfig()
|
||||
assert config.name == "facebook/dinov2-large"
|
||||
assert config.dino_model == "facebook/dinov2-large"
|
||||
assert config.compression_dim == 512
|
||||
assert config.device == "auto"
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestYamlLoader:
|
||||
config = load_yaml(config_path, Config)
|
||||
|
||||
# Verify model config
|
||||
assert config.model.name == "facebook/dinov2-large"
|
||||
assert config.model.dino_model == "facebook/dinov2-large"
|
||||
assert config.model.compression_dim == 256
|
||||
|
||||
# Verify output config
|
||||
|
||||
50
mini-nav/tests/test_feature_extractor.py
Normal file
50
mini-nav/tests/test_feature_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Tests for feature extraction utilities."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from utils.feature_extractor import (
|
||||
extract_batch_features,
|
||||
extract_single_image_feature,
|
||||
infer_vector_dim,
|
||||
)
|
||||
|
||||
TEST_MODEL_NAME = "facebook/dinov2-base"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_and_processor():
|
||||
processor = AutoImageProcessor.from_pretrained(TEST_MODEL_NAME)
|
||||
model = AutoModel.from_pretrained(TEST_MODEL_NAME)
|
||||
model.eval()
|
||||
yield processor, model
|
||||
del model
|
||||
del processor
|
||||
|
||||
|
||||
def test_infer_vector_dim(model_and_processor):
|
||||
"""Verify infer_vector_dim returns correct dimension."""
|
||||
processor, model = model_and_processor
|
||||
sample_image = Image.new("RGB", (224, 224), color="blue")
|
||||
dim = infer_vector_dim(processor, model, sample_image)
|
||||
assert dim == 768
|
||||
|
||||
|
||||
def test_extract_single_image_feature(model_and_processor):
|
||||
"""Verify single image feature extraction."""
|
||||
processor, model = model_and_processor
|
||||
sample_image = Image.new("RGB", (224, 224), color="red")
|
||||
features = extract_single_image_feature(processor, model, sample_image)
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 768
|
||||
|
||||
|
||||
def test_extract_batch_features(model_and_processor):
|
||||
"""Verify batch feature extraction."""
|
||||
processor, model = model_and_processor
|
||||
images = [Image.new("RGB", (224, 224), color="red") for _ in range(3)]
|
||||
features = extract_batch_features(processor, model, images)
|
||||
assert isinstance(features, torch.Tensor)
|
||||
assert features.shape == (3, 768)
|
||||
@@ -1,3 +1,14 @@
|
||||
from .common import get_device, get_output_diretory
|
||||
from .feature_extractor import (
|
||||
extract_batch_features,
|
||||
extract_single_image_feature,
|
||||
infer_vector_dim,
|
||||
)
|
||||
|
||||
__all__ = ["get_device", "get_output_diretory"]
|
||||
__all__ = [
|
||||
"get_device",
|
||||
"get_output_diretory",
|
||||
"infer_vector_dim",
|
||||
"extract_single_image_feature",
|
||||
"extract_batch_features",
|
||||
]
|
||||
|
||||
130
mini-nav/utils/feature_extractor.py
Normal file
130
mini-nav/utils/feature_extractor.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Feature extraction utilities for image models."""
|
||||
|
||||
from typing import Any, List, Union, cast
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BitImageProcessorFast
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def _extract_features_from_output(output: Any) -> torch.Tensor:
|
||||
"""Extract features from model output, handling both HuggingFace ModelOutput and raw tensors.
|
||||
|
||||
Args:
|
||||
output: Model output (either ModelOutput with .last_hidden_state or raw tensor).
|
||||
|
||||
Returns:
|
||||
Feature tensor of shape [B, D].
|
||||
"""
|
||||
# Handle HuggingFace ModelOutput (has .last_hidden_state)
|
||||
if hasattr(output, "last_hidden_state"):
|
||||
return output.last_hidden_state[:, 0] # [B, D] - CLS token
|
||||
# Handle raw tensor output (like DinoCompressor)
|
||||
return cast(torch.Tensor, output)
|
||||
|
||||
|
||||
def infer_vector_dim(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
sample_image: Any,
|
||||
) -> int:
|
||||
"""Infer model output vector dimension via a single forward pass.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
sample_image: A sample image for dimension inference.
|
||||
|
||||
Returns:
|
||||
Vector dimension.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(images=sample_image, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
output = model(inputs)
|
||||
|
||||
features = _extract_features_from_output(output)
|
||||
return features.shape[-1]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_single_image_feature(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
image: Union[Image.Image, Any],
|
||||
) -> List[float]:
|
||||
"""Extract feature from a single image.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
image: A single image (PIL Image or other supported format).
|
||||
|
||||
Returns:
|
||||
The extracted CLS token feature vector as a list of floats.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
inputs.to(device, non_blocking=True)
|
||||
outputs = model(inputs)
|
||||
|
||||
features = _extract_features_from_output(outputs) # [1, D]
|
||||
return features.cpu().squeeze(0).tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_batch_features(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
images: Union[List[Any], Any],
|
||||
batch_size: int = 32,
|
||||
show_progress: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Extract features from a batch of images.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
images: List of images, DataLoader, or other iterable.
|
||||
batch_size: Batch size for processing.
|
||||
show_progress: Whether to show progress bar.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [batch_size, feature_dim].
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
# Handle DataLoader input
|
||||
if isinstance(images, DataLoader):
|
||||
all_features = []
|
||||
iterator = tqdm(images, desc="Extracting features") if show_progress else images
|
||||
for batch in iterator:
|
||||
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
|
||||
inputs = processor(images=imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs)
|
||||
features = _extract_features_from_output(outputs) # [B, D]
|
||||
all_features.append(features.cpu())
|
||||
return torch.cat(all_features, dim=0)
|
||||
|
||||
# Handle list of images
|
||||
all_features = []
|
||||
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
|
||||
for i in iterator:
|
||||
batch_imgs = images[i : i + batch_size]
|
||||
inputs = processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs)
|
||||
features = _extract_features_from_output(outputs) # [B, D]
|
||||
all_features.append(features.cpu())
|
||||
|
||||
return torch.cat(all_features, dim=0)
|
||||
Reference in New Issue
Block a user