feat(utils): add feature extraction utilities and tests

This commit is contained in:
2026-03-05 20:48:53 +08:00
parent a16b376dd7
commit 5be4709acf
13 changed files with 247 additions and 138 deletions

View File

@@ -1,3 +1,3 @@
from database import DatabaseManager, db_manager, db_schema
from database import DatabaseManager, db_manager
__all__ = ["DatabaseManager", "db_manager", "db_schema"]
__all__ = ["DatabaseManager", "db_manager"]

View File

@@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
from typing import Any, Protocol
import lancedb
from torch.utils.data import DataLoader
class BaseDataset(ABC):

View File

@@ -149,9 +149,9 @@ def run_benchmark(
sample = train_dataset[0]
sample_image = sample["img"]
from .tasks.retrieval import _infer_vector_dim
from utils.feature_extractor import infer_vector_dim
vector_dim = _infer_vector_dim(processor, model, sample_image)
vector_dim = infer_vector_dim(processor, model, sample_image)
print(f"Model output dimension: {vector_dim}")
# Ensure table exists with correct schema

View File

@@ -1,10 +1,9 @@
"""Retrieval task for benchmark evaluation (Recall@K)."""
from typing import Any, cast
from typing import Any
import lancedb
import pyarrow as pa
import torch
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from torch import nn
@@ -12,31 +11,7 @@ from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
def _infer_vector_dim(
processor: BitImageProcessorFast,
model: nn.Module,
sample_image: Any,
) -> int:
"""Infer model output vector dimension via a single forward pass.
Args:
processor: Image preprocessor.
model: Feature extraction model.
sample_image: A sample image for dimension inference.
Returns:
Vector dimension.
"""
device = next(model.parameters()).device
model.eval()
with torch.no_grad():
inputs = processor(images=sample_image, return_tensors="pt")
inputs.to(device)
output = model(inputs)
return output.shape[-1]
from utils.feature_extractor import extract_batch_features, infer_vector_dim
def _build_eval_schema(vector_dim: int) -> pa.Schema:
@@ -57,7 +32,6 @@ def _build_eval_schema(vector_dim: int) -> pa.Schema:
)
@torch.no_grad()
def _establish_eval_database(
processor: BitImageProcessorFast,
model: nn.Module,
@@ -72,28 +46,22 @@ def _establish_eval_database(
table: LanceDB table to store features.
dataloader: DataLoader for the training dataset.
"""
device = next(model.parameters()).device
model.eval()
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
# Store features to database
global_idx = 0
for batch in tqdm(dataloader, desc="Building eval database"):
imgs = batch["img"]
for batch in tqdm(dataloader, desc="Storing eval database"):
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs)
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
batch_size = len(labels_list)
table.add(
[
{
"id": global_idx + j,
"label": labels_list[j],
"vector": features[j].numpy(),
"vector": all_features[global_idx + j].numpy(),
}
for j in range(batch_size)
]
@@ -101,7 +69,6 @@ def _establish_eval_database(
global_idx += batch_size
@torch.no_grad()
def _evaluate_recall(
processor: BitImageProcessorFast,
model: nn.Module,
@@ -121,25 +88,19 @@ def _evaluate_recall(
Returns:
A tuple of (correct_count, total_count).
"""
device = next(model.parameters()).device
model.eval()
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
correct = 0
total = 0
feature_idx = 0
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs)
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
for j in range(len(labels_list)):
feature = features[j].tolist()
feature = all_features[feature_idx + j].tolist()
true_label = labels_list[j]
results = (
@@ -154,6 +115,8 @@ def _evaluate_recall(
correct += 1
total += 1
feature_idx += len(labels_list)
return correct, total
@@ -191,7 +154,7 @@ class RetrievalTask(BaseBenchmarkTask):
sample = train_dataset[0]
sample_image = sample["img"]
vector_dim = _infer_vector_dim(processor, model, sample_image)
vector_dim = infer_vector_dim(processor, model, sample_image)
expected_schema = _build_eval_schema(vector_dim)
# Check schema compatibility

View File

@@ -30,6 +30,6 @@ benchmark:
task:
name: "recall_at_k"
type: "retrieval"
top_k: 10
top_k: 1
batch_size: 64
model_table_prefix: "benchmark"

View File

@@ -11,7 +11,7 @@ class ModelConfig(BaseModel):
model_config = ConfigDict(extra="ignore")
name: str = "facebook/dinov2-large"
dino_model: str = "facebook/dinov2-large"
compression_dim: int = Field(
default=512, gt=0, description="Output feature dimension"
)

View File

@@ -4,7 +4,9 @@ import lancedb
import pyarrow as pa
from configs import cfg_manager
db_schema = pa.schema(
def _build_database_schema():
return pa.schema(
[
pa.field("id", pa.int32()),
pa.field("label", pa.string()),
@@ -34,7 +36,9 @@ class DatabaseManager:
# 初始化数据库与表格
self.db = lancedb.connect(db_path)
if "default" not in self.db.list_tables().tables:
self.table = self.db.create_table("default", schema=db_schema)
self.table = self.db.create_table(
"default", schema=_build_database_schema()
)
else:
self.table = self.db.open_table("default")

View File

@@ -1,9 +1,8 @@
import io
from typing import Any, Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, cast
import torch
from database import db_manager
from datasets import load_dataset
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from torch import nn
@@ -14,6 +13,9 @@ from transformers import (
BitImageProcessorFast,
Dinov2Model,
)
from utils.feature_extractor import extract_batch_features
from datasets import load_dataset
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
@@ -86,78 +88,26 @@ class FeatureRetrieval:
batch_size: Number of images to process in a batch.
label_map: Optional mapping from label indices to string names.
"""
device = self.model.device
self.model.eval()
for i in tqdm(range(0, len(images), batch_size)):
batch_imgs = images[i : i + batch_size]
inputs = self.processor(batch_imgs, return_tensors="pt")
# 迁移数据到GPU
inputs.to(device)
outputs = self.model(**inputs)
# 后处理
feats = outputs.last_hidden_state # [B, N, D]
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
cls_tokens = cast(torch.Tensor, cls_tokens)
# 迁移输出到CPU
cls_tokens = cls_tokens.cpu()
batch_labels = (
labels[i : i + batch_size]
if label_map is None
else list(
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
# Extract features using the utility function
cls_tokens = extract_batch_features(
self.processor, self.model, images, batch_size=batch_size
)
)
actual_batch_size = len(batch_labels)
# 存库
for i in tqdm(range(len(labels)), desc="Storing to database"):
batch_label = labels[i] if label_map is None else label_map[labels[i]]
# Store to database
db_manager.table.add(
[
{
"id": i + j,
"label": batch_labels[j],
"vector": cls_tokens[j].numpy(),
"binary": pil_image_to_bytes(batch_imgs[j]),
"id": i,
"label": batch_label,
"vector": cls_tokens[i].numpy(),
"binary": pil_image_to_bytes(images[i]),
}
for j in range(actual_batch_size)
]
)
@torch.no_grad()
def extract_single_image_feature(
self, image: Union[Image.Image, Any]
) -> List[float]:
"""Extract feature from a single image without storing to database.
Args:
image: A single image (PIL Image or other supported format).
Returns:
pl.Series: The extracted CLS token feature vector as a Polars Series.
"""
device = self.model.device
self.model.eval()
# 预处理图片
inputs = self.processor(images=image, return_tensors="pt")
inputs.to(device, non_blocking=True)
# 提取特征
outputs = self.model(**inputs)
# 获取 CLS token
feats = outputs.last_hidden_state # [1, N, D]
cls_token = feats[:, 0] # [1, D]
cls_token = cast(torch.Tensor, cls_token)
# 返回 CLS List
return cls_token.cpu().squeeze(0).tolist()
if __name__ == "__main__":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")

View File

@@ -38,7 +38,7 @@ if __name__ == "__main__":
model_cfg = config.model
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device),
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
)
# Load compressor weights if specified in model config
@@ -84,4 +84,6 @@ if __name__ == "__main__":
)
generated_files = synthesizer.generate()
print(f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}")
print(
f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}"
)

View File

@@ -25,7 +25,7 @@ class TestConfigModels:
def test_model_config_defaults(self):
"""Verify ModelConfig creates with correct defaults."""
config = ModelConfig()
assert config.name == "facebook/dinov2-large"
assert config.dino_model == "facebook/dinov2-large"
assert config.compression_dim == 512
assert config.device == "auto"
@@ -73,7 +73,7 @@ class TestYamlLoader:
config = load_yaml(config_path, Config)
# Verify model config
assert config.model.name == "facebook/dinov2-large"
assert config.model.dino_model == "facebook/dinov2-large"
assert config.model.compression_dim == 256
# Verify output config

View File

@@ -0,0 +1,50 @@
"""Tests for feature extraction utilities."""
import pytest
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from utils.feature_extractor import (
extract_batch_features,
extract_single_image_feature,
infer_vector_dim,
)
TEST_MODEL_NAME = "facebook/dinov2-base"
@pytest.fixture
def model_and_processor():
processor = AutoImageProcessor.from_pretrained(TEST_MODEL_NAME)
model = AutoModel.from_pretrained(TEST_MODEL_NAME)
model.eval()
yield processor, model
del model
del processor
def test_infer_vector_dim(model_and_processor):
"""Verify infer_vector_dim returns correct dimension."""
processor, model = model_and_processor
sample_image = Image.new("RGB", (224, 224), color="blue")
dim = infer_vector_dim(processor, model, sample_image)
assert dim == 768
def test_extract_single_image_feature(model_and_processor):
"""Verify single image feature extraction."""
processor, model = model_and_processor
sample_image = Image.new("RGB", (224, 224), color="red")
features = extract_single_image_feature(processor, model, sample_image)
assert isinstance(features, list)
assert len(features) == 768
def test_extract_batch_features(model_and_processor):
"""Verify batch feature extraction."""
processor, model = model_and_processor
images = [Image.new("RGB", (224, 224), color="red") for _ in range(3)]
features = extract_batch_features(processor, model, images)
assert isinstance(features, torch.Tensor)
assert features.shape == (3, 768)

View File

@@ -1,3 +1,14 @@
from .common import get_device, get_output_diretory
from .feature_extractor import (
extract_batch_features,
extract_single_image_feature,
infer_vector_dim,
)
__all__ = ["get_device", "get_output_diretory"]
__all__ = [
"get_device",
"get_output_diretory",
"infer_vector_dim",
"extract_single_image_feature",
"extract_batch_features",
]

View File

@@ -0,0 +1,130 @@
"""Feature extraction utilities for image models."""
from typing import Any, List, Union, cast
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from transformers import BitImageProcessorFast
from tqdm.auto import tqdm
def _extract_features_from_output(output: Any) -> torch.Tensor:
"""Extract features from model output, handling both HuggingFace ModelOutput and raw tensors.
Args:
output: Model output (either ModelOutput with .last_hidden_state or raw tensor).
Returns:
Feature tensor of shape [B, D].
"""
# Handle HuggingFace ModelOutput (has .last_hidden_state)
if hasattr(output, "last_hidden_state"):
return output.last_hidden_state[:, 0] # [B, D] - CLS token
# Handle raw tensor output (like DinoCompressor)
return cast(torch.Tensor, output)
def infer_vector_dim(
processor: BitImageProcessorFast,
model: nn.Module,
sample_image: Any,
) -> int:
"""Infer model output vector dimension via a single forward pass.
Args:
processor: Image preprocessor.
model: Feature extraction model.
sample_image: A sample image for dimension inference.
Returns:
Vector dimension.
"""
device = next(model.parameters()).device
model.eval()
with torch.no_grad():
inputs = processor(images=sample_image, return_tensors="pt")
inputs.to(device)
output = model(inputs)
features = _extract_features_from_output(output)
return features.shape[-1]
@torch.no_grad()
def extract_single_image_feature(
processor: BitImageProcessorFast,
model: nn.Module,
image: Union[Image.Image, Any],
) -> List[float]:
"""Extract feature from a single image.
Args:
processor: Image preprocessor.
model: Feature extraction model.
image: A single image (PIL Image or other supported format).
Returns:
The extracted CLS token feature vector as a list of floats.
"""
device = next(model.parameters()).device
model.eval()
inputs = processor(images=image, return_tensors="pt")
inputs.to(device, non_blocking=True)
outputs = model(inputs)
features = _extract_features_from_output(outputs) # [1, D]
return features.cpu().squeeze(0).tolist()
@torch.no_grad()
def extract_batch_features(
processor: BitImageProcessorFast,
model: nn.Module,
images: Union[List[Any], Any],
batch_size: int = 32,
show_progress: bool = False,
) -> torch.Tensor:
"""Extract features from a batch of images.
Args:
processor: Image preprocessor.
model: Feature extraction model.
images: List of images, DataLoader, or other iterable.
batch_size: Batch size for processing.
show_progress: Whether to show progress bar.
Returns:
Tensor of shape [batch_size, feature_dim].
"""
device = next(model.parameters()).device
model.eval()
# Handle DataLoader input
if isinstance(images, DataLoader):
all_features = []
iterator = tqdm(images, desc="Extracting features") if show_progress else images
for batch in iterator:
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
inputs = processor(images=imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs)
features = _extract_features_from_output(outputs) # [B, D]
all_features.append(features.cpu())
return torch.cat(all_features, dim=0)
# Handle list of images
all_features = []
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
for i in iterator:
batch_imgs = images[i : i + batch_size]
inputs = processor(images=batch_imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs)
features = _extract_features_from_output(outputs) # [B, D]
all_features.append(features.cpu())
return torch.cat(all_features, dim=0)