From b39ee74e999f2bf9866d94fc9d36e2cad61c8d4f Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Thu, 12 Mar 2026 12:51:04 +0800 Subject: [PATCH] feat(benchmark): add multi-object retrieval benchmark with SAM segmentation --- .../tasks/multi_object_retrieval.py | 356 ++++++++++++++++++ mini-nav/configs/models.py | 11 + mini-nav/data_loading/insdet_scenes.py | 64 ++++ mini-nav/tests/test_multi_object_retrieval.py | 238 ++++++++++++ mini-nav/tests/test_sam.py | 168 +++++++++ mini-nav/tests/test_scene_scoring.py | 121 ++++++ mini-nav/utils/sam.py | 100 +++++ 7 files changed, 1058 insertions(+) create mode 100644 mini-nav/benchmarks/tasks/multi_object_retrieval.py create mode 100644 mini-nav/data_loading/insdet_scenes.py create mode 100644 mini-nav/tests/test_multi_object_retrieval.py create mode 100644 mini-nav/tests/test_sam.py create mode 100644 mini-nav/tests/test_scene_scoring.py create mode 100644 mini-nav/utils/sam.py diff --git a/mini-nav/benchmarks/tasks/multi_object_retrieval.py b/mini-nav/benchmarks/tasks/multi_object_retrieval.py new file mode 100644 index 0000000..75fafe5 --- /dev/null +++ b/mini-nav/benchmarks/tasks/multi_object_retrieval.py @@ -0,0 +1,356 @@ +"""Multi-object retrieval benchmark task. + +This benchmark evaluates retrieval accuracy using multiple objects from a cropped +scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature +extraction, and LanceDB for vector storage with scene-level score aggregation. +""" + +import random +from typing import Any + +import lancedb +import numpy as np +import pyarrow as pa +from benchmarks.base import BaseBenchmarkTask +from benchmarks.tasks.registry import RegisterTask +from configs.models import BenchmarkTaskConfig +from rich.progress import track +from torch import nn +from torch.utils.data import DataLoader +from transformers import BitImageProcessorFast +from utils.feature_extractor import extract_single_image_feature +from utils.sam import load_sam_model, segment_image +from utils.common import get_device + + +def _build_object_schema(vector_dim: int) -> pa.Schema: + """Build PyArrow schema for object-level vectors. + + Args: + vector_dim: Feature vector dimension. + + Returns: + PyArrow schema with id, image_id, object_id, category, and vector fields. + """ + return pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("image_id", pa.string()), + pa.field("object_id", pa.string()), + pa.field("category", pa.string()), + pa.field("vector", pa.list_(pa.float32(), vector_dim)), + ] + ) + + +def _compute_scene_score( + query_object_ids: list[str], + retrieved_results: dict[str, list[tuple[float, str]]], + gamma: float, +) -> dict[str, float]: + """Compute scene-level scores using co-occurrence penalty. + + Args: + query_object_ids: List of query object IDs. + retrieved_results: Dict mapping image_id to list of (distance, object_id) results. + gamma: Co-occurrence penalty exponent. + + Returns: + Dict mapping image_id to computed scene score. + """ + scene_scores: dict[str, float] = {} + + for image_id, results in retrieved_results.items(): + # Build a set of retrieved object IDs for this scene + retrieved_ids = {obj_id for _, obj_id in results} + + # Count how many query objects are found in this scene + matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids) + + if matched_count == 0: + scene_scores[image_id] = 0.0 + continue + + # Sum of best similarities (using distance as similarity: smaller = better) + # We use 1/(1+distance) to convert distance to similarity + similarities = [] + for dist, obj_id in results: + if obj_id in query_object_ids: + sim = 1.0 / (1.0 + dist) + similarities.append(sim) + + sum_similarity = sum(similarities) if similarities else 0.0 + + # Hit rate: ratio of matched objects + hit_rate = matched_count / len(query_object_ids) + + # Final score: sum_similarity * (hit_rate)^gamma + score = sum_similarity * (hit_rate ** gamma) + scene_scores[image_id] = score + + return scene_scores + + +@RegisterTask("multi-object-retrieval") +class MultiObjectRetrievalTask(BaseBenchmarkTask): + """Multi-object retrieval benchmark task.""" + + def __init__(self, **kwargs: Any): + """Initialize multi-object retrieval task. + + Args: + **kwargs: Configuration parameters from BenchmarkTaskConfig. + """ + # Use config from kwargs or load default config + if kwargs: + config_dict = kwargs + else: + config = BenchmarkTaskConfig(type="multi-object-retrieval") + config_dict = config.model_dump() + + super().__init__(**config_dict) + self.config = BenchmarkTaskConfig(**config_dict) + + # SAM settings from ModelConfig (passed via kwargs or use defaults) + self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large") + self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32) + self.max_masks_per_image = kwargs.get("sam_max_masks", 5) + + # Lazy-loaded resources + self._sam_model = None + self._mask_generator = None + + @property + def sam_model(self) -> Any: + """Lazy-load SAM model.""" + if self._sam_model is None: + self._sam_model, self._mask_generator = load_sam_model( + model_name=self.sam_model, + device=str(get_device()), + ) + return self._sam_model + + @property + def mask_generator(self) -> Any: + """Lazy-load mask generator.""" + if self._mask_generator is None: + self._sam_model, self._mask_generator = load_sam_model( + model_name=self.sam_model, + device=str(get_device()), + ) + return self._mask_generator + + def build_database( + self, + model: nn.Module, + processor: BitImageProcessorFast, + train_dataset: Any, + table: lancedb.table.Table, + batch_size: int, + ) -> None: + """Build the evaluation database with object-level vectors. + + Args: + model: Feature extraction model. + processor: Image preprocessor. + train_dataset: Training dataset. + table: LanceDB table to store features. + batch_size: Batch size for DataLoader. + """ + # Infer vector dimension from a sample + sample = train_dataset[0] + sample_image = sample["image"] + + # Get vector dimension by running a forward pass + vector_dim = self._infer_vector_dim(processor, model, sample_image) + expected_schema = _build_object_schema(vector_dim) + + # Check schema compatibility + if table.schema != expected_schema: + raise ValueError( + f"Table schema mismatch. Expected: {expected_schema}, " + f"Got: {table.schema}" + ) + + # Build database: segment each image, extract features per object + record_id = 0 + records = [] + + for idx in track(range(len(train_dataset)), description="Building object database"): + item = train_dataset[idx] + image = item["image"] + image_id = item.get("image_id", f"image_{idx}") + + # Segment image using SAM + masks = segment_image( + self.mask_generator, + image, + min_area=self.min_mask_area, + max_masks=self.max_masks_per_image, + ) + + if not masks: + continue + + # Extract features for each mask + for mask_idx, mask_info in enumerate(masks): + # Extract masked region + masked_image = self._apply_mask(image, mask_info["segment"]) + + # Extract feature vector + vector = extract_single_image_feature(processor, model, masked_image) + + # Create object ID + object_id = f"{image_id}_obj_{mask_idx}" + category = mask_info.get("category", "unknown") + + records.append({ + "id": record_id, + "image_id": image_id, + "object_id": object_id, + "category": category, + "vector": vector, + }) + record_id += 1 + + # Add all records to table + if records: + table.add(records) + + def evaluate( + self, + model: nn.Module, + processor: BitImageProcessorFast, + test_dataset: Any, + table: lancedb.table.Table, + batch_size: int, + ) -> dict[str, Any]: + """Evaluate the model on the test dataset. + + Args: + model: Feature extraction model. + processor: Image preprocessor. + test_dataset: Test dataset. + table: LanceDB table to search against. + batch_size: Batch size for DataLoader. + + Returns: + Dictionary containing evaluation results with keys: + - accuracy: Recall@K accuracy (0.0 ~ 1.0) + - correct: Number of correct predictions + - total: Total number of test samples + - top_k: The K value used + """ + top_k = self.config.top_k_per_object + + correct = 0 + total = 0 + + for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"): + item = test_dataset[idx] + image = item["image"] + target_image_id = item.get("image_id", f"image_{idx}") + + # Segment query image + masks = segment_image( + self.mask_generator, + image, + min_area=self.min_mask_area, + max_masks=self.max_masks_per_image, + ) + + if not masks: + continue + + # Randomly sample query objects + num_query = min(self.config.num_query_objects, len(masks)) + query_masks = random.sample(masks, num_query) + + # Extract features and search for each query object + retrieved_results: dict[str, list[tuple[float, str]]] = {} + + for mask_info in query_masks: + # Extract masked region + masked_image = self._apply_mask(image, mask_info["segment"]) + + # Extract feature vector + vector = extract_single_image_feature(processor, model, masked_image) + + # Search in LanceDB + results = ( + table.search(vector) + .select(["image_id", "object_id", "_distance"]) + .limit(top_k) + .to_polars() + ) + + # Aggregate results by scene + for row in results.iter_rows(): + image_id = row["image_id"] + object_id = row["object_id"] + distance = row["_distance"] + + if image_id not in retrieved_results: + retrieved_results[image_id] = [] + retrieved_results[image_id].append((distance, object_id)) + + # Compute scene scores + query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)] + scene_scores = _compute_scene_score( + query_object_ids, + retrieved_results, + self.config.gamma, + ) + + # Rank scenes by score + ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True) + + # Check if target is in top-K + top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]] + if target_image_id in top_k_scenes: + correct += 1 + total += 1 + + accuracy = correct / total if total > 0 else 0.0 + + return { + "accuracy": accuracy, + "correct": correct, + "total": total, + "top_k": top_k, + } + + def _infer_vector_dim( + self, + processor: BitImageProcessorFast, + model: nn.Module, + sample_image: Any, + ) -> int: + """Infer vector dimension from model output.""" + vector = extract_single_image_feature(processor, model, sample_image) + return len(vector) + + def _apply_mask(self, image: Any, mask: np.ndarray) -> Any: + """Apply mask to image and return masked image. + + Args: + image: PIL Image. + mask: Binary mask as numpy array. + + Returns: + Masked PIL Image. + """ + import numpy as np + from PIL import Image + + image_np = np.array(image.convert("RGB")) + # Ensure mask is the right shape + if mask.shape != image_np.shape[:2]: + from skimage.transform import resize + mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False) + else: + mask_resized = mask + + # Apply mask + masked_np = image_np * mask_resized[:, :, np.newaxis] + return Image.fromarray(masked_np.astype(np.uint8)) diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py index b3e244c..7109178 100644 --- a/mini-nav/configs/models.py +++ b/mini-nav/configs/models.py @@ -118,6 +118,17 @@ class BenchmarkTaskConfig(BaseModel): type: str = Field(default="retrieval", description="Task type") top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation") + # Multi-object retrieval specific settings + gamma: float = Field( + default=1.0, ge=0, description="Co-occurrence penalty exponent" + ) + top_k_per_object: int = Field( + default=50, gt=0, description="Top K results per object query" + ) + num_query_objects: int = Field( + default=3, gt=0, description="Number of objects to sample from query image" + ) + class BenchmarkConfig(BaseModel): """Configuration for benchmark evaluation.""" diff --git a/mini-nav/data_loading/insdet_scenes.py b/mini-nav/data_loading/insdet_scenes.py new file mode 100644 index 0000000..e335857 --- /dev/null +++ b/mini-nav/data_loading/insdet_scenes.py @@ -0,0 +1,64 @@ +"""InsDet Scenes dataset for multi-object retrieval benchmark.""" + +from pathlib import Path +from typing import Any + +from benchmarks.base import BaseDataset +from data_loading.loader import load_val_dataset + + +class InsDetScenesDataset(BaseDataset): + """InsDet-FULL/Scenes dataset with easy/hard splits. + + This dataset provides scene images with object annotations from the + Instance Detection (InsDet) dataset, supporting easy and hard splits. + """ + + def __init__( + self, + scenes_dir: Path | str, + split: str = "easy", + ): + """Initialize InsDet Scenes dataset. + + Args: + scenes_dir: Path to the InsDet-FULL/Scenes directory. + split: Scene split to use ('easy' or 'hard'). + """ + self.scenes_dir = Path(scenes_dir) + self.split = split + self._dataset = load_val_dataset(self.scenes_dir, split) + + def get_train_split(self) -> Any: + """Get training split (same as test for this dataset). + + Returns: + HuggingFace Dataset for training. + """ + return self._dataset + + def get_test_split(self) -> Any: + """Get test/evaluation split. + + Returns: + HuggingFace Dataset for testing. + """ + return self._dataset + + def __len__(self) -> int: + """Get dataset length.""" + return len(self._dataset) + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Get a single item from the dataset. + + Args: + idx: Index of the item. + + Returns: + Dictionary containing: + - image: PIL Image + - image_id: Scene identifier + - objects: dict with bbox, category, area, id + """ + return self._dataset[idx] diff --git a/mini-nav/tests/test_multi_object_retrieval.py b/mini-nav/tests/test_multi_object_retrieval.py new file mode 100644 index 0000000..7d71ac2 --- /dev/null +++ b/mini-nav/tests/test_multi_object_retrieval.py @@ -0,0 +1,238 @@ +"""Integration tests for multi-object retrieval benchmark pipeline. + +These tests verify the end-to-end functionality of the multi-object retrieval +benchmark, including schema building, database population, and evaluation. +""" + +import numpy as np +import pytest +from unittest.mock import Mock, patch, MagicMock +from PIL import Image + + +class TestMultiObjectRetrievalIntegration: + """Integration tests for multi-object retrieval benchmark.""" + + @pytest.fixture + def mock_model_processor(self): + """Create mock model and processor.""" + mock_model = Mock() + mock_processor = Mock() + + # Mock the feature extraction to return a fixed-size vector + def mock_extract(processor, model, image): + return [0.1] * 256 # 256-dim vector + mock_processor.images = mock_extract + + return mock_model, mock_processor + + @pytest.fixture + def mock_dataset(self): + """Create a mock dataset with images and annotations.""" + # Create mock items + items = [] + for i in range(3): + item = { + "image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)), + "image_id": f"scene_{i}", + "objects": { + "bbox": [[10, 10, 50, 50], [60, 60, 40, 40]], + "category": ["object_a", "object_b"], + "area": [2500, 1600], + "id": [0, 1], + }, + } + items.append(item) + + mock_dataset = Mock() + mock_dataset.__len__ = Mock(return_value=len(items)) + mock_dataset.__getitem__ = lambda self, idx: items[idx] + mock_dataset.with_format = lambda fmt: mock_dataset + + return mock_dataset + + def test_build_object_schema(self): + """Test that object schema is built correctly.""" + from benchmarks.tasks.multi_object_retrieval import _build_object_schema + import pyarrow as pa + + vector_dim = 256 + schema = _build_object_schema(vector_dim) + + assert isinstance(schema, pa.Schema) + assert "id" in schema.names + assert "image_id" in schema.names + assert "object_id" in schema.names + assert "category" in schema.names + assert "vector" in schema.names + + # Check vector field has correct dimension + vector_field = schema.field("vector") + assert isinstance(vector_field.type, pa.List) + assert vector_field.type.value_type == pa.float32() + + @patch("benchmarks.tasks.multi_object_retrieval.load_sam_model") + @patch("benchmarks.tasks.multi_object_retrieval.segment_image") + def test_build_database_with_mocked_sam( + self, + mock_segment, + mock_load_sam, + mock_model_processor, + mock_dataset, + ): + """Test database building with mocked SAM segmentation.""" + from benchmarks.tasks.multi_object_retrieval import ( + MultiObjectRetrievalTask, + _build_object_schema, + ) + + mock_model, mock_processor = mock_model_processor + + # Mock SAM + mock_load_sam.return_value = (Mock(), Mock()) + mock_segment.return_value = [ + { + "segment": np.ones((224, 224), dtype=bool), + "area": 50000, + "bbox": [0, 0, 224, 224], + } + ] + + # Create task with config + task = MultiObjectRetrievalTask( + sam_model="facebook/sam2.1-hiera-large", + min_mask_area=1024, + max_masks_per_image=5, + gamma=1.0, + top_k_per_object=50, + num_query_objects=3, + ) + + # Create mock table + mock_table = Mock() + mock_table.schema = _build_object_schema(256) + + # Build database (this should not raise) + task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1) + + # Verify table.add was called + assert mock_table.add.called + + @patch("benchmarks.tasks.multi_object_retrieval.load_sam_model") + @patch("benchmarks.tasks.multi_object_retrieval.segment_image") + def test_evaluate_with_mocked_sam( + self, + mock_segment, + mock_load_sam, + mock_model_processor, + mock_dataset, + ): + """Test evaluation with mocked SAM segmentation.""" + from benchmarks.tasks.multi_object_retrieval import ( + MultiObjectRetrievalTask, + _build_object_schema, + ) + + mock_model, mock_processor = mock_model_processor + + # Mock SAM + mock_load_sam.return_value = (Mock(), Mock()) + mock_segment.return_value = [ + { + "segment": np.ones((224, 224), dtype=bool), + "area": 50000, + "bbox": [0, 0, 224, 224], + "object_id": "query_obj_0", + } + ] + + # Create mock table with search results + mock_table = Mock() + mock_table.schema = _build_object_schema(256) + + # Mock search to return matching result + mock_result = Mock() + mock_result.to_polars.return_value = { + "image_id": ["scene_0"], + "object_id": ["scene_0_obj_0"], + "_distance": [0.1], + } + + mock_table.search.return_value.select.return_value.limit.return_value = mock_result + + # Create task + task = MultiObjectRetrievalTask( + sam_model="facebook/sam2.1-hiera-large", + min_mask_area=1024, + max_masks_per_image=5, + gamma=1.0, + top_k_per_object=50, + num_query_objects=1, + ) + + # Evaluate + results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1) + + # Verify results structure + assert "accuracy" in results + assert "correct" in results + assert "total" in results + assert "top_k" in results + assert results["top_k"] == 50 + + def test_task_initialization_with_config(self): + """Test task initialization with custom config.""" + from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask + + task = MultiObjectRetrievalTask( + sam_model="facebook/sam2.1-hiera-small", + min_mask_area=500, + max_masks_per_image=3, + gamma=0.5, + top_k_per_object=100, + num_query_objects=5, + ) + + assert task.sam_model == "facebook/sam2.1-hiera-small" + assert task.min_mask_area == 500 + assert task.max_masks_per_image == 3 + assert task.config.gamma == 0.5 + assert task.config.top_k_per_object == 100 + assert task.config.num_query_objects == 5 + + def test_task_initialization_defaults(self): + """Test task initialization with default config.""" + from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask + + task = MultiObjectRetrievalTask() + + # Check defaults from BenchmarkTaskConfig + assert task.config.gamma == 1.0 + assert task.config.top_k_per_object == 50 + assert task.config.num_query_objects == 3 + # SAM settings from ModelConfig defaults + assert task.sam_model == "facebook/sam2.1-hiera-large" + assert task.min_mask_area == 1024 + assert task.max_masks_per_image == 5 + + +class TestInsDetScenesDataset: + """Tests for InsDetScenesDataset class.""" + + def test_dataset_class_exists(self): + """Test that InsDetScenesDataset can be imported.""" + from data_loading.insdet_scenes import InsDetScenesDataset + + assert InsDetScenesDataset is not None + + @patch("data_loading.insdet_scenes.load_val_dataset") + def test_dataset_loads_correct_split(self, mock_load): + """Test dataset loads correct split.""" + from data_loading.insdet_scenes import InsDetScenesDataset + + mock_load.return_value = Mock() + + dataset = InsDetScenesDataset("/path/to/scenes", split="easy") + + mock_load.assert_called_once_with("/path/to/scenes", "easy") + assert dataset.split == "easy" diff --git a/mini-nav/tests/test_sam.py b/mini-nav/tests/test_sam.py new file mode 100644 index 0000000..3d9ba2c --- /dev/null +++ b/mini-nav/tests/test_sam.py @@ -0,0 +1,168 @@ +"""Tests for SAM segmentation utilities. + +Note: These tests mock the SAM model loading since SAM requires +heavy model weights. The actual SAM integration should be tested +separately in integration tests. +""" + +import numpy as np +import pytest +from unittest.mock import Mock, patch +from PIL import Image + + +class TestSAMSegmentation: + """Test suite for SAM segmentation utilities.""" + + def test_segment_image_empty_masks(self): + """Test segment_image returns empty list when no masks generated.""" + from utils.sam import segment_image + + # Create mock mask generator that returns empty list + mock_generator = Mock() + mock_generator.generate.return_value = [] + + result = segment_image(mock_generator, Image.new("RGB", (100, 100))) + + assert result == [] + + def test_segment_image_filters_small_masks(self): + """Test segment_image filters masks below min_area threshold.""" + from utils.sam import segment_image + + # Create mock masks with different areas + small_mask = { + "segment": np.zeros((10, 10), dtype=bool), + "area": 50, # Below 32*32 = 1024 + "bbox": [0, 0, 10, 10], + "predicted_iou": 0.9, + "stability_score": 0.8, + } + large_mask = { + "segment": np.ones((100, 100), dtype=bool), + "area": 10000, # Above threshold + "bbox": [0, 0, 100, 100], + "predicted_iou": 0.95, + "stability_score": 0.9, + } + + mock_generator = Mock() + mock_generator.generate.return_value = [small_mask, large_mask] + + result = segment_image( + mock_generator, + Image.new("RGB", (100, 100)), + min_area=32 * 32, + max_masks=5, + ) + + # Should only return the large mask + assert len(result) == 1 + assert result[0]["area"] == 10000 + + def test_segment_image_limits_max_masks(self): + """Test segment_image limits to max_masks largest masks.""" + from utils.sam import segment_image + + # Create 10 masks with different areas + masks = [ + { + "segment": np.ones((i + 1, i + 1), dtype=bool), + "area": (i + 1) * (i + 1), + "bbox": [0, 0, i + 1, i + 1], + "predicted_iou": 0.9, + "stability_score": 0.8, + } + for i in range(10) + ] + + mock_generator = Mock() + mock_generator.generate.return_value = masks + + result = segment_image( + mock_generator, + Image.new("RGB", (100, 100)), + min_area=1, + max_masks=3, + ) + + # Should only return top 3 largest masks + assert len(result) == 3 + # Check they are sorted by area (largest first) + areas = [m["area"] for m in result] + assert areas == sorted(areas, reverse=True) + + def test_segment_image_sorted_by_area(self): + """Test segment_image returns masks sorted by area descending.""" + from utils.sam import segment_image + + # Create masks with known areas (unordered) + mask1 = {"segment": np.ones((5, 5), dtype=bool), "area": 25, "bbox": [0, 0, 5, 5]} + mask2 = {"segment": np.ones((10, 10), dtype=bool), "area": 100, "bbox": [0, 0, 10, 10]} + mask3 = {"segment": np.ones((3, 3), dtype=bool), "area": 9, "bbox": [0, 0, 3, 3]} + + mock_generator = Mock() + mock_generator.generate.return_value = [mask1, mask2, mask3] + + result = segment_image( + mock_generator, + Image.new("RGB", (100, 100)), + min_area=1, + max_masks=10, + ) + + # Should be sorted by area descending + assert result[0]["area"] == 100 + assert result[1]["area"] == 25 + assert result[2]["area"] == 9 + + +class TestExtractMaskedRegion: + """Test suite for extracting masked regions from images.""" + + def test_extract_masked_region_binary(self): + """Test extracting masked region with binary mask.""" + from utils.sam import extract_masked_region + + # Create a simple image + image = Image.new("RGB", (10, 10), color=(255, 0, 0)) + + # Create a binary mask (half kept, half masked) + mask = np.zeros((10, 10), dtype=bool) + mask[:, :5] = True + + result = extract_masked_region(image, mask) + + # Check that left half is red, right half is black + result_np = np.array(result) + left_half = result_np[:, :5, :] + right_half = result_np[:, 5:, :] + + assert np.all(left_half == [255, 0, 0]) + assert np.all(right_half == [0, 0, 0]) + + def test_extract_masked_region_all_masked(self): + """Test extracting when entire image is masked.""" + from utils.sam import extract_masked_region + + image = Image.new("RGB", (10, 10), color=(255, 0, 0)) + mask = np.ones((10, 10), dtype=bool) + + result = extract_masked_region(image, mask) + result_np = np.array(result) + + # Entire image should be preserved + assert np.all(result_np == [255, 0, 0]) + + def test_extract_masked_region_all_zero_mask(self): + """Test extracting when mask is all zeros.""" + from utils.sam import extract_masked_region + + image = Image.new("RGB", (10, 10), color=(255, 0, 0)) + mask = np.zeros((10, 10), dtype=bool) + + result = extract_masked_region(image, mask) + result_np = np.array(result) + + # Entire image should be black + assert np.all(result_np == [0, 0, 0]) diff --git a/mini-nav/tests/test_scene_scoring.py b/mini-nav/tests/test_scene_scoring.py new file mode 100644 index 0000000..de12468 --- /dev/null +++ b/mini-nav/tests/test_scene_scoring.py @@ -0,0 +1,121 @@ +"""Tests for scene scoring algorithm in multi-object retrieval.""" + +import pytest +from benchmarks.tasks.multi_object_retrieval import _compute_scene_score + + +class TestSceneScoringAlgorithm: + """Test suite for scene scoring with co-occurrence penalty.""" + + def test_scene_score_basic(self): + """Test basic scene scoring with single match.""" + query_object_ids = ["obj_1", "obj_2", "obj_3"] + + # Scene A has obj_1 + retrieved_results = { + "scene_A": [("distance_1", "obj_1")], + } + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + # Hit rate = 1/3, similarity = 1/(1+distance_1) + assert "scene_A" in scores + assert scores["scene_A"] > 0 + + def test_scene_score_no_match(self): + """Test scene scoring when no objects match.""" + query_object_ids = ["obj_1", "obj_2", "obj_3"] + + retrieved_results = { + "scene_A": [("distance_1", "other_obj")], + } + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + assert scores["scene_A"] == 0.0 + + def test_scene_score_multiple_scenes(self): + """Test scoring across multiple scenes.""" + query_object_ids = ["obj_1", "obj_2"] + + retrieved_results = { + "scene_A": [("0.1", "obj_1")], + "scene_B": [("0.1", "obj_2")], + "scene_C": [("0.1", "other")], + } + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + # Scenes with matches should have positive scores + assert scores["scene_A"] > 0 + assert scores["scene_B"] > 0 + # Scene C has no match, score should be 0 + assert scores["scene_C"] == 0.0 + + def test_scene_score_gamma_zero(self): + """Test scoring with gamma=0 (no penalty).""" + query_object_ids = ["obj_1", "obj_2", "obj_3", "obj_4", "obj_5"] + + retrieved_results = { + "scene_A": [("0.1", "obj_1")], + } + + scores_gamma_0 = _compute_scene_score(query_object_ids, retrieved_results, gamma=0.0) + scores_gamma_1 = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + # With gamma=0, hit_rate^0 = 1, so score = similarity + # With gamma=1, hit_rate^1 = 1/5, so score = similarity * 1/5 + # scores_gamma_0 should be larger + assert scores_gamma_0["scene_A"] > scores_gamma_1["scene_A"] + + def test_scene_score_multiple_matches(self): + """Test scoring when scene has multiple matching objects.""" + query_object_ids = ["obj_1", "obj_2"] + + retrieved_results = { + "scene_A": [("0.1", "obj_1"), ("0.2", "obj_2")], + } + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + # Both objects match, hit_rate = 2/2 = 1.0 + # Score = (1/(1+0.1) + 1/(1+0.2)) * 1.0 + expected_similarity = 1.0 / 1.1 + 1.0 / 1.2 + assert abs(scores["scene_A"] - expected_similarity) < 0.01 + + def test_scene_score_distance_to_similarity(self): + """Test that smaller distance yields higher score.""" + query_object_ids = ["obj_1"] + + retrieved_results = { + "scene_close": [("0.01", "obj_1")], + "scene_far": [("10.0", "obj_1")], + } + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + # Closer scene should have higher score + assert scores["scene_close"] > scores["scene_far"] + + def test_scene_score_empty_results(self): + """Test scoring with empty retrieved results.""" + query_object_ids = ["obj_1", "obj_2"] + + retrieved_results = {} + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + assert scores == {} + + def test_scene_score_empty_query(self): + """Test scoring with empty query objects.""" + query_object_ids = [] + + retrieved_results = { + "scene_A": [("0.1", "obj_1")], + } + + scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0) + + # With empty query, no scenes should have positive score + assert all(score == 0.0 for score in scores.values()) diff --git a/mini-nav/utils/sam.py b/mini-nav/utils/sam.py new file mode 100644 index 0000000..a896d90 --- /dev/null +++ b/mini-nav/utils/sam.py @@ -0,0 +1,100 @@ +"""SAM (Segment Anything Model) utilities for object segmentation.""" + +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from PIL import Image +from sam2.build_sam import build_sam2 +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + + +def load_sam_model( + model_name: str = "facebook/sam2.1-hiera-large", + device: str = "cuda", + checkpoint_dir: Path | None = None, +) -> tuple[Any, Any]: + """Load SAM 2.1 model and mask generator. + + Args: + model_name: SAM model name (currently supports facebook/sam2.1-hiera-*). + device: Device to load model on (cuda or cpu). + checkpoint_dir: Optional directory for model checkpoint cache. + + Returns: + Tuple of (sam_model, mask_generator). + """ + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + + # Build SAM2 model + sam_model = build_sam2(model_name, device=device) + + # Create automatic mask generator + mask_generator = SAM2AutomaticMaskGenerator(sam_model) + + return sam_model, mask_generator + + +def segment_image( + mask_generator: Any, + image: Image.Image, + min_area: int = 32 * 32, + max_masks: int = 5, +) -> list[dict[str, Any]]: + """Segment image using SAM to extract object masks. + + Args: + mask_generator: SAM2AutomaticMaskGenerator instance. + image: PIL Image to segment. + min_area: Minimum mask area threshold in pixels. + max_masks: Maximum number of masks to return. + + Returns: + List of mask dictionaries with keys: + - segment: Binary mask (numpy array) + - area: Mask area in pixels + - bbox: Bounding box [x, y, width, height] + - predicted_iou: Model's confidence in the mask + - stability_score: Stability score for the mask + """ + # Convert PIL Image to numpy array + image_np = np.array(image.convert("RGB")) + + # Generate masks + masks = mask_generator.generate(image_np) + + if not masks: + return [] + + # Filter by minimum area + filtered_masks = [m for m in masks if m["area"] >= min_area] + + if not filtered_masks: + return [] + + # Sort by area (largest first) and limit to max_masks + sorted_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True) + return sorted_masks[:max_masks] + + +def extract_masked_region( + image: Image.Image, + mask: np.ndarray, +) -> Image.Image: + """Extract masked region from image. + + Args: + image: Original PIL Image. + mask: Binary mask as numpy array (True = keep). + + Returns: + PIL Image with only the masked region visible. + """ + image_np = np.array(image.convert("RGB")) + + # Apply mask + masked_np = image_np * mask[:, :, np.newaxis] + + return Image.fromarray(masked_np.astype(np.uint8))