mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(utils): add feature extraction utilities and tests
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
from database import DatabaseManager, db_manager, db_schema
|
from database import DatabaseManager, db_manager
|
||||||
|
|
||||||
__all__ = ["DatabaseManager", "db_manager", "db_schema"]
|
__all__ = ["DatabaseManager", "db_manager"]
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(ABC):
|
class BaseDataset(ABC):
|
||||||
|
|||||||
@@ -149,9 +149,9 @@ def run_benchmark(
|
|||||||
sample = train_dataset[0]
|
sample = train_dataset[0]
|
||||||
sample_image = sample["img"]
|
sample_image = sample["img"]
|
||||||
|
|
||||||
from .tasks.retrieval import _infer_vector_dim
|
from utils.feature_extractor import infer_vector_dim
|
||||||
|
|
||||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||||
print(f"Model output dimension: {vector_dim}")
|
print(f"Model output dimension: {vector_dim}")
|
||||||
|
|
||||||
# Ensure table exists with correct schema
|
# Ensure table exists with correct schema
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""Retrieval task for benchmark evaluation (Recall@K)."""
|
"""Retrieval task for benchmark evaluation (Recall@K)."""
|
||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
|
||||||
from benchmarks.base import BaseBenchmarkTask
|
from benchmarks.base import BaseBenchmarkTask
|
||||||
from benchmarks.tasks.registry import RegisterTask
|
from benchmarks.tasks.registry import RegisterTask
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -12,31 +11,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import BitImageProcessorFast
|
from transformers import BitImageProcessorFast
|
||||||
|
|
||||||
|
from utils.feature_extractor import extract_batch_features, infer_vector_dim
|
||||||
def _infer_vector_dim(
|
|
||||||
processor: BitImageProcessorFast,
|
|
||||||
model: nn.Module,
|
|
||||||
sample_image: Any,
|
|
||||||
) -> int:
|
|
||||||
"""Infer model output vector dimension via a single forward pass.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
processor: Image preprocessor.
|
|
||||||
model: Feature extraction model.
|
|
||||||
sample_image: A sample image for dimension inference.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Vector dimension.
|
|
||||||
"""
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
inputs = processor(images=sample_image, return_tensors="pt")
|
|
||||||
inputs.to(device)
|
|
||||||
output = model(inputs)
|
|
||||||
|
|
||||||
return output.shape[-1]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||||
@@ -57,7 +32,6 @@ def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _establish_eval_database(
|
def _establish_eval_database(
|
||||||
processor: BitImageProcessorFast,
|
processor: BitImageProcessorFast,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@@ -72,28 +46,22 @@ def _establish_eval_database(
|
|||||||
table: LanceDB table to store features.
|
table: LanceDB table to store features.
|
||||||
dataloader: DataLoader for the training dataset.
|
dataloader: DataLoader for the training dataset.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
# Extract all features using the utility function
|
||||||
model.eval()
|
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||||
|
|
||||||
|
# Store features to database
|
||||||
global_idx = 0
|
global_idx = 0
|
||||||
for batch in tqdm(dataloader, desc="Building eval database"):
|
for batch in tqdm(dataloader, desc="Storing eval database"):
|
||||||
imgs = batch["img"]
|
|
||||||
labels = batch["label"]
|
labels = batch["label"]
|
||||||
|
|
||||||
inputs = processor(imgs, return_tensors="pt")
|
|
||||||
inputs.to(device)
|
|
||||||
outputs = model(inputs)
|
|
||||||
|
|
||||||
features = cast(torch.Tensor, outputs).cpu()
|
|
||||||
labels_list = labels.tolist()
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
batch_size = len(labels_list)
|
batch_size = len(labels_list)
|
||||||
|
|
||||||
table.add(
|
table.add(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": global_idx + j,
|
"id": global_idx + j,
|
||||||
"label": labels_list[j],
|
"label": labels_list[j],
|
||||||
"vector": features[j].numpy(),
|
"vector": all_features[global_idx + j].numpy(),
|
||||||
}
|
}
|
||||||
for j in range(batch_size)
|
for j in range(batch_size)
|
||||||
]
|
]
|
||||||
@@ -101,7 +69,6 @@ def _establish_eval_database(
|
|||||||
global_idx += batch_size
|
global_idx += batch_size
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _evaluate_recall(
|
def _evaluate_recall(
|
||||||
processor: BitImageProcessorFast,
|
processor: BitImageProcessorFast,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@@ -121,25 +88,19 @@ def _evaluate_recall(
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple of (correct_count, total_count).
|
A tuple of (correct_count, total_count).
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
# Extract all features using the utility function
|
||||||
model.eval()
|
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||||
|
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
feature_idx = 0
|
||||||
|
|
||||||
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
||||||
imgs = batch["img"]
|
|
||||||
labels = batch["label"]
|
labels = batch["label"]
|
||||||
|
|
||||||
inputs = processor(imgs, return_tensors="pt")
|
|
||||||
inputs.to(device)
|
|
||||||
outputs = model(inputs)
|
|
||||||
|
|
||||||
features = cast(torch.Tensor, outputs).cpu()
|
|
||||||
labels_list = labels.tolist()
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
for j in range(len(labels_list)):
|
for j in range(len(labels_list)):
|
||||||
feature = features[j].tolist()
|
feature = all_features[feature_idx + j].tolist()
|
||||||
true_label = labels_list[j]
|
true_label = labels_list[j]
|
||||||
|
|
||||||
results = (
|
results = (
|
||||||
@@ -154,6 +115,8 @@ def _evaluate_recall(
|
|||||||
correct += 1
|
correct += 1
|
||||||
total += 1
|
total += 1
|
||||||
|
|
||||||
|
feature_idx += len(labels_list)
|
||||||
|
|
||||||
return correct, total
|
return correct, total
|
||||||
|
|
||||||
|
|
||||||
@@ -191,7 +154,7 @@ class RetrievalTask(BaseBenchmarkTask):
|
|||||||
sample = train_dataset[0]
|
sample = train_dataset[0]
|
||||||
sample_image = sample["img"]
|
sample_image = sample["img"]
|
||||||
|
|
||||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||||
expected_schema = _build_eval_schema(vector_dim)
|
expected_schema = _build_eval_schema(vector_dim)
|
||||||
|
|
||||||
# Check schema compatibility
|
# Check schema compatibility
|
||||||
|
|||||||
@@ -30,6 +30,6 @@ benchmark:
|
|||||||
task:
|
task:
|
||||||
name: "recall_at_k"
|
name: "recall_at_k"
|
||||||
type: "retrieval"
|
type: "retrieval"
|
||||||
top_k: 10
|
top_k: 1
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
model_table_prefix: "benchmark"
|
model_table_prefix: "benchmark"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class ModelConfig(BaseModel):
|
|||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
name: str = "facebook/dinov2-large"
|
dino_model: str = "facebook/dinov2-large"
|
||||||
compression_dim: int = Field(
|
compression_dim: int = Field(
|
||||||
default=512, gt=0, description="Output feature dimension"
|
default=512, gt=0, description="Output feature dimension"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,14 +4,16 @@ import lancedb
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from configs import cfg_manager
|
from configs import cfg_manager
|
||||||
|
|
||||||
db_schema = pa.schema(
|
|
||||||
[
|
def _build_database_schema():
|
||||||
pa.field("id", pa.int32()),
|
return pa.schema(
|
||||||
pa.field("label", pa.string()),
|
[
|
||||||
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
pa.field("id", pa.int32()),
|
||||||
pa.field("binary", pa.binary()),
|
pa.field("label", pa.string()),
|
||||||
]
|
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
||||||
)
|
pa.field("binary", pa.binary()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
@@ -34,7 +36,9 @@ class DatabaseManager:
|
|||||||
# 初始化数据库与表格
|
# 初始化数据库与表格
|
||||||
self.db = lancedb.connect(db_path)
|
self.db = lancedb.connect(db_path)
|
||||||
if "default" not in self.db.list_tables().tables:
|
if "default" not in self.db.list_tables().tables:
|
||||||
self.table = self.db.create_table("default", schema=db_schema)
|
self.table = self.db.create_table(
|
||||||
|
"default", schema=_build_database_schema()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.table = self.db.open_table("default")
|
self.table = self.db.open_table("default")
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import io
|
import io
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Dict, List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
from datasets import load_dataset
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngImageFile
|
from PIL.PngImagePlugin import PngImageFile
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -14,6 +13,9 @@ from transformers import (
|
|||||||
BitImageProcessorFast,
|
BitImageProcessorFast,
|
||||||
Dinov2Model,
|
Dinov2Model,
|
||||||
)
|
)
|
||||||
|
from utils.feature_extractor import extract_batch_features
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||||
@@ -86,78 +88,26 @@ class FeatureRetrieval:
|
|||||||
batch_size: Number of images to process in a batch.
|
batch_size: Number of images to process in a batch.
|
||||||
label_map: Optional mapping from label indices to string names.
|
label_map: Optional mapping from label indices to string names.
|
||||||
"""
|
"""
|
||||||
device = self.model.device
|
# Extract features using the utility function
|
||||||
self.model.eval()
|
cls_tokens = extract_batch_features(
|
||||||
|
self.processor, self.model, images, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
for i in tqdm(range(0, len(images), batch_size)):
|
for i in tqdm(range(len(labels)), desc="Storing to database"):
|
||||||
batch_imgs = images[i : i + batch_size]
|
batch_label = labels[i] if label_map is None else label_map[labels[i]]
|
||||||
|
|
||||||
inputs = self.processor(batch_imgs, return_tensors="pt")
|
# Store to database
|
||||||
|
|
||||||
# 迁移数据到GPU
|
|
||||||
inputs.to(device)
|
|
||||||
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
|
|
||||||
# 后处理
|
|
||||||
feats = outputs.last_hidden_state # [B, N, D]
|
|
||||||
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
|
|
||||||
cls_tokens = cast(torch.Tensor, cls_tokens)
|
|
||||||
|
|
||||||
# 迁移输出到CPU
|
|
||||||
cls_tokens = cls_tokens.cpu()
|
|
||||||
batch_labels = (
|
|
||||||
labels[i : i + batch_size]
|
|
||||||
if label_map is None
|
|
||||||
else list(
|
|
||||||
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
|
||||||
)
|
|
||||||
)
|
|
||||||
actual_batch_size = len(batch_labels)
|
|
||||||
|
|
||||||
# 存库
|
|
||||||
db_manager.table.add(
|
db_manager.table.add(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": i + j,
|
"id": i,
|
||||||
"label": batch_labels[j],
|
"label": batch_label,
|
||||||
"vector": cls_tokens[j].numpy(),
|
"vector": cls_tokens[i].numpy(),
|
||||||
"binary": pil_image_to_bytes(batch_imgs[j]),
|
"binary": pil_image_to_bytes(images[i]),
|
||||||
}
|
}
|
||||||
for j in range(actual_batch_size)
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def extract_single_image_feature(
|
|
||||||
self, image: Union[Image.Image, Any]
|
|
||||||
) -> List[float]:
|
|
||||||
"""Extract feature from a single image without storing to database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: A single image (PIL Image or other supported format).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pl.Series: The extracted CLS token feature vector as a Polars Series.
|
|
||||||
"""
|
|
||||||
device = self.model.device
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# 预处理图片
|
|
||||||
inputs = self.processor(images=image, return_tensors="pt")
|
|
||||||
inputs.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
# 提取特征
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
|
|
||||||
# 获取 CLS token
|
|
||||||
feats = outputs.last_hidden_state # [1, N, D]
|
|
||||||
cls_token = feats[:, 0] # [1, D]
|
|
||||||
cls_token = cast(torch.Tensor, cls_token)
|
|
||||||
|
|
||||||
# 返回 CLS List
|
|
||||||
return cls_token.cpu().squeeze(0).tolist()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ if __name__ == "__main__":
|
|||||||
model_cfg = config.model
|
model_cfg = config.model
|
||||||
processor = cast(
|
processor = cast(
|
||||||
BitImageProcessorFast,
|
BitImageProcessorFast,
|
||||||
AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device),
|
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load compressor weights if specified in model config
|
# Load compressor weights if specified in model config
|
||||||
@@ -84,4 +84,6 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
generated_files = synthesizer.generate()
|
generated_files = synthesizer.generate()
|
||||||
print(f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}")
|
print(
|
||||||
|
f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TestConfigModels:
|
|||||||
def test_model_config_defaults(self):
|
def test_model_config_defaults(self):
|
||||||
"""Verify ModelConfig creates with correct defaults."""
|
"""Verify ModelConfig creates with correct defaults."""
|
||||||
config = ModelConfig()
|
config = ModelConfig()
|
||||||
assert config.name == "facebook/dinov2-large"
|
assert config.dino_model == "facebook/dinov2-large"
|
||||||
assert config.compression_dim == 512
|
assert config.compression_dim == 512
|
||||||
assert config.device == "auto"
|
assert config.device == "auto"
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ class TestYamlLoader:
|
|||||||
config = load_yaml(config_path, Config)
|
config = load_yaml(config_path, Config)
|
||||||
|
|
||||||
# Verify model config
|
# Verify model config
|
||||||
assert config.model.name == "facebook/dinov2-large"
|
assert config.model.dino_model == "facebook/dinov2-large"
|
||||||
assert config.model.compression_dim == 256
|
assert config.model.compression_dim == 256
|
||||||
|
|
||||||
# Verify output config
|
# Verify output config
|
||||||
|
|||||||
50
mini-nav/tests/test_feature_extractor.py
Normal file
50
mini-nav/tests/test_feature_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Tests for feature extraction utilities."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
from utils.feature_extractor import (
|
||||||
|
extract_batch_features,
|
||||||
|
extract_single_image_feature,
|
||||||
|
infer_vector_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_MODEL_NAME = "facebook/dinov2-base"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_and_processor():
|
||||||
|
processor = AutoImageProcessor.from_pretrained(TEST_MODEL_NAME)
|
||||||
|
model = AutoModel.from_pretrained(TEST_MODEL_NAME)
|
||||||
|
model.eval()
|
||||||
|
yield processor, model
|
||||||
|
del model
|
||||||
|
del processor
|
||||||
|
|
||||||
|
|
||||||
|
def test_infer_vector_dim(model_and_processor):
|
||||||
|
"""Verify infer_vector_dim returns correct dimension."""
|
||||||
|
processor, model = model_and_processor
|
||||||
|
sample_image = Image.new("RGB", (224, 224), color="blue")
|
||||||
|
dim = infer_vector_dim(processor, model, sample_image)
|
||||||
|
assert dim == 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_single_image_feature(model_and_processor):
|
||||||
|
"""Verify single image feature extraction."""
|
||||||
|
processor, model = model_and_processor
|
||||||
|
sample_image = Image.new("RGB", (224, 224), color="red")
|
||||||
|
features = extract_single_image_feature(processor, model, sample_image)
|
||||||
|
assert isinstance(features, list)
|
||||||
|
assert len(features) == 768
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_batch_features(model_and_processor):
|
||||||
|
"""Verify batch feature extraction."""
|
||||||
|
processor, model = model_and_processor
|
||||||
|
images = [Image.new("RGB", (224, 224), color="red") for _ in range(3)]
|
||||||
|
features = extract_batch_features(processor, model, images)
|
||||||
|
assert isinstance(features, torch.Tensor)
|
||||||
|
assert features.shape == (3, 768)
|
||||||
@@ -1,3 +1,14 @@
|
|||||||
from .common import get_device, get_output_diretory
|
from .common import get_device, get_output_diretory
|
||||||
|
from .feature_extractor import (
|
||||||
|
extract_batch_features,
|
||||||
|
extract_single_image_feature,
|
||||||
|
infer_vector_dim,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["get_device", "get_output_diretory"]
|
__all__ = [
|
||||||
|
"get_device",
|
||||||
|
"get_output_diretory",
|
||||||
|
"infer_vector_dim",
|
||||||
|
"extract_single_image_feature",
|
||||||
|
"extract_batch_features",
|
||||||
|
]
|
||||||
|
|||||||
130
mini-nav/utils/feature_extractor.py
Normal file
130
mini-nav/utils/feature_extractor.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""Feature extraction utilities for image models."""
|
||||||
|
|
||||||
|
from typing import Any, List, Union, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import BitImageProcessorFast
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_features_from_output(output: Any) -> torch.Tensor:
|
||||||
|
"""Extract features from model output, handling both HuggingFace ModelOutput and raw tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: Model output (either ModelOutput with .last_hidden_state or raw tensor).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Feature tensor of shape [B, D].
|
||||||
|
"""
|
||||||
|
# Handle HuggingFace ModelOutput (has .last_hidden_state)
|
||||||
|
if hasattr(output, "last_hidden_state"):
|
||||||
|
return output.last_hidden_state[:, 0] # [B, D] - CLS token
|
||||||
|
# Handle raw tensor output (like DinoCompressor)
|
||||||
|
return cast(torch.Tensor, output)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_vector_dim(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
sample_image: Any,
|
||||||
|
) -> int:
|
||||||
|
"""Infer model output vector dimension via a single forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
sample_image: A sample image for dimension inference.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Vector dimension.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = processor(images=sample_image, return_tensors="pt")
|
||||||
|
inputs.to(device)
|
||||||
|
output = model(inputs)
|
||||||
|
|
||||||
|
features = _extract_features_from_output(output)
|
||||||
|
return features.shape[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def extract_single_image_feature(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
image: Union[Image.Image, Any],
|
||||||
|
) -> List[float]:
|
||||||
|
"""Extract feature from a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
image: A single image (PIL Image or other supported format).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted CLS token feature vector as a list of floats.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = processor(images=image, return_tensors="pt")
|
||||||
|
inputs.to(device, non_blocking=True)
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
features = _extract_features_from_output(outputs) # [1, D]
|
||||||
|
return features.cpu().squeeze(0).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def extract_batch_features(
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
images: Union[List[Any], Any],
|
||||||
|
batch_size: int = 32,
|
||||||
|
show_progress: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Extract features from a batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: Image preprocessor.
|
||||||
|
model: Feature extraction model.
|
||||||
|
images: List of images, DataLoader, or other iterable.
|
||||||
|
batch_size: Batch size for processing.
|
||||||
|
show_progress: Whether to show progress bar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [batch_size, feature_dim].
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Handle DataLoader input
|
||||||
|
if isinstance(images, DataLoader):
|
||||||
|
all_features = []
|
||||||
|
iterator = tqdm(images, desc="Extracting features") if show_progress else images
|
||||||
|
for batch in iterator:
|
||||||
|
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
|
||||||
|
inputs = processor(images=imgs, return_tensors="pt")
|
||||||
|
inputs.to(device)
|
||||||
|
outputs = model(inputs)
|
||||||
|
features = _extract_features_from_output(outputs) # [B, D]
|
||||||
|
all_features.append(features.cpu())
|
||||||
|
return torch.cat(all_features, dim=0)
|
||||||
|
|
||||||
|
# Handle list of images
|
||||||
|
all_features = []
|
||||||
|
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
|
||||||
|
for i in iterator:
|
||||||
|
batch_imgs = images[i : i + batch_size]
|
||||||
|
inputs = processor(images=batch_imgs, return_tensors="pt")
|
||||||
|
inputs.to(device)
|
||||||
|
outputs = model(inputs)
|
||||||
|
features = _extract_features_from_output(outputs) # [B, D]
|
||||||
|
all_features.append(features.cpu())
|
||||||
|
|
||||||
|
return torch.cat(all_features, dim=0)
|
||||||
Reference in New Issue
Block a user