mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
"""SAM (Segment Anything Model) utilities for object segmentation."""
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
|
|
def load_sam_model(
|
|
model_name: str = "facebook/sam2.1-hiera-large",
|
|
device: str = "cuda",
|
|
checkpoint_dir: Path | None = None,
|
|
) -> tuple[Any, Any]:
|
|
"""Load SAM 2.1 model and mask generator.
|
|
|
|
Args:
|
|
model_name: SAM model name (currently supports facebook/sam2.1-hiera-*).
|
|
device: Device to load model on (cuda or cpu).
|
|
checkpoint_dir: Optional directory for model checkpoint cache.
|
|
|
|
Returns:
|
|
Tuple of (sam_model, mask_generator).
|
|
"""
|
|
if device == "cuda" and not torch.cuda.is_available():
|
|
device = "cpu"
|
|
|
|
# Build SAM2 model
|
|
sam_model = build_sam2(model_name, device=device)
|
|
|
|
# Create automatic mask generator
|
|
mask_generator = SAM2AutomaticMaskGenerator(sam_model)
|
|
|
|
return sam_model, mask_generator
|
|
|
|
|
|
def segment_image(
|
|
mask_generator: Any,
|
|
image: Image.Image,
|
|
min_area: int = 32 * 32,
|
|
max_masks: int = 5,
|
|
) -> list[dict[str, Any]]:
|
|
"""Segment image using SAM to extract object masks.
|
|
|
|
Args:
|
|
mask_generator: SAM2AutomaticMaskGenerator instance.
|
|
image: PIL Image to segment.
|
|
min_area: Minimum mask area threshold in pixels.
|
|
max_masks: Maximum number of masks to return.
|
|
|
|
Returns:
|
|
List of mask dictionaries with keys:
|
|
- segment: Binary mask (numpy array)
|
|
- area: Mask area in pixels
|
|
- bbox: Bounding box [x, y, width, height]
|
|
- predicted_iou: Model's confidence in the mask
|
|
- stability_score: Stability score for the mask
|
|
"""
|
|
# Convert PIL Image to numpy array
|
|
image_np = np.array(image.convert("RGB"))
|
|
|
|
# Generate masks
|
|
masks = mask_generator.generate(image_np)
|
|
|
|
if not masks:
|
|
return []
|
|
|
|
# Filter by minimum area
|
|
filtered_masks = [m for m in masks if m["area"] >= min_area]
|
|
|
|
if not filtered_masks:
|
|
return []
|
|
|
|
# Sort by area (largest first) and limit to max_masks
|
|
sorted_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True)
|
|
return sorted_masks[:max_masks]
|
|
|
|
|
|
def extract_masked_region(
|
|
image: Image.Image,
|
|
mask: np.ndarray,
|
|
) -> Image.Image:
|
|
"""Extract masked region from image.
|
|
|
|
Args:
|
|
image: Original PIL Image.
|
|
mask: Binary mask as numpy array (True = keep).
|
|
|
|
Returns:
|
|
PIL Image with only the masked region visible.
|
|
"""
|
|
image_np = np.array(image.convert("RGB"))
|
|
|
|
# Apply mask
|
|
masked_np = image_np * mask[:, :, np.newaxis]
|
|
|
|
return Image.fromarray(masked_np.astype(np.uint8))
|