feat(benchmark): add multi-object retrieval benchmark with SAM segmentation

This commit is contained in:
2026-03-12 12:51:04 +08:00
parent 2466ab28cd
commit b39ee74e99
7 changed files with 1058 additions and 0 deletions

View File

@@ -0,0 +1,356 @@
"""Multi-object retrieval benchmark task.
This benchmark evaluates retrieval accuracy using multiple objects from a cropped
scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature
extraction, and LanceDB for vector storage with scene-level score aggregation.
"""
import random
from typing import Any
import lancedb
import numpy as np
import pyarrow as pa
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from configs.models import BenchmarkTaskConfig
from rich.progress import track
from torch import nn
from torch.utils.data import DataLoader
from transformers import BitImageProcessorFast
from utils.feature_extractor import extract_single_image_feature
from utils.sam import load_sam_model, segment_image
from utils.common import get_device
def _build_object_schema(vector_dim: int) -> pa.Schema:
"""Build PyArrow schema for object-level vectors.
Args:
vector_dim: Feature vector dimension.
Returns:
PyArrow schema with id, image_id, object_id, category, and vector fields.
"""
return pa.schema(
[
pa.field("id", pa.int32()),
pa.field("image_id", pa.string()),
pa.field("object_id", pa.string()),
pa.field("category", pa.string()),
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
]
)
def _compute_scene_score(
query_object_ids: list[str],
retrieved_results: dict[str, list[tuple[float, str]]],
gamma: float,
) -> dict[str, float]:
"""Compute scene-level scores using co-occurrence penalty.
Args:
query_object_ids: List of query object IDs.
retrieved_results: Dict mapping image_id to list of (distance, object_id) results.
gamma: Co-occurrence penalty exponent.
Returns:
Dict mapping image_id to computed scene score.
"""
scene_scores: dict[str, float] = {}
for image_id, results in retrieved_results.items():
# Build a set of retrieved object IDs for this scene
retrieved_ids = {obj_id for _, obj_id in results}
# Count how many query objects are found in this scene
matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids)
if matched_count == 0:
scene_scores[image_id] = 0.0
continue
# Sum of best similarities (using distance as similarity: smaller = better)
# We use 1/(1+distance) to convert distance to similarity
similarities = []
for dist, obj_id in results:
if obj_id in query_object_ids:
sim = 1.0 / (1.0 + dist)
similarities.append(sim)
sum_similarity = sum(similarities) if similarities else 0.0
# Hit rate: ratio of matched objects
hit_rate = matched_count / len(query_object_ids)
# Final score: sum_similarity * (hit_rate)^gamma
score = sum_similarity * (hit_rate ** gamma)
scene_scores[image_id] = score
return scene_scores
@RegisterTask("multi-object-retrieval")
class MultiObjectRetrievalTask(BaseBenchmarkTask):
"""Multi-object retrieval benchmark task."""
def __init__(self, **kwargs: Any):
"""Initialize multi-object retrieval task.
Args:
**kwargs: Configuration parameters from BenchmarkTaskConfig.
"""
# Use config from kwargs or load default config
if kwargs:
config_dict = kwargs
else:
config = BenchmarkTaskConfig(type="multi-object-retrieval")
config_dict = config.model_dump()
super().__init__(**config_dict)
self.config = BenchmarkTaskConfig(**config_dict)
# SAM settings from ModelConfig (passed via kwargs or use defaults)
self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large")
self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32)
self.max_masks_per_image = kwargs.get("sam_max_masks", 5)
# Lazy-loaded resources
self._sam_model = None
self._mask_generator = None
@property
def sam_model(self) -> Any:
"""Lazy-load SAM model."""
if self._sam_model is None:
self._sam_model, self._mask_generator = load_sam_model(
model_name=self.sam_model,
device=str(get_device()),
)
return self._sam_model
@property
def mask_generator(self) -> Any:
"""Lazy-load mask generator."""
if self._mask_generator is None:
self._sam_model, self._mask_generator = load_sam_model(
model_name=self.sam_model,
device=str(get_device()),
)
return self._mask_generator
def build_database(
self,
model: nn.Module,
processor: BitImageProcessorFast,
train_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> None:
"""Build the evaluation database with object-level vectors.
Args:
model: Feature extraction model.
processor: Image preprocessor.
train_dataset: Training dataset.
table: LanceDB table to store features.
batch_size: Batch size for DataLoader.
"""
# Infer vector dimension from a sample
sample = train_dataset[0]
sample_image = sample["image"]
# Get vector dimension by running a forward pass
vector_dim = self._infer_vector_dim(processor, model, sample_image)
expected_schema = _build_object_schema(vector_dim)
# Check schema compatibility
if table.schema != expected_schema:
raise ValueError(
f"Table schema mismatch. Expected: {expected_schema}, "
f"Got: {table.schema}"
)
# Build database: segment each image, extract features per object
record_id = 0
records = []
for idx in track(range(len(train_dataset)), description="Building object database"):
item = train_dataset[idx]
image = item["image"]
image_id = item.get("image_id", f"image_{idx}")
# Segment image using SAM
masks = segment_image(
self.mask_generator,
image,
min_area=self.min_mask_area,
max_masks=self.max_masks_per_image,
)
if not masks:
continue
# Extract features for each mask
for mask_idx, mask_info in enumerate(masks):
# Extract masked region
masked_image = self._apply_mask(image, mask_info["segment"])
# Extract feature vector
vector = extract_single_image_feature(processor, model, masked_image)
# Create object ID
object_id = f"{image_id}_obj_{mask_idx}"
category = mask_info.get("category", "unknown")
records.append({
"id": record_id,
"image_id": image_id,
"object_id": object_id,
"category": category,
"vector": vector,
})
record_id += 1
# Add all records to table
if records:
table.add(records)
def evaluate(
self,
model: nn.Module,
processor: BitImageProcessorFast,
test_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> dict[str, Any]:
"""Evaluate the model on the test dataset.
Args:
model: Feature extraction model.
processor: Image preprocessor.
test_dataset: Test dataset.
table: LanceDB table to search against.
batch_size: Batch size for DataLoader.
Returns:
Dictionary containing evaluation results with keys:
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
- correct: Number of correct predictions
- total: Total number of test samples
- top_k: The K value used
"""
top_k = self.config.top_k_per_object
correct = 0
total = 0
for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"):
item = test_dataset[idx]
image = item["image"]
target_image_id = item.get("image_id", f"image_{idx}")
# Segment query image
masks = segment_image(
self.mask_generator,
image,
min_area=self.min_mask_area,
max_masks=self.max_masks_per_image,
)
if not masks:
continue
# Randomly sample query objects
num_query = min(self.config.num_query_objects, len(masks))
query_masks = random.sample(masks, num_query)
# Extract features and search for each query object
retrieved_results: dict[str, list[tuple[float, str]]] = {}
for mask_info in query_masks:
# Extract masked region
masked_image = self._apply_mask(image, mask_info["segment"])
# Extract feature vector
vector = extract_single_image_feature(processor, model, masked_image)
# Search in LanceDB
results = (
table.search(vector)
.select(["image_id", "object_id", "_distance"])
.limit(top_k)
.to_polars()
)
# Aggregate results by scene
for row in results.iter_rows():
image_id = row["image_id"]
object_id = row["object_id"]
distance = row["_distance"]
if image_id not in retrieved_results:
retrieved_results[image_id] = []
retrieved_results[image_id].append((distance, object_id))
# Compute scene scores
query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)]
scene_scores = _compute_scene_score(
query_object_ids,
retrieved_results,
self.config.gamma,
)
# Rank scenes by score
ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True)
# Check if target is in top-K
top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]]
if target_image_id in top_k_scenes:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0.0
return {
"accuracy": accuracy,
"correct": correct,
"total": total,
"top_k": top_k,
}
def _infer_vector_dim(
self,
processor: BitImageProcessorFast,
model: nn.Module,
sample_image: Any,
) -> int:
"""Infer vector dimension from model output."""
vector = extract_single_image_feature(processor, model, sample_image)
return len(vector)
def _apply_mask(self, image: Any, mask: np.ndarray) -> Any:
"""Apply mask to image and return masked image.
Args:
image: PIL Image.
mask: Binary mask as numpy array.
Returns:
Masked PIL Image.
"""
import numpy as np
from PIL import Image
image_np = np.array(image.convert("RGB"))
# Ensure mask is the right shape
if mask.shape != image_np.shape[:2]:
from skimage.transform import resize
mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False)
else:
mask_resized = mask
# Apply mask
masked_np = image_np * mask_resized[:, :, np.newaxis]
return Image.fromarray(masked_np.astype(np.uint8))

