mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-13 04:45:32 +08:00
feat(benchmark): add multi-object retrieval benchmark with SAM segmentation
This commit is contained in:
356
mini-nav/benchmarks/tasks/multi_object_retrieval.py
Normal file
356
mini-nav/benchmarks/tasks/multi_object_retrieval.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Multi-object retrieval benchmark task.
|
||||
|
||||
This benchmark evaluates retrieval accuracy using multiple objects from a cropped
|
||||
scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature
|
||||
extraction, and LanceDB for vector storage with scene-level score aggregation.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
from benchmarks.tasks.registry import RegisterTask
|
||||
from configs.models import BenchmarkTaskConfig
|
||||
from rich.progress import track
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BitImageProcessorFast
|
||||
from utils.feature_extractor import extract_single_image_feature
|
||||
from utils.sam import load_sam_model, segment_image
|
||||
from utils.common import get_device
|
||||
|
||||
|
||||
def _build_object_schema(vector_dim: int) -> pa.Schema:
|
||||
"""Build PyArrow schema for object-level vectors.
|
||||
|
||||
Args:
|
||||
vector_dim: Feature vector dimension.
|
||||
|
||||
Returns:
|
||||
PyArrow schema with id, image_id, object_id, category, and vector fields.
|
||||
"""
|
||||
return pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("image_id", pa.string()),
|
||||
pa.field("object_id", pa.string()),
|
||||
pa.field("category", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _compute_scene_score(
|
||||
query_object_ids: list[str],
|
||||
retrieved_results: dict[str, list[tuple[float, str]]],
|
||||
gamma: float,
|
||||
) -> dict[str, float]:
|
||||
"""Compute scene-level scores using co-occurrence penalty.
|
||||
|
||||
Args:
|
||||
query_object_ids: List of query object IDs.
|
||||
retrieved_results: Dict mapping image_id to list of (distance, object_id) results.
|
||||
gamma: Co-occurrence penalty exponent.
|
||||
|
||||
Returns:
|
||||
Dict mapping image_id to computed scene score.
|
||||
"""
|
||||
scene_scores: dict[str, float] = {}
|
||||
|
||||
for image_id, results in retrieved_results.items():
|
||||
# Build a set of retrieved object IDs for this scene
|
||||
retrieved_ids = {obj_id for _, obj_id in results}
|
||||
|
||||
# Count how many query objects are found in this scene
|
||||
matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids)
|
||||
|
||||
if matched_count == 0:
|
||||
scene_scores[image_id] = 0.0
|
||||
continue
|
||||
|
||||
# Sum of best similarities (using distance as similarity: smaller = better)
|
||||
# We use 1/(1+distance) to convert distance to similarity
|
||||
similarities = []
|
||||
for dist, obj_id in results:
|
||||
if obj_id in query_object_ids:
|
||||
sim = 1.0 / (1.0 + dist)
|
||||
similarities.append(sim)
|
||||
|
||||
sum_similarity = sum(similarities) if similarities else 0.0
|
||||
|
||||
# Hit rate: ratio of matched objects
|
||||
hit_rate = matched_count / len(query_object_ids)
|
||||
|
||||
# Final score: sum_similarity * (hit_rate)^gamma
|
||||
score = sum_similarity * (hit_rate ** gamma)
|
||||
scene_scores[image_id] = score
|
||||
|
||||
return scene_scores
|
||||
|
||||
|
||||
@RegisterTask("multi-object-retrieval")
|
||||
class MultiObjectRetrievalTask(BaseBenchmarkTask):
|
||||
"""Multi-object retrieval benchmark task."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize multi-object retrieval task.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters from BenchmarkTaskConfig.
|
||||
"""
|
||||
# Use config from kwargs or load default config
|
||||
if kwargs:
|
||||
config_dict = kwargs
|
||||
else:
|
||||
config = BenchmarkTaskConfig(type="multi-object-retrieval")
|
||||
config_dict = config.model_dump()
|
||||
|
||||
super().__init__(**config_dict)
|
||||
self.config = BenchmarkTaskConfig(**config_dict)
|
||||
|
||||
# SAM settings from ModelConfig (passed via kwargs or use defaults)
|
||||
self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large")
|
||||
self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32)
|
||||
self.max_masks_per_image = kwargs.get("sam_max_masks", 5)
|
||||
|
||||
# Lazy-loaded resources
|
||||
self._sam_model = None
|
||||
self._mask_generator = None
|
||||
|
||||
@property
|
||||
def sam_model(self) -> Any:
|
||||
"""Lazy-load SAM model."""
|
||||
if self._sam_model is None:
|
||||
self._sam_model, self._mask_generator = load_sam_model(
|
||||
model_name=self.sam_model,
|
||||
device=str(get_device()),
|
||||
)
|
||||
return self._sam_model
|
||||
|
||||
@property
|
||||
def mask_generator(self) -> Any:
|
||||
"""Lazy-load mask generator."""
|
||||
if self._mask_generator is None:
|
||||
self._sam_model, self._mask_generator = load_sam_model(
|
||||
model_name=self.sam_model,
|
||||
device=str(get_device()),
|
||||
)
|
||||
return self._mask_generator
|
||||
|
||||
def build_database(
|
||||
self,
|
||||
model: nn.Module,
|
||||
processor: BitImageProcessorFast,
|
||||
train_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""Build the evaluation database with object-level vectors.
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
train_dataset: Training dataset.
|
||||
table: LanceDB table to store features.
|
||||
batch_size: Batch size for DataLoader.
|
||||
"""
|
||||
# Infer vector dimension from a sample
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["image"]
|
||||
|
||||
# Get vector dimension by running a forward pass
|
||||
vector_dim = self._infer_vector_dim(processor, model, sample_image)
|
||||
expected_schema = _build_object_schema(vector_dim)
|
||||
|
||||
# Check schema compatibility
|
||||
if table.schema != expected_schema:
|
||||
raise ValueError(
|
||||
f"Table schema mismatch. Expected: {expected_schema}, "
|
||||
f"Got: {table.schema}"
|
||||
)
|
||||
|
||||
# Build database: segment each image, extract features per object
|
||||
record_id = 0
|
||||
records = []
|
||||
|
||||
for idx in track(range(len(train_dataset)), description="Building object database"):
|
||||
item = train_dataset[idx]
|
||||
image = item["image"]
|
||||
image_id = item.get("image_id", f"image_{idx}")
|
||||
|
||||
# Segment image using SAM
|
||||
masks = segment_image(
|
||||
self.mask_generator,
|
||||
image,
|
||||
min_area=self.min_mask_area,
|
||||
max_masks=self.max_masks_per_image,
|
||||
)
|
||||
|
||||
if not masks:
|
||||
continue
|
||||
|
||||
# Extract features for each mask
|
||||
for mask_idx, mask_info in enumerate(masks):
|
||||
# Extract masked region
|
||||
masked_image = self._apply_mask(image, mask_info["segment"])
|
||||
|
||||
# Extract feature vector
|
||||
vector = extract_single_image_feature(processor, model, masked_image)
|
||||
|
||||
# Create object ID
|
||||
object_id = f"{image_id}_obj_{mask_idx}"
|
||||
category = mask_info.get("category", "unknown")
|
||||
|
||||
records.append({
|
||||
"id": record_id,
|
||||
"image_id": image_id,
|
||||
"object_id": object_id,
|
||||
"category": category,
|
||||
"vector": vector,
|
||||
})
|
||||
record_id += 1
|
||||
|
||||
# Add all records to table
|
||||
if records:
|
||||
table.add(records)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
model: nn.Module,
|
||||
processor: BitImageProcessorFast,
|
||||
test_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Evaluate the model on the test dataset.
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
test_dataset: Test dataset.
|
||||
table: LanceDB table to search against.
|
||||
batch_size: Batch size for DataLoader.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation results with keys:
|
||||
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
|
||||
- correct: Number of correct predictions
|
||||
- total: Total number of test samples
|
||||
- top_k: The K value used
|
||||
"""
|
||||
top_k = self.config.top_k_per_object
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"):
|
||||
item = test_dataset[idx]
|
||||
image = item["image"]
|
||||
target_image_id = item.get("image_id", f"image_{idx}")
|
||||
|
||||
# Segment query image
|
||||
masks = segment_image(
|
||||
self.mask_generator,
|
||||
image,
|
||||
min_area=self.min_mask_area,
|
||||
max_masks=self.max_masks_per_image,
|
||||
)
|
||||
|
||||
if not masks:
|
||||
continue
|
||||
|
||||
# Randomly sample query objects
|
||||
num_query = min(self.config.num_query_objects, len(masks))
|
||||
query_masks = random.sample(masks, num_query)
|
||||
|
||||
# Extract features and search for each query object
|
||||
retrieved_results: dict[str, list[tuple[float, str]]] = {}
|
||||
|
||||
for mask_info in query_masks:
|
||||
# Extract masked region
|
||||
masked_image = self._apply_mask(image, mask_info["segment"])
|
||||
|
||||
# Extract feature vector
|
||||
vector = extract_single_image_feature(processor, model, masked_image)
|
||||
|
||||
# Search in LanceDB
|
||||
results = (
|
||||
table.search(vector)
|
||||
.select(["image_id", "object_id", "_distance"])
|
||||
.limit(top_k)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
# Aggregate results by scene
|
||||
for row in results.iter_rows():
|
||||
image_id = row["image_id"]
|
||||
object_id = row["object_id"]
|
||||
distance = row["_distance"]
|
||||
|
||||
if image_id not in retrieved_results:
|
||||
retrieved_results[image_id] = []
|
||||
retrieved_results[image_id].append((distance, object_id))
|
||||
|
||||
# Compute scene scores
|
||||
query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)]
|
||||
scene_scores = _compute_scene_score(
|
||||
query_object_ids,
|
||||
retrieved_results,
|
||||
self.config.gamma,
|
||||
)
|
||||
|
||||
# Rank scenes by score
|
||||
ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Check if target is in top-K
|
||||
top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]]
|
||||
if target_image_id in top_k_scenes:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
"top_k": top_k,
|
||||
}
|
||||
|
||||
def _infer_vector_dim(
|
||||
self,
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
sample_image: Any,
|
||||
) -> int:
|
||||
"""Infer vector dimension from model output."""
|
||||
vector = extract_single_image_feature(processor, model, sample_image)
|
||||
return len(vector)
|
||||
|
||||
def _apply_mask(self, image: Any, mask: np.ndarray) -> Any:
|
||||
"""Apply mask to image and return masked image.
|
||||
|
||||
Args:
|
||||
image: PIL Image.
|
||||
mask: Binary mask as numpy array.
|
||||
|
||||
Returns:
|
||||
Masked PIL Image.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
image_np = np.array(image.convert("RGB"))
|
||||
# Ensure mask is the right shape
|
||||
if mask.shape != image_np.shape[:2]:
|
||||
from skimage.transform import resize
|
||||
mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False)
|
||||
else:
|
||||
mask_resized = mask
|
||||
|
||||
# Apply mask
|
||||
masked_np = image_np * mask_resized[:, :, np.newaxis]
|
||||
return Image.fromarray(masked_np.astype(np.uint8))
|
||||
Reference in New Issue
Block a user