feat(benchmark): add multi-object retrieval benchmark with SAM segmentation

This commit is contained in:
2026-03-12 12:51:04 +08:00
parent 2466ab28cd
commit b39ee74e99
7 changed files with 1058 additions and 0 deletions

View File

@@ -0,0 +1,238 @@
"""Integration tests for multi-object retrieval benchmark pipeline.
These tests verify the end-to-end functionality of the multi-object retrieval
benchmark, including schema building, database population, and evaluation.
"""
import numpy as np
import pytest
from unittest.mock import Mock, patch, MagicMock
from PIL import Image
class TestMultiObjectRetrievalIntegration:
"""Integration tests for multi-object retrieval benchmark."""
@pytest.fixture
def mock_model_processor(self):
"""Create mock model and processor."""
mock_model = Mock()
mock_processor = Mock()
# Mock the feature extraction to return a fixed-size vector
def mock_extract(processor, model, image):
return [0.1] * 256 # 256-dim vector
mock_processor.images = mock_extract
return mock_model, mock_processor
@pytest.fixture
def mock_dataset(self):
"""Create a mock dataset with images and annotations."""
# Create mock items
items = []
for i in range(3):
item = {
"image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)),
"image_id": f"scene_{i}",
"objects": {
"bbox": [[10, 10, 50, 50], [60, 60, 40, 40]],
"category": ["object_a", "object_b"],
"area": [2500, 1600],
"id": [0, 1],
},
}
items.append(item)
mock_dataset = Mock()
mock_dataset.__len__ = Mock(return_value=len(items))
mock_dataset.__getitem__ = lambda self, idx: items[idx]
mock_dataset.with_format = lambda fmt: mock_dataset
return mock_dataset
def test_build_object_schema(self):
"""Test that object schema is built correctly."""
from benchmarks.tasks.multi_object_retrieval import _build_object_schema
import pyarrow as pa
vector_dim = 256
schema = _build_object_schema(vector_dim)
assert isinstance(schema, pa.Schema)
assert "id" in schema.names
assert "image_id" in schema.names
assert "object_id" in schema.names
assert "category" in schema.names
assert "vector" in schema.names
# Check vector field has correct dimension
vector_field = schema.field("vector")
assert isinstance(vector_field.type, pa.List)
assert vector_field.type.value_type == pa.float32()
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
def test_build_database_with_mocked_sam(
self,
mock_segment,
mock_load_sam,
mock_model_processor,
mock_dataset,
):
"""Test database building with mocked SAM segmentation."""
from benchmarks.tasks.multi_object_retrieval import (
MultiObjectRetrievalTask,
_build_object_schema,
)
mock_model, mock_processor = mock_model_processor
# Mock SAM
mock_load_sam.return_value = (Mock(), Mock())
mock_segment.return_value = [
{
"segment": np.ones((224, 224), dtype=bool),
"area": 50000,
"bbox": [0, 0, 224, 224],
}
]
# Create task with config
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-large",
min_mask_area=1024,
max_masks_per_image=5,
gamma=1.0,
top_k_per_object=50,
num_query_objects=3,
)
# Create mock table
mock_table = Mock()
mock_table.schema = _build_object_schema(256)
# Build database (this should not raise)
task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
# Verify table.add was called
assert mock_table.add.called
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
def test_evaluate_with_mocked_sam(
self,
mock_segment,
mock_load_sam,
mock_model_processor,
mock_dataset,
):
"""Test evaluation with mocked SAM segmentation."""
from benchmarks.tasks.multi_object_retrieval import (
MultiObjectRetrievalTask,
_build_object_schema,
)
mock_model, mock_processor = mock_model_processor
# Mock SAM
mock_load_sam.return_value = (Mock(), Mock())
mock_segment.return_value = [
{
"segment": np.ones((224, 224), dtype=bool),
"area": 50000,
"bbox": [0, 0, 224, 224],
"object_id": "query_obj_0",
}
]
# Create mock table with search results
mock_table = Mock()
mock_table.schema = _build_object_schema(256)
# Mock search to return matching result
mock_result = Mock()
mock_result.to_polars.return_value = {
"image_id": ["scene_0"],
"object_id": ["scene_0_obj_0"],
"_distance": [0.1],
}
mock_table.search.return_value.select.return_value.limit.return_value = mock_result
# Create task
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-large",
min_mask_area=1024,
max_masks_per_image=5,
gamma=1.0,
top_k_per_object=50,
num_query_objects=1,
)
# Evaluate
results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
# Verify results structure
assert "accuracy" in results
assert "correct" in results
assert "total" in results
assert "top_k" in results
assert results["top_k"] == 50
def test_task_initialization_with_config(self):
"""Test task initialization with custom config."""
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-small",
min_mask_area=500,
max_masks_per_image=3,
gamma=0.5,
top_k_per_object=100,
num_query_objects=5,
)
assert task.sam_model == "facebook/sam2.1-hiera-small"
assert task.min_mask_area == 500
assert task.max_masks_per_image == 3
assert task.config.gamma == 0.5
assert task.config.top_k_per_object == 100
assert task.config.num_query_objects == 5
def test_task_initialization_defaults(self):
"""Test task initialization with default config."""
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
task = MultiObjectRetrievalTask()
# Check defaults from BenchmarkTaskConfig
assert task.config.gamma == 1.0
assert task.config.top_k_per_object == 50
assert task.config.num_query_objects == 3
# SAM settings from ModelConfig defaults
assert task.sam_model == "facebook/sam2.1-hiera-large"
assert task.min_mask_area == 1024
assert task.max_masks_per_image == 5
class TestInsDetScenesDataset:
"""Tests for InsDetScenesDataset class."""
def test_dataset_class_exists(self):
"""Test that InsDetScenesDataset can be imported."""
from data_loading.insdet_scenes import InsDetScenesDataset
assert InsDetScenesDataset is not None
@patch("data_loading.insdet_scenes.load_val_dataset")
def test_dataset_loads_correct_split(self, mock_load):
"""Test dataset loads correct split."""
from data_loading.insdet_scenes import InsDetScenesDataset
mock_load.return_value = Mock()
dataset = InsDetScenesDataset("/path/to/scenes", split="easy")
mock_load.assert_called_once_with("/path/to/scenes", "easy")
assert dataset.split == "easy"