mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-13 04:45:32 +08:00
feat(benchmark): add multi-object retrieval benchmark with SAM segmentation
This commit is contained in:
238
mini-nav/tests/test_multi_object_retrieval.py
Normal file
238
mini-nav/tests/test_multi_object_retrieval.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Integration tests for multi-object retrieval benchmark pipeline.
|
||||
|
||||
These tests verify the end-to-end functionality of the multi-object retrieval
|
||||
benchmark, including schema building, database population, and evaluation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestMultiObjectRetrievalIntegration:
|
||||
"""Integration tests for multi-object retrieval benchmark."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_processor(self):
|
||||
"""Create mock model and processor."""
|
||||
mock_model = Mock()
|
||||
mock_processor = Mock()
|
||||
|
||||
# Mock the feature extraction to return a fixed-size vector
|
||||
def mock_extract(processor, model, image):
|
||||
return [0.1] * 256 # 256-dim vector
|
||||
mock_processor.images = mock_extract
|
||||
|
||||
return mock_model, mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(self):
|
||||
"""Create a mock dataset with images and annotations."""
|
||||
# Create mock items
|
||||
items = []
|
||||
for i in range(3):
|
||||
item = {
|
||||
"image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)),
|
||||
"image_id": f"scene_{i}",
|
||||
"objects": {
|
||||
"bbox": [[10, 10, 50, 50], [60, 60, 40, 40]],
|
||||
"category": ["object_a", "object_b"],
|
||||
"area": [2500, 1600],
|
||||
"id": [0, 1],
|
||||
},
|
||||
}
|
||||
items.append(item)
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.__len__ = Mock(return_value=len(items))
|
||||
mock_dataset.__getitem__ = lambda self, idx: items[idx]
|
||||
mock_dataset.with_format = lambda fmt: mock_dataset
|
||||
|
||||
return mock_dataset
|
||||
|
||||
def test_build_object_schema(self):
|
||||
"""Test that object schema is built correctly."""
|
||||
from benchmarks.tasks.multi_object_retrieval import _build_object_schema
|
||||
import pyarrow as pa
|
||||
|
||||
vector_dim = 256
|
||||
schema = _build_object_schema(vector_dim)
|
||||
|
||||
assert isinstance(schema, pa.Schema)
|
||||
assert "id" in schema.names
|
||||
assert "image_id" in schema.names
|
||||
assert "object_id" in schema.names
|
||||
assert "category" in schema.names
|
||||
assert "vector" in schema.names
|
||||
|
||||
# Check vector field has correct dimension
|
||||
vector_field = schema.field("vector")
|
||||
assert isinstance(vector_field.type, pa.List)
|
||||
assert vector_field.type.value_type == pa.float32()
|
||||
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
|
||||
def test_build_database_with_mocked_sam(
|
||||
self,
|
||||
mock_segment,
|
||||
mock_load_sam,
|
||||
mock_model_processor,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test database building with mocked SAM segmentation."""
|
||||
from benchmarks.tasks.multi_object_retrieval import (
|
||||
MultiObjectRetrievalTask,
|
||||
_build_object_schema,
|
||||
)
|
||||
|
||||
mock_model, mock_processor = mock_model_processor
|
||||
|
||||
# Mock SAM
|
||||
mock_load_sam.return_value = (Mock(), Mock())
|
||||
mock_segment.return_value = [
|
||||
{
|
||||
"segment": np.ones((224, 224), dtype=bool),
|
||||
"area": 50000,
|
||||
"bbox": [0, 0, 224, 224],
|
||||
}
|
||||
]
|
||||
|
||||
# Create task with config
|
||||
task = MultiObjectRetrievalTask(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
min_mask_area=1024,
|
||||
max_masks_per_image=5,
|
||||
gamma=1.0,
|
||||
top_k_per_object=50,
|
||||
num_query_objects=3,
|
||||
)
|
||||
|
||||
# Create mock table
|
||||
mock_table = Mock()
|
||||
mock_table.schema = _build_object_schema(256)
|
||||
|
||||
# Build database (this should not raise)
|
||||
task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
|
||||
|
||||
# Verify table.add was called
|
||||
assert mock_table.add.called
|
||||
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
|
||||
def test_evaluate_with_mocked_sam(
|
||||
self,
|
||||
mock_segment,
|
||||
mock_load_sam,
|
||||
mock_model_processor,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test evaluation with mocked SAM segmentation."""
|
||||
from benchmarks.tasks.multi_object_retrieval import (
|
||||
MultiObjectRetrievalTask,
|
||||
_build_object_schema,
|
||||
)
|
||||
|
||||
mock_model, mock_processor = mock_model_processor
|
||||
|
||||
# Mock SAM
|
||||
mock_load_sam.return_value = (Mock(), Mock())
|
||||
mock_segment.return_value = [
|
||||
{
|
||||
"segment": np.ones((224, 224), dtype=bool),
|
||||
"area": 50000,
|
||||
"bbox": [0, 0, 224, 224],
|
||||
"object_id": "query_obj_0",
|
||||
}
|
||||
]
|
||||
|
||||
# Create mock table with search results
|
||||
mock_table = Mock()
|
||||
mock_table.schema = _build_object_schema(256)
|
||||
|
||||
# Mock search to return matching result
|
||||
mock_result = Mock()
|
||||
mock_result.to_polars.return_value = {
|
||||
"image_id": ["scene_0"],
|
||||
"object_id": ["scene_0_obj_0"],
|
||||
"_distance": [0.1],
|
||||
}
|
||||
|
||||
mock_table.search.return_value.select.return_value.limit.return_value = mock_result
|
||||
|
||||
# Create task
|
||||
task = MultiObjectRetrievalTask(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
min_mask_area=1024,
|
||||
max_masks_per_image=5,
|
||||
gamma=1.0,
|
||||
top_k_per_object=50,
|
||||
num_query_objects=1,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
|
||||
|
||||
# Verify results structure
|
||||
assert "accuracy" in results
|
||||
assert "correct" in results
|
||||
assert "total" in results
|
||||
assert "top_k" in results
|
||||
assert results["top_k"] == 50
|
||||
|
||||
def test_task_initialization_with_config(self):
|
||||
"""Test task initialization with custom config."""
|
||||
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
|
||||
|
||||
task = MultiObjectRetrievalTask(
|
||||
sam_model="facebook/sam2.1-hiera-small",
|
||||
min_mask_area=500,
|
||||
max_masks_per_image=3,
|
||||
gamma=0.5,
|
||||
top_k_per_object=100,
|
||||
num_query_objects=5,
|
||||
)
|
||||
|
||||
assert task.sam_model == "facebook/sam2.1-hiera-small"
|
||||
assert task.min_mask_area == 500
|
||||
assert task.max_masks_per_image == 3
|
||||
assert task.config.gamma == 0.5
|
||||
assert task.config.top_k_per_object == 100
|
||||
assert task.config.num_query_objects == 5
|
||||
|
||||
def test_task_initialization_defaults(self):
|
||||
"""Test task initialization with default config."""
|
||||
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
|
||||
|
||||
task = MultiObjectRetrievalTask()
|
||||
|
||||
# Check defaults from BenchmarkTaskConfig
|
||||
assert task.config.gamma == 1.0
|
||||
assert task.config.top_k_per_object == 50
|
||||
assert task.config.num_query_objects == 3
|
||||
# SAM settings from ModelConfig defaults
|
||||
assert task.sam_model == "facebook/sam2.1-hiera-large"
|
||||
assert task.min_mask_area == 1024
|
||||
assert task.max_masks_per_image == 5
|
||||
|
||||
|
||||
class TestInsDetScenesDataset:
|
||||
"""Tests for InsDetScenesDataset class."""
|
||||
|
||||
def test_dataset_class_exists(self):
|
||||
"""Test that InsDetScenesDataset can be imported."""
|
||||
from data_loading.insdet_scenes import InsDetScenesDataset
|
||||
|
||||
assert InsDetScenesDataset is not None
|
||||
|
||||
@patch("data_loading.insdet_scenes.load_val_dataset")
|
||||
def test_dataset_loads_correct_split(self, mock_load):
|
||||
"""Test dataset loads correct split."""
|
||||
from data_loading.insdet_scenes import InsDetScenesDataset
|
||||
|
||||
mock_load.return_value = Mock()
|
||||
|
||||
dataset = InsDetScenesDataset("/path/to/scenes", split="easy")
|
||||
|
||||
mock_load.assert_called_once_with("/path/to/scenes", "easy")
|
||||
assert dataset.split == "easy"
|
||||
168
mini-nav/tests/test_sam.py
Normal file
168
mini-nav/tests/test_sam.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for SAM segmentation utilities.
|
||||
|
||||
Note: These tests mock the SAM model loading since SAM requires
|
||||
heavy model weights. The actual SAM integration should be tested
|
||||
separately in integration tests.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestSAMSegmentation:
|
||||
"""Test suite for SAM segmentation utilities."""
|
||||
|
||||
def test_segment_image_empty_masks(self):
|
||||
"""Test segment_image returns empty list when no masks generated."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create mock mask generator that returns empty list
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = []
|
||||
|
||||
result = segment_image(mock_generator, Image.new("RGB", (100, 100)))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_segment_image_filters_small_masks(self):
|
||||
"""Test segment_image filters masks below min_area threshold."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create mock masks with different areas
|
||||
small_mask = {
|
||||
"segment": np.zeros((10, 10), dtype=bool),
|
||||
"area": 50, # Below 32*32 = 1024
|
||||
"bbox": [0, 0, 10, 10],
|
||||
"predicted_iou": 0.9,
|
||||
"stability_score": 0.8,
|
||||
}
|
||||
large_mask = {
|
||||
"segment": np.ones((100, 100), dtype=bool),
|
||||
"area": 10000, # Above threshold
|
||||
"bbox": [0, 0, 100, 100],
|
||||
"predicted_iou": 0.95,
|
||||
"stability_score": 0.9,
|
||||
}
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = [small_mask, large_mask]
|
||||
|
||||
result = segment_image(
|
||||
mock_generator,
|
||||
Image.new("RGB", (100, 100)),
|
||||
min_area=32 * 32,
|
||||
max_masks=5,
|
||||
)
|
||||
|
||||
# Should only return the large mask
|
||||
assert len(result) == 1
|
||||
assert result[0]["area"] == 10000
|
||||
|
||||
def test_segment_image_limits_max_masks(self):
|
||||
"""Test segment_image limits to max_masks largest masks."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create 10 masks with different areas
|
||||
masks = [
|
||||
{
|
||||
"segment": np.ones((i + 1, i + 1), dtype=bool),
|
||||
"area": (i + 1) * (i + 1),
|
||||
"bbox": [0, 0, i + 1, i + 1],
|
||||
"predicted_iou": 0.9,
|
||||
"stability_score": 0.8,
|
||||
}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = masks
|
||||
|
||||
result = segment_image(
|
||||
mock_generator,
|
||||
Image.new("RGB", (100, 100)),
|
||||
min_area=1,
|
||||
max_masks=3,
|
||||
)
|
||||
|
||||
# Should only return top 3 largest masks
|
||||
assert len(result) == 3
|
||||
# Check they are sorted by area (largest first)
|
||||
areas = [m["area"] for m in result]
|
||||
assert areas == sorted(areas, reverse=True)
|
||||
|
||||
def test_segment_image_sorted_by_area(self):
|
||||
"""Test segment_image returns masks sorted by area descending."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create masks with known areas (unordered)
|
||||
mask1 = {"segment": np.ones((5, 5), dtype=bool), "area": 25, "bbox": [0, 0, 5, 5]}
|
||||
mask2 = {"segment": np.ones((10, 10), dtype=bool), "area": 100, "bbox": [0, 0, 10, 10]}
|
||||
mask3 = {"segment": np.ones((3, 3), dtype=bool), "area": 9, "bbox": [0, 0, 3, 3]}
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = [mask1, mask2, mask3]
|
||||
|
||||
result = segment_image(
|
||||
mock_generator,
|
||||
Image.new("RGB", (100, 100)),
|
||||
min_area=1,
|
||||
max_masks=10,
|
||||
)
|
||||
|
||||
# Should be sorted by area descending
|
||||
assert result[0]["area"] == 100
|
||||
assert result[1]["area"] == 25
|
||||
assert result[2]["area"] == 9
|
||||
|
||||
|
||||
class TestExtractMaskedRegion:
|
||||
"""Test suite for extracting masked regions from images."""
|
||||
|
||||
def test_extract_masked_region_binary(self):
|
||||
"""Test extracting masked region with binary mask."""
|
||||
from utils.sam import extract_masked_region
|
||||
|
||||
# Create a simple image
|
||||
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||
|
||||
# Create a binary mask (half kept, half masked)
|
||||
mask = np.zeros((10, 10), dtype=bool)
|
||||
mask[:, :5] = True
|
||||
|
||||
result = extract_masked_region(image, mask)
|
||||
|
||||
# Check that left half is red, right half is black
|
||||
result_np = np.array(result)
|
||||
left_half = result_np[:, :5, :]
|
||||
right_half = result_np[:, 5:, :]
|
||||
|
||||
assert np.all(left_half == [255, 0, 0])
|
||||
assert np.all(right_half == [0, 0, 0])
|
||||
|
||||
def test_extract_masked_region_all_masked(self):
|
||||
"""Test extracting when entire image is masked."""
|
||||
from utils.sam import extract_masked_region
|
||||
|
||||
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||
mask = np.ones((10, 10), dtype=bool)
|
||||
|
||||
result = extract_masked_region(image, mask)
|
||||
result_np = np.array(result)
|
||||
|
||||
# Entire image should be preserved
|
||||
assert np.all(result_np == [255, 0, 0])
|
||||
|
||||
def test_extract_masked_region_all_zero_mask(self):
|
||||
"""Test extracting when mask is all zeros."""
|
||||
from utils.sam import extract_masked_region
|
||||
|
||||
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||
mask = np.zeros((10, 10), dtype=bool)
|
||||
|
||||
result = extract_masked_region(image, mask)
|
||||
result_np = np.array(result)
|
||||
|
||||
# Entire image should be black
|
||||
assert np.all(result_np == [0, 0, 0])
|
||||
121
mini-nav/tests/test_scene_scoring.py
Normal file
121
mini-nav/tests/test_scene_scoring.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Tests for scene scoring algorithm in multi-object retrieval."""
|
||||
|
||||
import pytest
|
||||
from benchmarks.tasks.multi_object_retrieval import _compute_scene_score
|
||||
|
||||
|
||||
class TestSceneScoringAlgorithm:
|
||||
"""Test suite for scene scoring with co-occurrence penalty."""
|
||||
|
||||
def test_scene_score_basic(self):
|
||||
"""Test basic scene scoring with single match."""
|
||||
query_object_ids = ["obj_1", "obj_2", "obj_3"]
|
||||
|
||||
# Scene A has obj_1
|
||||
retrieved_results = {
|
||||
"scene_A": [("distance_1", "obj_1")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Hit rate = 1/3, similarity = 1/(1+distance_1)
|
||||
assert "scene_A" in scores
|
||||
assert scores["scene_A"] > 0
|
||||
|
||||
def test_scene_score_no_match(self):
|
||||
"""Test scene scoring when no objects match."""
|
||||
query_object_ids = ["obj_1", "obj_2", "obj_3"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("distance_1", "other_obj")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
assert scores["scene_A"] == 0.0
|
||||
|
||||
def test_scene_score_multiple_scenes(self):
|
||||
"""Test scoring across multiple scenes."""
|
||||
query_object_ids = ["obj_1", "obj_2"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1")],
|
||||
"scene_B": [("0.1", "obj_2")],
|
||||
"scene_C": [("0.1", "other")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Scenes with matches should have positive scores
|
||||
assert scores["scene_A"] > 0
|
||||
assert scores["scene_B"] > 0
|
||||
# Scene C has no match, score should be 0
|
||||
assert scores["scene_C"] == 0.0
|
||||
|
||||
def test_scene_score_gamma_zero(self):
|
||||
"""Test scoring with gamma=0 (no penalty)."""
|
||||
query_object_ids = ["obj_1", "obj_2", "obj_3", "obj_4", "obj_5"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1")],
|
||||
}
|
||||
|
||||
scores_gamma_0 = _compute_scene_score(query_object_ids, retrieved_results, gamma=0.0)
|
||||
scores_gamma_1 = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# With gamma=0, hit_rate^0 = 1, so score = similarity
|
||||
# With gamma=1, hit_rate^1 = 1/5, so score = similarity * 1/5
|
||||
# scores_gamma_0 should be larger
|
||||
assert scores_gamma_0["scene_A"] > scores_gamma_1["scene_A"]
|
||||
|
||||
def test_scene_score_multiple_matches(self):
|
||||
"""Test scoring when scene has multiple matching objects."""
|
||||
query_object_ids = ["obj_1", "obj_2"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1"), ("0.2", "obj_2")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Both objects match, hit_rate = 2/2 = 1.0
|
||||
# Score = (1/(1+0.1) + 1/(1+0.2)) * 1.0
|
||||
expected_similarity = 1.0 / 1.1 + 1.0 / 1.2
|
||||
assert abs(scores["scene_A"] - expected_similarity) < 0.01
|
||||
|
||||
def test_scene_score_distance_to_similarity(self):
|
||||
"""Test that smaller distance yields higher score."""
|
||||
query_object_ids = ["obj_1"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_close": [("0.01", "obj_1")],
|
||||
"scene_far": [("10.0", "obj_1")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Closer scene should have higher score
|
||||
assert scores["scene_close"] > scores["scene_far"]
|
||||
|
||||
def test_scene_score_empty_results(self):
|
||||
"""Test scoring with empty retrieved results."""
|
||||
query_object_ids = ["obj_1", "obj_2"]
|
||||
|
||||
retrieved_results = {}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
assert scores == {}
|
||||
|
||||
def test_scene_score_empty_query(self):
|
||||
"""Test scoring with empty query objects."""
|
||||
query_object_ids = []
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# With empty query, no scenes should have positive score
|
||||
assert all(score == 0.0 for score in scores.values())
|
||||
Reference in New Issue
Block a user