feat(benchmark): add multi-object retrieval benchmark with SAM segmentation

This commit is contained in:
2026-03-12 12:51:04 +08:00
parent 2466ab28cd
commit b39ee74e99
7 changed files with 1058 additions and 0 deletions

View File

@@ -0,0 +1,356 @@
"""Multi-object retrieval benchmark task.
This benchmark evaluates retrieval accuracy using multiple objects from a cropped
scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature
extraction, and LanceDB for vector storage with scene-level score aggregation.
"""
import random
from typing import Any
import lancedb
import numpy as np
import pyarrow as pa
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from configs.models import BenchmarkTaskConfig
from rich.progress import track
from torch import nn
from torch.utils.data import DataLoader
from transformers import BitImageProcessorFast
from utils.feature_extractor import extract_single_image_feature
from utils.sam import load_sam_model, segment_image
from utils.common import get_device
def _build_object_schema(vector_dim: int) -> pa.Schema:
"""Build PyArrow schema for object-level vectors.
Args:
vector_dim: Feature vector dimension.
Returns:
PyArrow schema with id, image_id, object_id, category, and vector fields.
"""
return pa.schema(
[
pa.field("id", pa.int32()),
pa.field("image_id", pa.string()),
pa.field("object_id", pa.string()),
pa.field("category", pa.string()),
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
]
)
def _compute_scene_score(
query_object_ids: list[str],
retrieved_results: dict[str, list[tuple[float, str]]],
gamma: float,
) -> dict[str, float]:
"""Compute scene-level scores using co-occurrence penalty.
Args:
query_object_ids: List of query object IDs.
retrieved_results: Dict mapping image_id to list of (distance, object_id) results.
gamma: Co-occurrence penalty exponent.
Returns:
Dict mapping image_id to computed scene score.
"""
scene_scores: dict[str, float] = {}
for image_id, results in retrieved_results.items():
# Build a set of retrieved object IDs for this scene
retrieved_ids = {obj_id for _, obj_id in results}
# Count how many query objects are found in this scene
matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids)
if matched_count == 0:
scene_scores[image_id] = 0.0
continue
# Sum of best similarities (using distance as similarity: smaller = better)
# We use 1/(1+distance) to convert distance to similarity
similarities = []
for dist, obj_id in results:
if obj_id in query_object_ids:
sim = 1.0 / (1.0 + dist)
similarities.append(sim)
sum_similarity = sum(similarities) if similarities else 0.0
# Hit rate: ratio of matched objects
hit_rate = matched_count / len(query_object_ids)
# Final score: sum_similarity * (hit_rate)^gamma
score = sum_similarity * (hit_rate ** gamma)
scene_scores[image_id] = score
return scene_scores
@RegisterTask("multi-object-retrieval")
class MultiObjectRetrievalTask(BaseBenchmarkTask):
"""Multi-object retrieval benchmark task."""
def __init__(self, **kwargs: Any):
"""Initialize multi-object retrieval task.
Args:
**kwargs: Configuration parameters from BenchmarkTaskConfig.
"""
# Use config from kwargs or load default config
if kwargs:
config_dict = kwargs
else:
config = BenchmarkTaskConfig(type="multi-object-retrieval")
config_dict = config.model_dump()
super().__init__(**config_dict)
self.config = BenchmarkTaskConfig(**config_dict)
# SAM settings from ModelConfig (passed via kwargs or use defaults)
self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large")
self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32)
self.max_masks_per_image = kwargs.get("sam_max_masks", 5)
# Lazy-loaded resources
self._sam_model = None
self._mask_generator = None
@property
def sam_model(self) -> Any:
"""Lazy-load SAM model."""
if self._sam_model is None:
self._sam_model, self._mask_generator = load_sam_model(
model_name=self.sam_model,
device=str(get_device()),
)
return self._sam_model
@property
def mask_generator(self) -> Any:
"""Lazy-load mask generator."""
if self._mask_generator is None:
self._sam_model, self._mask_generator = load_sam_model(
model_name=self.sam_model,
device=str(get_device()),
)
return self._mask_generator
def build_database(
self,
model: nn.Module,
processor: BitImageProcessorFast,
train_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> None:
"""Build the evaluation database with object-level vectors.
Args:
model: Feature extraction model.
processor: Image preprocessor.
train_dataset: Training dataset.
table: LanceDB table to store features.
batch_size: Batch size for DataLoader.
"""
# Infer vector dimension from a sample
sample = train_dataset[0]
sample_image = sample["image"]
# Get vector dimension by running a forward pass
vector_dim = self._infer_vector_dim(processor, model, sample_image)
expected_schema = _build_object_schema(vector_dim)
# Check schema compatibility
if table.schema != expected_schema:
raise ValueError(
f"Table schema mismatch. Expected: {expected_schema}, "
f"Got: {table.schema}"
)
# Build database: segment each image, extract features per object
record_id = 0
records = []
for idx in track(range(len(train_dataset)), description="Building object database"):
item = train_dataset[idx]
image = item["image"]
image_id = item.get("image_id", f"image_{idx}")
# Segment image using SAM
masks = segment_image(
self.mask_generator,
image,
min_area=self.min_mask_area,
max_masks=self.max_masks_per_image,
)
if not masks:
continue
# Extract features for each mask
for mask_idx, mask_info in enumerate(masks):
# Extract masked region
masked_image = self._apply_mask(image, mask_info["segment"])
# Extract feature vector
vector = extract_single_image_feature(processor, model, masked_image)
# Create object ID
object_id = f"{image_id}_obj_{mask_idx}"
category = mask_info.get("category", "unknown")
records.append({
"id": record_id,
"image_id": image_id,
"object_id": object_id,
"category": category,
"vector": vector,
})
record_id += 1
# Add all records to table
if records:
table.add(records)
def evaluate(
self,
model: nn.Module,
processor: BitImageProcessorFast,
test_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> dict[str, Any]:
"""Evaluate the model on the test dataset.
Args:
model: Feature extraction model.
processor: Image preprocessor.
test_dataset: Test dataset.
table: LanceDB table to search against.
batch_size: Batch size for DataLoader.
Returns:
Dictionary containing evaluation results with keys:
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
- correct: Number of correct predictions
- total: Total number of test samples
- top_k: The K value used
"""
top_k = self.config.top_k_per_object
correct = 0
total = 0
for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"):
item = test_dataset[idx]
image = item["image"]
target_image_id = item.get("image_id", f"image_{idx}")
# Segment query image
masks = segment_image(
self.mask_generator,
image,
min_area=self.min_mask_area,
max_masks=self.max_masks_per_image,
)
if not masks:
continue
# Randomly sample query objects
num_query = min(self.config.num_query_objects, len(masks))
query_masks = random.sample(masks, num_query)
# Extract features and search for each query object
retrieved_results: dict[str, list[tuple[float, str]]] = {}
for mask_info in query_masks:
# Extract masked region
masked_image = self._apply_mask(image, mask_info["segment"])
# Extract feature vector
vector = extract_single_image_feature(processor, model, masked_image)
# Search in LanceDB
results = (
table.search(vector)
.select(["image_id", "object_id", "_distance"])
.limit(top_k)
.to_polars()
)
# Aggregate results by scene
for row in results.iter_rows():
image_id = row["image_id"]
object_id = row["object_id"]
distance = row["_distance"]
if image_id not in retrieved_results:
retrieved_results[image_id] = []
retrieved_results[image_id].append((distance, object_id))
# Compute scene scores
query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)]
scene_scores = _compute_scene_score(
query_object_ids,
retrieved_results,
self.config.gamma,
)
# Rank scenes by score
ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True)
# Check if target is in top-K
top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]]
if target_image_id in top_k_scenes:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0.0
return {
"accuracy": accuracy,
"correct": correct,
"total": total,
"top_k": top_k,
}
def _infer_vector_dim(
self,
processor: BitImageProcessorFast,
model: nn.Module,
sample_image: Any,
) -> int:
"""Infer vector dimension from model output."""
vector = extract_single_image_feature(processor, model, sample_image)
return len(vector)
def _apply_mask(self, image: Any, mask: np.ndarray) -> Any:
"""Apply mask to image and return masked image.
Args:
image: PIL Image.
mask: Binary mask as numpy array.
Returns:
Masked PIL Image.
"""
import numpy as np
from PIL import Image
image_np = np.array(image.convert("RGB"))
# Ensure mask is the right shape
if mask.shape != image_np.shape[:2]:
from skimage.transform import resize
mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False)
else:
mask_resized = mask
# Apply mask
masked_np = image_np * mask_resized[:, :, np.newaxis]
return Image.fromarray(masked_np.astype(np.uint8))