mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-13 04:45:32 +08:00
feat(benchmark): add multi-object retrieval benchmark with SAM segmentation
This commit is contained in:
356
mini-nav/benchmarks/tasks/multi_object_retrieval.py
Normal file
356
mini-nav/benchmarks/tasks/multi_object_retrieval.py
Normal file
@@ -0,0 +1,356 @@
|
|||||||
|
"""Multi-object retrieval benchmark task.
|
||||||
|
|
||||||
|
This benchmark evaluates retrieval accuracy using multiple objects from a cropped
|
||||||
|
scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature
|
||||||
|
extraction, and LanceDB for vector storage with scene-level score aggregation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
from benchmarks.base import BaseBenchmarkTask
|
||||||
|
from benchmarks.tasks.registry import RegisterTask
|
||||||
|
from configs.models import BenchmarkTaskConfig
|
||||||
|
from rich.progress import track
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import BitImageProcessorFast
|
||||||
|
from utils.feature_extractor import extract_single_image_feature
|
||||||
|
from utils.sam import load_sam_model, segment_image
|
||||||
|
from utils.common import get_device
|
||||||
|
|
||||||
|
|
||||||
|
def _build_object_schema(vector_dim: int) -> pa.Schema:
|
||||||
|
"""Build PyArrow schema for object-level vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_dim: Feature vector dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PyArrow schema with id, image_id, object_id, category, and vector fields.
|
||||||
|
"""
|
||||||
|
return pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.int32()),
|
||||||
|
pa.field("image_id", pa.string()),
|
||||||
|
pa.field("object_id", pa.string()),
|
||||||
|
pa.field("category", pa.string()),
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_scene_score(
|
||||||
|
query_object_ids: list[str],
|
||||||
|
retrieved_results: dict[str, list[tuple[float, str]]],
|
||||||
|
gamma: float,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Compute scene-level scores using co-occurrence penalty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_object_ids: List of query object IDs.
|
||||||
|
retrieved_results: Dict mapping image_id to list of (distance, object_id) results.
|
||||||
|
gamma: Co-occurrence penalty exponent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping image_id to computed scene score.
|
||||||
|
"""
|
||||||
|
scene_scores: dict[str, float] = {}
|
||||||
|
|
||||||
|
for image_id, results in retrieved_results.items():
|
||||||
|
# Build a set of retrieved object IDs for this scene
|
||||||
|
retrieved_ids = {obj_id for _, obj_id in results}
|
||||||
|
|
||||||
|
# Count how many query objects are found in this scene
|
||||||
|
matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids)
|
||||||
|
|
||||||
|
if matched_count == 0:
|
||||||
|
scene_scores[image_id] = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Sum of best similarities (using distance as similarity: smaller = better)
|
||||||
|
# We use 1/(1+distance) to convert distance to similarity
|
||||||
|
similarities = []
|
||||||
|
for dist, obj_id in results:
|
||||||
|
if obj_id in query_object_ids:
|
||||||
|
sim = 1.0 / (1.0 + dist)
|
||||||
|
similarities.append(sim)
|
||||||
|
|
||||||
|
sum_similarity = sum(similarities) if similarities else 0.0
|
||||||
|
|
||||||
|
# Hit rate: ratio of matched objects
|
||||||
|
hit_rate = matched_count / len(query_object_ids)
|
||||||
|
|
||||||
|
# Final score: sum_similarity * (hit_rate)^gamma
|
||||||
|
score = sum_similarity * (hit_rate ** gamma)
|
||||||
|
scene_scores[image_id] = score
|
||||||
|
|
||||||
|
return scene_scores
|
||||||
|
|
||||||
|
|
||||||
|
@RegisterTask("multi-object-retrieval")
|
||||||
|
class MultiObjectRetrievalTask(BaseBenchmarkTask):
|
||||||
|
"""Multi-object retrieval benchmark task."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
"""Initialize multi-object retrieval task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Configuration parameters from BenchmarkTaskConfig.
|
||||||
|
"""
|
||||||
|
# Use config from kwargs or load default config
|
||||||
|
if kwargs:
|
||||||
|
config_dict = kwargs
|
||||||
|
else:
|
||||||
|
config = BenchmarkTaskConfig(type="multi-object-retrieval")
|
||||||
|
config_dict = config.model_dump()
|
||||||
|
|
||||||
|
super().__init__(**config_dict)
|
||||||
|
self.config = BenchmarkTaskConfig(**config_dict)
|
||||||
|
|
||||||
|
# SAM settings from ModelConfig (passed via kwargs or use defaults)
|
||||||
|
self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large")
|
||||||
|
self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32)
|
||||||
|
self.max_masks_per_image = kwargs.get("sam_max_masks", 5)
|
||||||
|
|
||||||
|
# Lazy-loaded resources
|
||||||
|
self._sam_model = None
|
||||||
|
self._mask_generator = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sam_model(self) -> Any:
|
||||||
|
"""Lazy-load SAM model."""
|
||||||
|
if self._sam_model is None:
|
||||||
|
self._sam_model, self._mask_generator = load_sam_model(
|
||||||
|
model_name=self.sam_model,
|
||||||
|
device=str(get_device()),
|
||||||
|
)
|
||||||
|
return self._sam_model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask_generator(self) -> Any:
|
||||||
|
"""Lazy-load mask generator."""
|
||||||
|
if self._mask_generator is None:
|
||||||
|
self._sam_model, self._mask_generator = load_sam_model(
|
||||||
|
model_name=self.sam_model,
|
||||||
|
device=str(get_device()),
|
||||||
|
)
|
||||||
|
return self._mask_generator
|
||||||
|
|
||||||
|
def build_database(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
train_dataset: Any,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
batch_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Build the evaluation database with object-level vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Feature extraction model.
|
||||||
|
processor: Image preprocessor.
|
||||||
|
train_dataset: Training dataset.
|
||||||
|
table: LanceDB table to store features.
|
||||||
|
batch_size: Batch size for DataLoader.
|
||||||
|
"""
|
||||||
|
# Infer vector dimension from a sample
|
||||||
|
sample = train_dataset[0]
|
||||||
|
sample_image = sample["image"]
|
||||||
|
|
||||||
|
# Get vector dimension by running a forward pass
|
||||||
|
vector_dim = self._infer_vector_dim(processor, model, sample_image)
|
||||||
|
expected_schema = _build_object_schema(vector_dim)
|
||||||
|
|
||||||
|
# Check schema compatibility
|
||||||
|
if table.schema != expected_schema:
|
||||||
|
raise ValueError(
|
||||||
|
f"Table schema mismatch. Expected: {expected_schema}, "
|
||||||
|
f"Got: {table.schema}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build database: segment each image, extract features per object
|
||||||
|
record_id = 0
|
||||||
|
records = []
|
||||||
|
|
||||||
|
for idx in track(range(len(train_dataset)), description="Building object database"):
|
||||||
|
item = train_dataset[idx]
|
||||||
|
image = item["image"]
|
||||||
|
image_id = item.get("image_id", f"image_{idx}")
|
||||||
|
|
||||||
|
# Segment image using SAM
|
||||||
|
masks = segment_image(
|
||||||
|
self.mask_generator,
|
||||||
|
image,
|
||||||
|
min_area=self.min_mask_area,
|
||||||
|
max_masks=self.max_masks_per_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not masks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract features for each mask
|
||||||
|
for mask_idx, mask_info in enumerate(masks):
|
||||||
|
# Extract masked region
|
||||||
|
masked_image = self._apply_mask(image, mask_info["segment"])
|
||||||
|
|
||||||
|
# Extract feature vector
|
||||||
|
vector = extract_single_image_feature(processor, model, masked_image)
|
||||||
|
|
||||||
|
# Create object ID
|
||||||
|
object_id = f"{image_id}_obj_{mask_idx}"
|
||||||
|
category = mask_info.get("category", "unknown")
|
||||||
|
|
||||||
|
records.append({
|
||||||
|
"id": record_id,
|
||||||
|
"image_id": image_id,
|
||||||
|
"object_id": object_id,
|
||||||
|
"category": category,
|
||||||
|
"vector": vector,
|
||||||
|
})
|
||||||
|
record_id += 1
|
||||||
|
|
||||||
|
# Add all records to table
|
||||||
|
if records:
|
||||||
|
table.add(records)
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
test_dataset: Any,
|
||||||
|
table: lancedb.table.Table,
|
||||||
|
batch_size: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Evaluate the model on the test dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Feature extraction model.
|
||||||
|
processor: Image preprocessor.
|
||||||
|
test_dataset: Test dataset.
|
||||||
|
table: LanceDB table to search against.
|
||||||
|
batch_size: Batch size for DataLoader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing evaluation results with keys:
|
||||||
|
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
|
||||||
|
- correct: Number of correct predictions
|
||||||
|
- total: Total number of test samples
|
||||||
|
- top_k: The K value used
|
||||||
|
"""
|
||||||
|
top_k = self.config.top_k_per_object
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"):
|
||||||
|
item = test_dataset[idx]
|
||||||
|
image = item["image"]
|
||||||
|
target_image_id = item.get("image_id", f"image_{idx}")
|
||||||
|
|
||||||
|
# Segment query image
|
||||||
|
masks = segment_image(
|
||||||
|
self.mask_generator,
|
||||||
|
image,
|
||||||
|
min_area=self.min_mask_area,
|
||||||
|
max_masks=self.max_masks_per_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not masks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Randomly sample query objects
|
||||||
|
num_query = min(self.config.num_query_objects, len(masks))
|
||||||
|
query_masks = random.sample(masks, num_query)
|
||||||
|
|
||||||
|
# Extract features and search for each query object
|
||||||
|
retrieved_results: dict[str, list[tuple[float, str]]] = {}
|
||||||
|
|
||||||
|
for mask_info in query_masks:
|
||||||
|
# Extract masked region
|
||||||
|
masked_image = self._apply_mask(image, mask_info["segment"])
|
||||||
|
|
||||||
|
# Extract feature vector
|
||||||
|
vector = extract_single_image_feature(processor, model, masked_image)
|
||||||
|
|
||||||
|
# Search in LanceDB
|
||||||
|
results = (
|
||||||
|
table.search(vector)
|
||||||
|
.select(["image_id", "object_id", "_distance"])
|
||||||
|
.limit(top_k)
|
||||||
|
.to_polars()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Aggregate results by scene
|
||||||
|
for row in results.iter_rows():
|
||||||
|
image_id = row["image_id"]
|
||||||
|
object_id = row["object_id"]
|
||||||
|
distance = row["_distance"]
|
||||||
|
|
||||||
|
if image_id not in retrieved_results:
|
||||||
|
retrieved_results[image_id] = []
|
||||||
|
retrieved_results[image_id].append((distance, object_id))
|
||||||
|
|
||||||
|
# Compute scene scores
|
||||||
|
query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)]
|
||||||
|
scene_scores = _compute_scene_score(
|
||||||
|
query_object_ids,
|
||||||
|
retrieved_results,
|
||||||
|
self.config.gamma,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rank scenes by score
|
||||||
|
ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Check if target is in top-K
|
||||||
|
top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]]
|
||||||
|
if target_image_id in top_k_scenes:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
accuracy = correct / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct": correct,
|
||||||
|
"total": total,
|
||||||
|
"top_k": top_k,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _infer_vector_dim(
|
||||||
|
self,
|
||||||
|
processor: BitImageProcessorFast,
|
||||||
|
model: nn.Module,
|
||||||
|
sample_image: Any,
|
||||||
|
) -> int:
|
||||||
|
"""Infer vector dimension from model output."""
|
||||||
|
vector = extract_single_image_feature(processor, model, sample_image)
|
||||||
|
return len(vector)
|
||||||
|
|
||||||
|
def _apply_mask(self, image: Any, mask: np.ndarray) -> Any:
|
||||||
|
"""Apply mask to image and return masked image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image.
|
||||||
|
mask: Binary mask as numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Masked PIL Image.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
image_np = np.array(image.convert("RGB"))
|
||||||
|
# Ensure mask is the right shape
|
||||||
|
if mask.shape != image_np.shape[:2]:
|
||||||
|
from skimage.transform import resize
|
||||||
|
mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False)
|
||||||
|
else:
|
||||||
|
mask_resized = mask
|
||||||
|
|
||||||
|
# Apply mask
|
||||||
|
masked_np = image_np * mask_resized[:, :, np.newaxis]
|
||||||
|
return Image.fromarray(masked_np.astype(np.uint8))
|
||||||
@@ -118,6 +118,17 @@ class BenchmarkTaskConfig(BaseModel):
|
|||||||
type: str = Field(default="retrieval", description="Task type")
|
type: str = Field(default="retrieval", description="Task type")
|
||||||
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
|
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
|
||||||
|
|
||||||
|
# Multi-object retrieval specific settings
|
||||||
|
gamma: float = Field(
|
||||||
|
default=1.0, ge=0, description="Co-occurrence penalty exponent"
|
||||||
|
)
|
||||||
|
top_k_per_object: int = Field(
|
||||||
|
default=50, gt=0, description="Top K results per object query"
|
||||||
|
)
|
||||||
|
num_query_objects: int = Field(
|
||||||
|
default=3, gt=0, description="Number of objects to sample from query image"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkConfig(BaseModel):
|
class BenchmarkConfig(BaseModel):
|
||||||
"""Configuration for benchmark evaluation."""
|
"""Configuration for benchmark evaluation."""
|
||||||
|
|||||||
64
mini-nav/data_loading/insdet_scenes.py
Normal file
64
mini-nav/data_loading/insdet_scenes.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""InsDet Scenes dataset for multi-object retrieval benchmark."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from benchmarks.base import BaseDataset
|
||||||
|
from data_loading.loader import load_val_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class InsDetScenesDataset(BaseDataset):
|
||||||
|
"""InsDet-FULL/Scenes dataset with easy/hard splits.
|
||||||
|
|
||||||
|
This dataset provides scene images with object annotations from the
|
||||||
|
Instance Detection (InsDet) dataset, supporting easy and hard splits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scenes_dir: Path | str,
|
||||||
|
split: str = "easy",
|
||||||
|
):
|
||||||
|
"""Initialize InsDet Scenes dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes_dir: Path to the InsDet-FULL/Scenes directory.
|
||||||
|
split: Scene split to use ('easy' or 'hard').
|
||||||
|
"""
|
||||||
|
self.scenes_dir = Path(scenes_dir)
|
||||||
|
self.split = split
|
||||||
|
self._dataset = load_val_dataset(self.scenes_dir, split)
|
||||||
|
|
||||||
|
def get_train_split(self) -> Any:
|
||||||
|
"""Get training split (same as test for this dataset).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace Dataset for training.
|
||||||
|
"""
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
def get_test_split(self) -> Any:
|
||||||
|
"""Get test/evaluation split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HuggingFace Dataset for testing.
|
||||||
|
"""
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Get dataset length."""
|
||||||
|
return len(self._dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||||
|
"""Get a single item from the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Index of the item.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- image: PIL Image
|
||||||
|
- image_id: Scene identifier
|
||||||
|
- objects: dict with bbox, category, area, id
|
||||||
|
"""
|
||||||
|
return self._dataset[idx]
|
||||||
238
mini-nav/tests/test_multi_object_retrieval.py
Normal file
238
mini-nav/tests/test_multi_object_retrieval.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""Integration tests for multi-object retrieval benchmark pipeline.
|
||||||
|
|
||||||
|
These tests verify the end-to-end functionality of the multi-object retrieval
|
||||||
|
benchmark, including schema building, database population, and evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiObjectRetrievalIntegration:
|
||||||
|
"""Integration tests for multi-object retrieval benchmark."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_model_processor(self):
|
||||||
|
"""Create mock model and processor."""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_processor = Mock()
|
||||||
|
|
||||||
|
# Mock the feature extraction to return a fixed-size vector
|
||||||
|
def mock_extract(processor, model, image):
|
||||||
|
return [0.1] * 256 # 256-dim vector
|
||||||
|
mock_processor.images = mock_extract
|
||||||
|
|
||||||
|
return mock_model, mock_processor
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dataset(self):
|
||||||
|
"""Create a mock dataset with images and annotations."""
|
||||||
|
# Create mock items
|
||||||
|
items = []
|
||||||
|
for i in range(3):
|
||||||
|
item = {
|
||||||
|
"image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)),
|
||||||
|
"image_id": f"scene_{i}",
|
||||||
|
"objects": {
|
||||||
|
"bbox": [[10, 10, 50, 50], [60, 60, 40, 40]],
|
||||||
|
"category": ["object_a", "object_b"],
|
||||||
|
"area": [2500, 1600],
|
||||||
|
"id": [0, 1],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
mock_dataset = Mock()
|
||||||
|
mock_dataset.__len__ = Mock(return_value=len(items))
|
||||||
|
mock_dataset.__getitem__ = lambda self, idx: items[idx]
|
||||||
|
mock_dataset.with_format = lambda fmt: mock_dataset
|
||||||
|
|
||||||
|
return mock_dataset
|
||||||
|
|
||||||
|
def test_build_object_schema(self):
|
||||||
|
"""Test that object schema is built correctly."""
|
||||||
|
from benchmarks.tasks.multi_object_retrieval import _build_object_schema
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
vector_dim = 256
|
||||||
|
schema = _build_object_schema(vector_dim)
|
||||||
|
|
||||||
|
assert isinstance(schema, pa.Schema)
|
||||||
|
assert "id" in schema.names
|
||||||
|
assert "image_id" in schema.names
|
||||||
|
assert "object_id" in schema.names
|
||||||
|
assert "category" in schema.names
|
||||||
|
assert "vector" in schema.names
|
||||||
|
|
||||||
|
# Check vector field has correct dimension
|
||||||
|
vector_field = schema.field("vector")
|
||||||
|
assert isinstance(vector_field.type, pa.List)
|
||||||
|
assert vector_field.type.value_type == pa.float32()
|
||||||
|
|
||||||
|
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
|
||||||
|
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
|
||||||
|
def test_build_database_with_mocked_sam(
|
||||||
|
self,
|
||||||
|
mock_segment,
|
||||||
|
mock_load_sam,
|
||||||
|
mock_model_processor,
|
||||||
|
mock_dataset,
|
||||||
|
):
|
||||||
|
"""Test database building with mocked SAM segmentation."""
|
||||||
|
from benchmarks.tasks.multi_object_retrieval import (
|
||||||
|
MultiObjectRetrievalTask,
|
||||||
|
_build_object_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_model, mock_processor = mock_model_processor
|
||||||
|
|
||||||
|
# Mock SAM
|
||||||
|
mock_load_sam.return_value = (Mock(), Mock())
|
||||||
|
mock_segment.return_value = [
|
||||||
|
{
|
||||||
|
"segment": np.ones((224, 224), dtype=bool),
|
||||||
|
"area": 50000,
|
||||||
|
"bbox": [0, 0, 224, 224],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create task with config
|
||||||
|
task = MultiObjectRetrievalTask(
|
||||||
|
sam_model="facebook/sam2.1-hiera-large",
|
||||||
|
min_mask_area=1024,
|
||||||
|
max_masks_per_image=5,
|
||||||
|
gamma=1.0,
|
||||||
|
top_k_per_object=50,
|
||||||
|
num_query_objects=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock table
|
||||||
|
mock_table = Mock()
|
||||||
|
mock_table.schema = _build_object_schema(256)
|
||||||
|
|
||||||
|
# Build database (this should not raise)
|
||||||
|
task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
|
||||||
|
|
||||||
|
# Verify table.add was called
|
||||||
|
assert mock_table.add.called
|
||||||
|
|
||||||
|
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
|
||||||
|
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
|
||||||
|
def test_evaluate_with_mocked_sam(
|
||||||
|
self,
|
||||||
|
mock_segment,
|
||||||
|
mock_load_sam,
|
||||||
|
mock_model_processor,
|
||||||
|
mock_dataset,
|
||||||
|
):
|
||||||
|
"""Test evaluation with mocked SAM segmentation."""
|
||||||
|
from benchmarks.tasks.multi_object_retrieval import (
|
||||||
|
MultiObjectRetrievalTask,
|
||||||
|
_build_object_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_model, mock_processor = mock_model_processor
|
||||||
|
|
||||||
|
# Mock SAM
|
||||||
|
mock_load_sam.return_value = (Mock(), Mock())
|
||||||
|
mock_segment.return_value = [
|
||||||
|
{
|
||||||
|
"segment": np.ones((224, 224), dtype=bool),
|
||||||
|
"area": 50000,
|
||||||
|
"bbox": [0, 0, 224, 224],
|
||||||
|
"object_id": "query_obj_0",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create mock table with search results
|
||||||
|
mock_table = Mock()
|
||||||
|
mock_table.schema = _build_object_schema(256)
|
||||||
|
|
||||||
|
# Mock search to return matching result
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.to_polars.return_value = {
|
||||||
|
"image_id": ["scene_0"],
|
||||||
|
"object_id": ["scene_0_obj_0"],
|
||||||
|
"_distance": [0.1],
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_table.search.return_value.select.return_value.limit.return_value = mock_result
|
||||||
|
|
||||||
|
# Create task
|
||||||
|
task = MultiObjectRetrievalTask(
|
||||||
|
sam_model="facebook/sam2.1-hiera-large",
|
||||||
|
min_mask_area=1024,
|
||||||
|
max_masks_per_image=5,
|
||||||
|
gamma=1.0,
|
||||||
|
top_k_per_object=50,
|
||||||
|
num_query_objects=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
|
||||||
|
|
||||||
|
# Verify results structure
|
||||||
|
assert "accuracy" in results
|
||||||
|
assert "correct" in results
|
||||||
|
assert "total" in results
|
||||||
|
assert "top_k" in results
|
||||||
|
assert results["top_k"] == 50
|
||||||
|
|
||||||
|
def test_task_initialization_with_config(self):
|
||||||
|
"""Test task initialization with custom config."""
|
||||||
|
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
|
||||||
|
|
||||||
|
task = MultiObjectRetrievalTask(
|
||||||
|
sam_model="facebook/sam2.1-hiera-small",
|
||||||
|
min_mask_area=500,
|
||||||
|
max_masks_per_image=3,
|
||||||
|
gamma=0.5,
|
||||||
|
top_k_per_object=100,
|
||||||
|
num_query_objects=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.sam_model == "facebook/sam2.1-hiera-small"
|
||||||
|
assert task.min_mask_area == 500
|
||||||
|
assert task.max_masks_per_image == 3
|
||||||
|
assert task.config.gamma == 0.5
|
||||||
|
assert task.config.top_k_per_object == 100
|
||||||
|
assert task.config.num_query_objects == 5
|
||||||
|
|
||||||
|
def test_task_initialization_defaults(self):
|
||||||
|
"""Test task initialization with default config."""
|
||||||
|
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
|
||||||
|
|
||||||
|
task = MultiObjectRetrievalTask()
|
||||||
|
|
||||||
|
# Check defaults from BenchmarkTaskConfig
|
||||||
|
assert task.config.gamma == 1.0
|
||||||
|
assert task.config.top_k_per_object == 50
|
||||||
|
assert task.config.num_query_objects == 3
|
||||||
|
# SAM settings from ModelConfig defaults
|
||||||
|
assert task.sam_model == "facebook/sam2.1-hiera-large"
|
||||||
|
assert task.min_mask_area == 1024
|
||||||
|
assert task.max_masks_per_image == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestInsDetScenesDataset:
|
||||||
|
"""Tests for InsDetScenesDataset class."""
|
||||||
|
|
||||||
|
def test_dataset_class_exists(self):
|
||||||
|
"""Test that InsDetScenesDataset can be imported."""
|
||||||
|
from data_loading.insdet_scenes import InsDetScenesDataset
|
||||||
|
|
||||||
|
assert InsDetScenesDataset is not None
|
||||||
|
|
||||||
|
@patch("data_loading.insdet_scenes.load_val_dataset")
|
||||||
|
def test_dataset_loads_correct_split(self, mock_load):
|
||||||
|
"""Test dataset loads correct split."""
|
||||||
|
from data_loading.insdet_scenes import InsDetScenesDataset
|
||||||
|
|
||||||
|
mock_load.return_value = Mock()
|
||||||
|
|
||||||
|
dataset = InsDetScenesDataset("/path/to/scenes", split="easy")
|
||||||
|
|
||||||
|
mock_load.assert_called_once_with("/path/to/scenes", "easy")
|
||||||
|
assert dataset.split == "easy"
|
||||||
168
mini-nav/tests/test_sam.py
Normal file
168
mini-nav/tests/test_sam.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Tests for SAM segmentation utilities.
|
||||||
|
|
||||||
|
Note: These tests mock the SAM model loading since SAM requires
|
||||||
|
heavy model weights. The actual SAM integration should be tested
|
||||||
|
separately in integration tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class TestSAMSegmentation:
|
||||||
|
"""Test suite for SAM segmentation utilities."""
|
||||||
|
|
||||||
|
def test_segment_image_empty_masks(self):
|
||||||
|
"""Test segment_image returns empty list when no masks generated."""
|
||||||
|
from utils.sam import segment_image
|
||||||
|
|
||||||
|
# Create mock mask generator that returns empty list
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_generator.generate.return_value = []
|
||||||
|
|
||||||
|
result = segment_image(mock_generator, Image.new("RGB", (100, 100)))
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_segment_image_filters_small_masks(self):
|
||||||
|
"""Test segment_image filters masks below min_area threshold."""
|
||||||
|
from utils.sam import segment_image
|
||||||
|
|
||||||
|
# Create mock masks with different areas
|
||||||
|
small_mask = {
|
||||||
|
"segment": np.zeros((10, 10), dtype=bool),
|
||||||
|
"area": 50, # Below 32*32 = 1024
|
||||||
|
"bbox": [0, 0, 10, 10],
|
||||||
|
"predicted_iou": 0.9,
|
||||||
|
"stability_score": 0.8,
|
||||||
|
}
|
||||||
|
large_mask = {
|
||||||
|
"segment": np.ones((100, 100), dtype=bool),
|
||||||
|
"area": 10000, # Above threshold
|
||||||
|
"bbox": [0, 0, 100, 100],
|
||||||
|
"predicted_iou": 0.95,
|
||||||
|
"stability_score": 0.9,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_generator.generate.return_value = [small_mask, large_mask]
|
||||||
|
|
||||||
|
result = segment_image(
|
||||||
|
mock_generator,
|
||||||
|
Image.new("RGB", (100, 100)),
|
||||||
|
min_area=32 * 32,
|
||||||
|
max_masks=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return the large mask
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["area"] == 10000
|
||||||
|
|
||||||
|
def test_segment_image_limits_max_masks(self):
|
||||||
|
"""Test segment_image limits to max_masks largest masks."""
|
||||||
|
from utils.sam import segment_image
|
||||||
|
|
||||||
|
# Create 10 masks with different areas
|
||||||
|
masks = [
|
||||||
|
{
|
||||||
|
"segment": np.ones((i + 1, i + 1), dtype=bool),
|
||||||
|
"area": (i + 1) * (i + 1),
|
||||||
|
"bbox": [0, 0, i + 1, i + 1],
|
||||||
|
"predicted_iou": 0.9,
|
||||||
|
"stability_score": 0.8,
|
||||||
|
}
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_generator.generate.return_value = masks
|
||||||
|
|
||||||
|
result = segment_image(
|
||||||
|
mock_generator,
|
||||||
|
Image.new("RGB", (100, 100)),
|
||||||
|
min_area=1,
|
||||||
|
max_masks=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return top 3 largest masks
|
||||||
|
assert len(result) == 3
|
||||||
|
# Check they are sorted by area (largest first)
|
||||||
|
areas = [m["area"] for m in result]
|
||||||
|
assert areas == sorted(areas, reverse=True)
|
||||||
|
|
||||||
|
def test_segment_image_sorted_by_area(self):
|
||||||
|
"""Test segment_image returns masks sorted by area descending."""
|
||||||
|
from utils.sam import segment_image
|
||||||
|
|
||||||
|
# Create masks with known areas (unordered)
|
||||||
|
mask1 = {"segment": np.ones((5, 5), dtype=bool), "area": 25, "bbox": [0, 0, 5, 5]}
|
||||||
|
mask2 = {"segment": np.ones((10, 10), dtype=bool), "area": 100, "bbox": [0, 0, 10, 10]}
|
||||||
|
mask3 = {"segment": np.ones((3, 3), dtype=bool), "area": 9, "bbox": [0, 0, 3, 3]}
|
||||||
|
|
||||||
|
mock_generator = Mock()
|
||||||
|
mock_generator.generate.return_value = [mask1, mask2, mask3]
|
||||||
|
|
||||||
|
result = segment_image(
|
||||||
|
mock_generator,
|
||||||
|
Image.new("RGB", (100, 100)),
|
||||||
|
min_area=1,
|
||||||
|
max_masks=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be sorted by area descending
|
||||||
|
assert result[0]["area"] == 100
|
||||||
|
assert result[1]["area"] == 25
|
||||||
|
assert result[2]["area"] == 9
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractMaskedRegion:
|
||||||
|
"""Test suite for extracting masked regions from images."""
|
||||||
|
|
||||||
|
def test_extract_masked_region_binary(self):
|
||||||
|
"""Test extracting masked region with binary mask."""
|
||||||
|
from utils.sam import extract_masked_region
|
||||||
|
|
||||||
|
# Create a simple image
|
||||||
|
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||||
|
|
||||||
|
# Create a binary mask (half kept, half masked)
|
||||||
|
mask = np.zeros((10, 10), dtype=bool)
|
||||||
|
mask[:, :5] = True
|
||||||
|
|
||||||
|
result = extract_masked_region(image, mask)
|
||||||
|
|
||||||
|
# Check that left half is red, right half is black
|
||||||
|
result_np = np.array(result)
|
||||||
|
left_half = result_np[:, :5, :]
|
||||||
|
right_half = result_np[:, 5:, :]
|
||||||
|
|
||||||
|
assert np.all(left_half == [255, 0, 0])
|
||||||
|
assert np.all(right_half == [0, 0, 0])
|
||||||
|
|
||||||
|
def test_extract_masked_region_all_masked(self):
|
||||||
|
"""Test extracting when entire image is masked."""
|
||||||
|
from utils.sam import extract_masked_region
|
||||||
|
|
||||||
|
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||||
|
mask = np.ones((10, 10), dtype=bool)
|
||||||
|
|
||||||
|
result = extract_masked_region(image, mask)
|
||||||
|
result_np = np.array(result)
|
||||||
|
|
||||||
|
# Entire image should be preserved
|
||||||
|
assert np.all(result_np == [255, 0, 0])
|
||||||
|
|
||||||
|
def test_extract_masked_region_all_zero_mask(self):
|
||||||
|
"""Test extracting when mask is all zeros."""
|
||||||
|
from utils.sam import extract_masked_region
|
||||||
|
|
||||||
|
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||||
|
mask = np.zeros((10, 10), dtype=bool)
|
||||||
|
|
||||||
|
result = extract_masked_region(image, mask)
|
||||||
|
result_np = np.array(result)
|
||||||
|
|
||||||
|
# Entire image should be black
|
||||||
|
assert np.all(result_np == [0, 0, 0])
|
||||||
121
mini-nav/tests/test_scene_scoring.py
Normal file
121
mini-nav/tests/test_scene_scoring.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Tests for scene scoring algorithm in multi-object retrieval."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from benchmarks.tasks.multi_object_retrieval import _compute_scene_score
|
||||||
|
|
||||||
|
|
||||||
|
class TestSceneScoringAlgorithm:
|
||||||
|
"""Test suite for scene scoring with co-occurrence penalty."""
|
||||||
|
|
||||||
|
def test_scene_score_basic(self):
|
||||||
|
"""Test basic scene scoring with single match."""
|
||||||
|
query_object_ids = ["obj_1", "obj_2", "obj_3"]
|
||||||
|
|
||||||
|
# Scene A has obj_1
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_A": [("distance_1", "obj_1")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
# Hit rate = 1/3, similarity = 1/(1+distance_1)
|
||||||
|
assert "scene_A" in scores
|
||||||
|
assert scores["scene_A"] > 0
|
||||||
|
|
||||||
|
def test_scene_score_no_match(self):
|
||||||
|
"""Test scene scoring when no objects match."""
|
||||||
|
query_object_ids = ["obj_1", "obj_2", "obj_3"]
|
||||||
|
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_A": [("distance_1", "other_obj")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
assert scores["scene_A"] == 0.0
|
||||||
|
|
||||||
|
def test_scene_score_multiple_scenes(self):
|
||||||
|
"""Test scoring across multiple scenes."""
|
||||||
|
query_object_ids = ["obj_1", "obj_2"]
|
||||||
|
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_A": [("0.1", "obj_1")],
|
||||||
|
"scene_B": [("0.1", "obj_2")],
|
||||||
|
"scene_C": [("0.1", "other")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
# Scenes with matches should have positive scores
|
||||||
|
assert scores["scene_A"] > 0
|
||||||
|
assert scores["scene_B"] > 0
|
||||||
|
# Scene C has no match, score should be 0
|
||||||
|
assert scores["scene_C"] == 0.0
|
||||||
|
|
||||||
|
def test_scene_score_gamma_zero(self):
|
||||||
|
"""Test scoring with gamma=0 (no penalty)."""
|
||||||
|
query_object_ids = ["obj_1", "obj_2", "obj_3", "obj_4", "obj_5"]
|
||||||
|
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_A": [("0.1", "obj_1")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores_gamma_0 = _compute_scene_score(query_object_ids, retrieved_results, gamma=0.0)
|
||||||
|
scores_gamma_1 = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
# With gamma=0, hit_rate^0 = 1, so score = similarity
|
||||||
|
# With gamma=1, hit_rate^1 = 1/5, so score = similarity * 1/5
|
||||||
|
# scores_gamma_0 should be larger
|
||||||
|
assert scores_gamma_0["scene_A"] > scores_gamma_1["scene_A"]
|
||||||
|
|
||||||
|
def test_scene_score_multiple_matches(self):
|
||||||
|
"""Test scoring when scene has multiple matching objects."""
|
||||||
|
query_object_ids = ["obj_1", "obj_2"]
|
||||||
|
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_A": [("0.1", "obj_1"), ("0.2", "obj_2")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
# Both objects match, hit_rate = 2/2 = 1.0
|
||||||
|
# Score = (1/(1+0.1) + 1/(1+0.2)) * 1.0
|
||||||
|
expected_similarity = 1.0 / 1.1 + 1.0 / 1.2
|
||||||
|
assert abs(scores["scene_A"] - expected_similarity) < 0.01
|
||||||
|
|
||||||
|
def test_scene_score_distance_to_similarity(self):
|
||||||
|
"""Test that smaller distance yields higher score."""
|
||||||
|
query_object_ids = ["obj_1"]
|
||||||
|
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_close": [("0.01", "obj_1")],
|
||||||
|
"scene_far": [("10.0", "obj_1")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
# Closer scene should have higher score
|
||||||
|
assert scores["scene_close"] > scores["scene_far"]
|
||||||
|
|
||||||
|
def test_scene_score_empty_results(self):
|
||||||
|
"""Test scoring with empty retrieved results."""
|
||||||
|
query_object_ids = ["obj_1", "obj_2"]
|
||||||
|
|
||||||
|
retrieved_results = {}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
assert scores == {}
|
||||||
|
|
||||||
|
def test_scene_score_empty_query(self):
|
||||||
|
"""Test scoring with empty query objects."""
|
||||||
|
query_object_ids = []
|
||||||
|
|
||||||
|
retrieved_results = {
|
||||||
|
"scene_A": [("0.1", "obj_1")],
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||||
|
|
||||||
|
# With empty query, no scenes should have positive score
|
||||||
|
assert all(score == 0.0 for score in scores.values())
|
||||||
100
mini-nav/utils/sam.py
Normal file
100
mini-nav/utils/sam.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""SAM (Segment Anything Model) utilities for object segmentation."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from sam2.build_sam import build_sam2
|
||||||
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def load_sam_model(
|
||||||
|
model_name: str = "facebook/sam2.1-hiera-large",
|
||||||
|
device: str = "cuda",
|
||||||
|
checkpoint_dir: Path | None = None,
|
||||||
|
) -> tuple[Any, Any]:
|
||||||
|
"""Load SAM 2.1 model and mask generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: SAM model name (currently supports facebook/sam2.1-hiera-*).
|
||||||
|
device: Device to load model on (cuda or cpu).
|
||||||
|
checkpoint_dir: Optional directory for model checkpoint cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (sam_model, mask_generator).
|
||||||
|
"""
|
||||||
|
if device == "cuda" and not torch.cuda.is_available():
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# Build SAM2 model
|
||||||
|
sam_model = build_sam2(model_name, device=device)
|
||||||
|
|
||||||
|
# Create automatic mask generator
|
||||||
|
mask_generator = SAM2AutomaticMaskGenerator(sam_model)
|
||||||
|
|
||||||
|
return sam_model, mask_generator
|
||||||
|
|
||||||
|
|
||||||
|
def segment_image(
|
||||||
|
mask_generator: Any,
|
||||||
|
image: Image.Image,
|
||||||
|
min_area: int = 32 * 32,
|
||||||
|
max_masks: int = 5,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Segment image using SAM to extract object masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_generator: SAM2AutomaticMaskGenerator instance.
|
||||||
|
image: PIL Image to segment.
|
||||||
|
min_area: Minimum mask area threshold in pixels.
|
||||||
|
max_masks: Maximum number of masks to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of mask dictionaries with keys:
|
||||||
|
- segment: Binary mask (numpy array)
|
||||||
|
- area: Mask area in pixels
|
||||||
|
- bbox: Bounding box [x, y, width, height]
|
||||||
|
- predicted_iou: Model's confidence in the mask
|
||||||
|
- stability_score: Stability score for the mask
|
||||||
|
"""
|
||||||
|
# Convert PIL Image to numpy array
|
||||||
|
image_np = np.array(image.convert("RGB"))
|
||||||
|
|
||||||
|
# Generate masks
|
||||||
|
masks = mask_generator.generate(image_np)
|
||||||
|
|
||||||
|
if not masks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Filter by minimum area
|
||||||
|
filtered_masks = [m for m in masks if m["area"] >= min_area]
|
||||||
|
|
||||||
|
if not filtered_masks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by area (largest first) and limit to max_masks
|
||||||
|
sorted_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True)
|
||||||
|
return sorted_masks[:max_masks]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_masked_region(
|
||||||
|
image: Image.Image,
|
||||||
|
mask: np.ndarray,
|
||||||
|
) -> Image.Image:
|
||||||
|
"""Extract masked region from image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Original PIL Image.
|
||||||
|
mask: Binary mask as numpy array (True = keep).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image with only the masked region visible.
|
||||||
|
"""
|
||||||
|
image_np = np.array(image.convert("RGB"))
|
||||||
|
|
||||||
|
# Apply mask
|
||||||
|
masked_np = image_np * mask[:, :, np.newaxis]
|
||||||
|
|
||||||
|
return Image.fromarray(masked_np.astype(np.uint8))
|
||||||
Reference in New Issue
Block a user