View File

@@ -118,6 +118,17 @@ class BenchmarkTaskConfig(BaseModel):
type: str = Field(default="retrieval", description="Task type")
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
# Multi-object retrieval specific settings
gamma: float = Field(
default=1.0, ge=0, description="Co-occurrence penalty exponent"
)
top_k_per_object: int = Field(
default=50, gt=0, description="Top K results per object query"
)
num_query_objects: int = Field(
default=3, gt=0, description="Number of objects to sample from query image"
)
class BenchmarkConfig(BaseModel):
"""Configuration for benchmark evaluation."""

View File

@@ -0,0 +1,64 @@
"""InsDet Scenes dataset for multi-object retrieval benchmark."""
from pathlib import Path
from typing import Any
from benchmarks.base import BaseDataset
from data_loading.loader import load_val_dataset
class InsDetScenesDataset(BaseDataset):
"""InsDet-FULL/Scenes dataset with easy/hard splits.
This dataset provides scene images with object annotations from the
Instance Detection (InsDet) dataset, supporting easy and hard splits.
"""
def __init__(
self,
scenes_dir: Path | str,
split: str = "easy",
):
"""Initialize InsDet Scenes dataset.
Args:
scenes_dir: Path to the InsDet-FULL/Scenes directory.
split: Scene split to use ('easy' or 'hard').
"""
self.scenes_dir = Path(scenes_dir)
self.split = split
self._dataset = load_val_dataset(self.scenes_dir, split)
def get_train_split(self) -> Any:
"""Get training split (same as test for this dataset).
Returns:
HuggingFace Dataset for training.
"""
return self._dataset
def get_test_split(self) -> Any:
"""Get test/evaluation split.
Returns:
HuggingFace Dataset for testing.
"""
return self._dataset
def __len__(self) -> int:
"""Get dataset length."""
return len(self._dataset)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Get a single item from the dataset.
Args:
idx: Index of the item.
Returns:
Dictionary containing:
- image: PIL Image
- image_id: Scene identifier
- objects: dict with bbox, category, area, id
"""
return self._dataset[idx]

View File

@@ -0,0 +1,238 @@
"""Integration tests for multi-object retrieval benchmark pipeline.
These tests verify the end-to-end functionality of the multi-object retrieval
benchmark, including schema building, database population, and evaluation.
"""
import numpy as np
import pytest
from unittest.mock import Mock, patch, MagicMock
from PIL import Image
class TestMultiObjectRetrievalIntegration:
"""Integration tests for multi-object retrieval benchmark."""
@pytest.fixture
def mock_model_processor(self):
"""Create mock model and processor."""
mock_model = Mock()
mock_processor = Mock()
# Mock the feature extraction to return a fixed-size vector
def mock_extract(processor, model, image):
return [0.1] * 256 # 256-dim vector
mock_processor.images = mock_extract
return mock_model, mock_processor
@pytest.fixture
def mock_dataset(self):
"""Create a mock dataset with images and annotations."""
# Create mock items
items = []
for i in range(3):
item = {
"image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)),
"image_id": f"scene_{i}",
"objects": {
"bbox": [[10, 10, 50, 50], [60, 60, 40, 40]],
"category": ["object_a", "object_b"],
"area": [2500, 1600],
"id": [0, 1],
},
}
items.append(item)
mock_dataset = Mock()
mock_dataset.__len__ = Mock(return_value=len(items))
mock_dataset.__getitem__ = lambda self, idx: items[idx]
mock_dataset.with_format = lambda fmt: mock_dataset
return mock_dataset
def test_build_object_schema(self):
"""Test that object schema is built correctly."""
from benchmarks.tasks.multi_object_retrieval import _build_object_schema
import pyarrow as pa
vector_dim = 256
schema = _build_object_schema(vector_dim)
assert isinstance(schema, pa.Schema)
assert "id" in schema.names
assert "image_id" in schema.names
assert "object_id" in schema.names
assert "category" in schema.names
assert "vector" in schema.names
# Check vector field has correct dimension
vector_field = schema.field("vector")
assert isinstance(vector_field.type, pa.List)
assert vector_field.type.value_type == pa.float32()
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
def test_build_database_with_mocked_sam(
self,
mock_segment,
mock_load_sam,
mock_model_processor,
mock_dataset,
):
"""Test database building with mocked SAM segmentation."""
from benchmarks.tasks.multi_object_retrieval import (
MultiObjectRetrievalTask,
_build_object_schema,
)
mock_model, mock_processor = mock_model_processor
# Mock SAM
mock_load_sam.return_value = (Mock(), Mock())
mock_segment.return_value = [
{
"segment": np.ones((224, 224), dtype=bool),
"area": 50000,
"bbox": [0, 0, 224, 224],
}
]
# Create task with config
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-large",
min_mask_area=1024,
max_masks_per_image=5,
gamma=1.0,
top_k_per_object=50,
num_query_objects=3,
)
# Create mock table
mock_table = Mock()
mock_table.schema = _build_object_schema(256)
# Build database (this should not raise)
task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
# Verify table.add was called
assert mock_table.add.called
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
def test_evaluate_with_mocked_sam(
self,
mock_segment,
mock_load_sam,
mock_model_processor,
mock_dataset,
):
"""Test evaluation with mocked SAM segmentation."""
from benchmarks.tasks.multi_object_retrieval import (
MultiObjectRetrievalTask,
_build_object_schema,
)
mock_model, mock_processor = mock_model_processor
# Mock SAM
mock_load_sam.return_value = (Mock(), Mock())
mock_segment.return_value = [
{
"segment": np.ones((224, 224), dtype=bool),
"area": 50000,
"bbox": [0, 0, 224, 224],
"object_id": "query_obj_0",
}
]
# Create mock table with search results
mock_table = Mock()
mock_table.schema = _build_object_schema(256)
# Mock search to return matching result
mock_result = Mock()
mock_result.to_polars.return_value = {
"image_id": ["scene_0"],
"object_id": ["scene_0_obj_0"],
"_distance": [0.1],
}
mock_table.search.return_value.select.return_value.limit.return_value = mock_result
# Create task
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-large",
min_mask_area=1024,
max_masks_per_image=5,
gamma=1.0,
top_k_per_object=50,
num_query_objects=1,
)
# Evaluate
results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
# Verify results structure
assert "accuracy" in results
assert "correct" in results
assert "total" in results
assert "top_k" in results
assert results["top_k"] == 50
def test_task_initialization_with_config(self):
"""Test task initialization with custom config."""
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-small",
min_mask_area=500,
max_masks_per_image=3,
gamma=0.5,
top_k_per_object=100,
num_query_objects=5,
)
assert task.sam_model == "facebook/sam2.1-hiera-small"
assert task.min_mask_area == 500
assert task.max_masks_per_image == 3
assert task.config.gamma == 0.5
assert task.config.top_k_per_object == 100
assert task.config.num_query_objects == 5
def test_task_initialization_defaults(self):
"""Test task initialization with default config."""
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
task = MultiObjectRetrievalTask()
# Check defaults from BenchmarkTaskConfig
assert task.config.gamma == 1.0
assert task.config.top_k_per_object == 50
assert task.config.num_query_objects == 3
# SAM settings from ModelConfig defaults
assert task.sam_model == "facebook/sam2.1-hiera-large"
assert task.min_mask_area == 1024
assert task.max_masks_per_image == 5
class TestInsDetScenesDataset:
"""Tests for InsDetScenesDataset class."""
def test_dataset_class_exists(self):
"""Test that InsDetScenesDataset can be imported."""
from data_loading.insdet_scenes import InsDetScenesDataset
assert InsDetScenesDataset is not None
@patch("data_loading.insdet_scenes.load_val_dataset")
def test_dataset_loads_correct_split(self, mock_load):
"""Test dataset loads correct split."""
from data_loading.insdet_scenes import InsDetScenesDataset
mock_load.return_value = Mock()
dataset = InsDetScenesDataset("/path/to/scenes", split="easy")
mock_load.assert_called_once_with("/path/to/scenes", "easy")
assert dataset.split == "easy"

168
mini-nav/tests/test_sam.py Normal file
View File

@@ -0,0 +1,168 @@
"""Tests for SAM segmentation utilities.
Note: These tests mock the SAM model loading since SAM requires
heavy model weights. The actual SAM integration should be tested
separately in integration tests.
"""
import numpy as np
import pytest
from unittest.mock import Mock, patch
from PIL import Image
class TestSAMSegmentation:
"""Test suite for SAM segmentation utilities."""
def test_segment_image_empty_masks(self):
"""Test segment_image returns empty list when no masks generated."""
from utils.sam import segment_image
# Create mock mask generator that returns empty list
mock_generator = Mock()
mock_generator.generate.return_value = []
result = segment_image(mock_generator, Image.new("RGB", (100, 100)))
assert result == []
def test_segment_image_filters_small_masks(self):
"""Test segment_image filters masks below min_area threshold."""
from utils.sam import segment_image
# Create mock masks with different areas
small_mask = {
"segment": np.zeros((10, 10), dtype=bool),
"area": 50, # Below 32*32 = 1024
"bbox": [0, 0, 10, 10],
"predicted_iou": 0.9,
"stability_score": 0.8,
}
large_mask = {
"segment": np.ones((100, 100), dtype=bool),
"area": 10000, # Above threshold
"bbox": [0, 0, 100, 100],
"predicted_iou": 0.95,
"stability_score": 0.9,
}
mock_generator = Mock()
mock_generator.generate.return_value = [small_mask, large_mask]
result = segment_image(
mock_generator,
Image.new("RGB", (100, 100)),
min_area=32 * 32,
max_masks=5,
)
# Should only return the large mask
assert len(result) == 1
assert result[0]["area"] == 10000
def test_segment_image_limits_max_masks(self):
"""Test segment_image limits to max_masks largest masks."""
from utils.sam import segment_image
# Create 10 masks with different areas
masks = [
{
"segment": np.ones((i + 1, i + 1), dtype=bool),
"area": (i + 1) * (i + 1),
"bbox": [0, 0, i + 1, i + 1],
"predicted_iou": 0.9,
"stability_score": 0.8,
}
for i in range(10)
]
mock_generator = Mock()
mock_generator.generate.return_value = masks
result = segment_image(
mock_generator,
Image.new("RGB", (100, 100)),
min_area=1,
max_masks=3,
)
# Should only return top 3 largest masks
assert len(result) == 3
# Check they are sorted by area (largest first)
areas = [m["area"] for m in result]
assert areas == sorted(areas, reverse=True)
def test_segment_image_sorted_by_area(self):
"""Test segment_image returns masks sorted by area descending."""
from utils.sam import segment_image
# Create masks with known areas (unordered)
mask1 = {"segment": np.ones((5, 5), dtype=bool), "area": 25, "bbox": [0, 0, 5, 5]}
mask2 = {"segment": np.ones((10, 10), dtype=bool), "area": 100, "bbox": [0, 0, 10, 10]}
mask3 = {"segment": np.ones((3, 3), dtype=bool), "area": 9, "bbox": [0, 0, 3, 3]}
mock_generator = Mock()
mock_generator.generate.return_value = [mask1, mask2, mask3]
result = segment_image(
mock_generator,
Image.new("RGB", (100, 100)),
min_area=1,
max_masks=10,
)
# Should be sorted by area descending
assert result[0]["area"] == 100
assert result[1]["area"] == 25
assert result[2]["area"] == 9
class TestExtractMaskedRegion:
"""Test suite for extracting masked regions from images."""
def test_extract_masked_region_binary(self):
"""Test extracting masked region with binary mask."""
from utils.sam import extract_masked_region
# Create a simple image
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
# Create a binary mask (half kept, half masked)
mask = np.zeros((10, 10), dtype=bool)
mask[:, :5] = True
result = extract_masked_region(image, mask)
# Check that left half is red, right half is black
result_np = np.array(result)
left_half = result_np[:, :5, :]
right_half = result_np[:, 5:, :]
assert np.all(left_half == [255, 0, 0])
assert np.all(right_half == [0, 0, 0])
def test_extract_masked_region_all_masked(self):
"""Test extracting when entire image is masked."""
from utils.sam import extract_masked_region
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
mask = np.ones((10, 10), dtype=bool)
result = extract_masked_region(image, mask)
result_np = np.array(result)
# Entire image should be preserved
assert np.all(result_np == [255, 0, 0])
def test_extract_masked_region_all_zero_mask(self):
"""Test extracting when mask is all zeros."""
from utils.sam import extract_masked_region
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
mask = np.zeros((10, 10), dtype=bool)
result = extract_masked_region(image, mask)
result_np = np.array(result)
# Entire image should be black
assert np.all(result_np == [0, 0, 0])

View File

@@ -0,0 +1,121 @@
"""Tests for scene scoring algorithm in multi-object retrieval."""
import pytest
from benchmarks.tasks.multi_object_retrieval import _compute_scene_score
class TestSceneScoringAlgorithm:
"""Test suite for scene scoring with co-occurrence penalty."""
def test_scene_score_basic(self):
"""Test basic scene scoring with single match."""
query_object_ids = ["obj_1", "obj_2", "obj_3"]
# Scene A has obj_1
retrieved_results = {
"scene_A": [("distance_1", "obj_1")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Hit rate = 1/3, similarity = 1/(1+distance_1)
assert "scene_A" in scores
assert scores["scene_A"] > 0
def test_scene_score_no_match(self):
"""Test scene scoring when no objects match."""
query_object_ids = ["obj_1", "obj_2", "obj_3"]
retrieved_results = {
"scene_A": [("distance_1", "other_obj")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
assert scores["scene_A"] == 0.0
def test_scene_score_multiple_scenes(self):
"""Test scoring across multiple scenes."""
query_object_ids = ["obj_1", "obj_2"]
retrieved_results = {
"scene_A": [("0.1", "obj_1")],
"scene_B": [("0.1", "obj_2")],
"scene_C": [("0.1", "other")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Scenes with matches should have positive scores
assert scores["scene_A"] > 0
assert scores["scene_B"] > 0
# Scene C has no match, score should be 0
assert scores["scene_C"] == 0.0
def test_scene_score_gamma_zero(self):
"""Test scoring with gamma=0 (no penalty)."""
query_object_ids = ["obj_1", "obj_2", "obj_3", "obj_4", "obj_5"]
retrieved_results = {
"scene_A": [("0.1", "obj_1")],
}
scores_gamma_0 = _compute_scene_score(query_object_ids, retrieved_results, gamma=0.0)
scores_gamma_1 = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# With gamma=0, hit_rate^0 = 1, so score = similarity
# With gamma=1, hit_rate^1 = 1/5, so score = similarity * 1/5
# scores_gamma_0 should be larger
assert scores_gamma_0["scene_A"] > scores_gamma_1["scene_A"]
def test_scene_score_multiple_matches(self):
"""Test scoring when scene has multiple matching objects."""
query_object_ids = ["obj_1", "obj_2"]
retrieved_results = {
"scene_A": [("0.1", "obj_1"), ("0.2", "obj_2")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Both objects match, hit_rate = 2/2 = 1.0
# Score = (1/(1+0.1) + 1/(1+0.2)) * 1.0
expected_similarity = 1.0 / 1.1 + 1.0 / 1.2
assert abs(scores["scene_A"] - expected_similarity) < 0.01
def test_scene_score_distance_to_similarity(self):
"""Test that smaller distance yields higher score."""
query_object_ids = ["obj_1"]
retrieved_results = {
"scene_close": [("0.01", "obj_1")],
"scene_far": [("10.0", "obj_1")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Closer scene should have higher score
assert scores["scene_close"] > scores["scene_far"]
def test_scene_score_empty_results(self):
"""Test scoring with empty retrieved results."""
query_object_ids = ["obj_1", "obj_2"]
retrieved_results = {}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
assert scores == {}
def test_scene_score_empty_query(self):
"""Test scoring with empty query objects."""
query_object_ids = []
retrieved_results = {
"scene_A": [("0.1", "obj_1")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# With empty query, no scenes should have positive score
assert all(score == 0.0 for score in scores.values())

100
mini-nav/utils/sam.py Normal file
View File

@@ -0,0 +1,100 @@
"""SAM (Segment Anything Model) utilities for object segmentation."""
from pathlib import Path
from typing import Any
import numpy as np
import torch
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
def load_sam_model(
model_name: str = "facebook/sam2.1-hiera-large",
device: str = "cuda",
checkpoint_dir: Path | None = None,
) -> tuple[Any, Any]:
"""Load SAM 2.1 model and mask generator.
Args:
model_name: SAM model name (currently supports facebook/sam2.1-hiera-*).
device: Device to load model on (cuda or cpu).
checkpoint_dir: Optional directory for model checkpoint cache.
Returns:
Tuple of (sam_model, mask_generator).
"""
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
# Build SAM2 model
sam_model = build_sam2(model_name, device=device)
# Create automatic mask generator
mask_generator = SAM2AutomaticMaskGenerator(sam_model)
return sam_model, mask_generator
def segment_image(
mask_generator: Any,
image: Image.Image,
min_area: int = 32 * 32,
max_masks: int = 5,
) -> list[dict[str, Any]]:
"""Segment image using SAM to extract object masks.
Args:
mask_generator: SAM2AutomaticMaskGenerator instance.
image: PIL Image to segment.
min_area: Minimum mask area threshold in pixels.
max_masks: Maximum number of masks to return.
Returns:
List of mask dictionaries with keys:
- segment: Binary mask (numpy array)
- area: Mask area in pixels
- bbox: Bounding box [x, y, width, height]
- predicted_iou: Model's confidence in the mask
- stability_score: Stability score for the mask
"""
# Convert PIL Image to numpy array
image_np = np.array(image.convert("RGB"))
# Generate masks
masks = mask_generator.generate(image_np)
if not masks:
return []
# Filter by minimum area
filtered_masks = [m for m in masks if m["area"] >= min_area]
if not filtered_masks:
return []
# Sort by area (largest first) and limit to max_masks
sorted_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True)
return sorted_masks[:max_masks]
def extract_masked_region(
image: Image.Image,
mask: np.ndarray,
) -> Image.Image:
"""Extract masked region from image.
Args:
image: Original PIL Image.
mask: Binary mask as numpy array (True = keep).
Returns:
PIL Image with only the masked region visible.
"""
image_np = np.array(image.convert("RGB"))
# Apply mask
masked_np = image_np * mask[:, :, np.newaxis]
return Image.fromarray(masked_np.astype(np.uint8))