feat(utils): add feature extraction utilities and tests

This commit is contained in:
2026-03-05 20:48:53 +08:00
parent a16b376dd7
commit 5be4709acf
13 changed files with 247 additions and 138 deletions

View File

@@ -1,10 +1,9 @@
"""Retrieval task for benchmark evaluation (Recall@K)."""
from typing import Any, cast
from typing import Any
import lancedb
import pyarrow as pa
import torch
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from torch import nn
@@ -12,31 +11,7 @@ from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
def _infer_vector_dim(
processor: BitImageProcessorFast,
model: nn.Module,
sample_image: Any,
) -> int:
"""Infer model output vector dimension via a single forward pass.
Args:
processor: Image preprocessor.
model: Feature extraction model.
sample_image: A sample image for dimension inference.
Returns:
Vector dimension.
"""
device = next(model.parameters()).device
model.eval()
with torch.no_grad():
inputs = processor(images=sample_image, return_tensors="pt")
inputs.to(device)
output = model(inputs)
return output.shape[-1]
from utils.feature_extractor import extract_batch_features, infer_vector_dim
def _build_eval_schema(vector_dim: int) -> pa.Schema:
@@ -57,7 +32,6 @@ def _build_eval_schema(vector_dim: int) -> pa.Schema:
)
@torch.no_grad()
def _establish_eval_database(
processor: BitImageProcessorFast,
model: nn.Module,
@@ -72,28 +46,22 @@ def _establish_eval_database(
table: LanceDB table to store features.
dataloader: DataLoader for the training dataset.
"""
device = next(model.parameters()).device
model.eval()
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
# Store features to database
global_idx = 0
for batch in tqdm(dataloader, desc="Building eval database"):
imgs = batch["img"]
for batch in tqdm(dataloader, desc="Storing eval database"):
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs)
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
batch_size = len(labels_list)
table.add(
[
{
"id": global_idx + j,
"label": labels_list[j],
"vector": features[j].numpy(),
"vector": all_features[global_idx + j].numpy(),
}
for j in range(batch_size)
]
@@ -101,7 +69,6 @@ def _establish_eval_database(
global_idx += batch_size
@torch.no_grad()
def _evaluate_recall(
processor: BitImageProcessorFast,
model: nn.Module,
@@ -121,25 +88,19 @@ def _evaluate_recall(
Returns:
A tuple of (correct_count, total_count).
"""
device = next(model.parameters()).device
model.eval()
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
correct = 0
total = 0
feature_idx = 0
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt")
inputs.to(device)
outputs = model(inputs)
features = cast(torch.Tensor, outputs).cpu()
labels_list = labels.tolist()
for j in range(len(labels_list)):
feature = features[j].tolist()
feature = all_features[feature_idx + j].tolist()
true_label = labels_list[j]
results = (
@@ -154,6 +115,8 @@ def _evaluate_recall(
correct += 1
total += 1
feature_idx += len(labels_list)
return correct, total
@@ -191,7 +154,7 @@ class RetrievalTask(BaseBenchmarkTask):
sample = train_dataset[0]
sample_image = sample["img"]
vector_dim = _infer_vector_dim(processor, model, sample_image)
vector_dim = infer_vector_dim(processor, model, sample_image)
expected_schema = _build_eval_schema(vector_dim)
# Check schema compatibility