mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(utils): add feature extraction utilities and tests
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
"""Retrieval task for benchmark evaluation (Recall@K)."""
|
||||
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
from benchmarks.tasks.registry import RegisterTask
|
||||
from torch import nn
|
||||
@@ -12,31 +11,7 @@ from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import BitImageProcessorFast
|
||||
|
||||
|
||||
def _infer_vector_dim(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
sample_image: Any,
|
||||
) -> int:
|
||||
"""Infer model output vector dimension via a single forward pass.
|
||||
|
||||
Args:
|
||||
processor: Image preprocessor.
|
||||
model: Feature extraction model.
|
||||
sample_image: A sample image for dimension inference.
|
||||
|
||||
Returns:
|
||||
Vector dimension.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = processor(images=sample_image, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
output = model(inputs)
|
||||
|
||||
return output.shape[-1]
|
||||
from utils.feature_extractor import extract_batch_features, infer_vector_dim
|
||||
|
||||
|
||||
def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||
@@ -57,7 +32,6 @@ def _build_eval_schema(vector_dim: int) -> pa.Schema:
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _establish_eval_database(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
@@ -72,28 +46,22 @@ def _establish_eval_database(
|
||||
table: LanceDB table to store features.
|
||||
dataloader: DataLoader for the training dataset.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
# Extract all features using the utility function
|
||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||
|
||||
# Store features to database
|
||||
global_idx = 0
|
||||
for batch in tqdm(dataloader, desc="Building eval database"):
|
||||
imgs = batch["img"]
|
||||
for batch in tqdm(dataloader, desc="Storing eval database"):
|
||||
labels = batch["label"]
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs)
|
||||
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
batch_size = len(labels_list)
|
||||
|
||||
table.add(
|
||||
[
|
||||
{
|
||||
"id": global_idx + j,
|
||||
"label": labels_list[j],
|
||||
"vector": features[j].numpy(),
|
||||
"vector": all_features[global_idx + j].numpy(),
|
||||
}
|
||||
for j in range(batch_size)
|
||||
]
|
||||
@@ -101,7 +69,6 @@ def _establish_eval_database(
|
||||
global_idx += batch_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _evaluate_recall(
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
@@ -121,25 +88,19 @@ def _evaluate_recall(
|
||||
Returns:
|
||||
A tuple of (correct_count, total_count).
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
model.eval()
|
||||
# Extract all features using the utility function
|
||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
feature_idx = 0
|
||||
|
||||
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
||||
imgs = batch["img"]
|
||||
labels = batch["label"]
|
||||
|
||||
inputs = processor(imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
outputs = model(inputs)
|
||||
|
||||
features = cast(torch.Tensor, outputs).cpu()
|
||||
labels_list = labels.tolist()
|
||||
|
||||
for j in range(len(labels_list)):
|
||||
feature = features[j].tolist()
|
||||
feature = all_features[feature_idx + j].tolist()
|
||||
true_label = labels_list[j]
|
||||
|
||||
results = (
|
||||
@@ -154,6 +115,8 @@ def _evaluate_recall(
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
feature_idx += len(labels_list)
|
||||
|
||||
return correct, total
|
||||
|
||||
|
||||
@@ -191,7 +154,7 @@ class RetrievalTask(BaseBenchmarkTask):
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["img"]
|
||||
|
||||
vector_dim = _infer_vector_dim(processor, model, sample_image)
|
||||
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||
expected_schema = _build_eval_schema(vector_dim)
|
||||
|
||||
# Check schema compatibility
|
||||
|
||||
Reference in New Issue
Block a user