From a7b01cb49e08ef8e2da4758f2b611f4ad9551590 Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Mon, 2 Mar 2026 14:22:44 +0800 Subject: [PATCH] feat(compressors): add SAM+DINO+Hash pipeline for object feature extraction --- mini-nav/compressors/__init__.py | 5 + mini-nav/compressors/dino_compressor.py | 85 +++++- mini-nav/compressors/pipeline.py | 170 +++++++++++ mini-nav/compressors/segament_compressor.py | 180 ++++++++++++ mini-nav/configs/config.yaml | 4 + mini-nav/configs/models.py | 14 + mini-nav/tests/test_compressors.py | 303 ++++++++++++++++++++ 7 files changed, 753 insertions(+), 8 deletions(-) create mode 100644 mini-nav/compressors/pipeline.py create mode 100644 mini-nav/tests/test_compressors.py diff --git a/mini-nav/compressors/__init__.py b/mini-nav/compressors/__init__.py index 3e1ec0e..91394bd 100644 --- a/mini-nav/compressors/__init__.py +++ b/mini-nav/compressors/__init__.py @@ -1,6 +1,8 @@ from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits from .dino_compressor import DinoCompressor from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask +from .pipeline import SAMHashPipeline, create_pipeline_from_config +from .segament_compressor import SegmentCompressor from .train import train __all__ = [ @@ -9,6 +11,9 @@ __all__ = [ "HashCompressor", "HashLoss", "VideoPositiveMask", + "SegmentCompressor", + "SAMHashPipeline", + "create_pipeline_from_config", "BinarySign", "hamming_distance", "hamming_similarity", diff --git a/mini-nav/compressors/dino_compressor.py b/mini-nav/compressors/dino_compressor.py index 33a948a..e991700 100644 --- a/mini-nav/compressors/dino_compressor.py +++ b/mini-nav/compressors/dino_compressor.py @@ -1,8 +1,10 @@ -from typing import Optional, cast +from typing import Optional +import torch +import torch.nn as nn import torch.nn.functional as F -from torch import nn -from transformers import AutoModel, Dinov2Model +from PIL import Image +from transformers import AutoImageProcessor, AutoModel class DinoCompressor(nn.Module): @@ -10,15 +12,34 @@ class DinoCompressor(nn.Module): When compressor is None: returns normalized DINO embeddings. When compressor is provided: returns binary hash bits for CAM storage. + + Supports both PIL Image input and pre-extracted tokens. """ - def __init__(self, compressor: Optional[nn.Module] = None): + def __init__( + self, + model_name: str = "facebook/dinov2-large", + compressor: Optional[nn.Module] = None, + device: Optional[str] = None, + ): + """Initialize DINOv2 extractor. + + Args: + model_name: HuggingFace model name + compressor: Optional hash compressor for producing binary codes + device: Device to load model on + """ super().__init__() - self.dino = cast( - Dinov2Model, - AutoModel.from_pretrained("facebook/dinov2-large"), - ) + # Auto detect device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = torch.device(device) + + self.model_name = model_name + self.processor = AutoImageProcessor.from_pretrained(model_name) + self.dino = AutoModel.from_pretrained(model_name).to(self.device) + self.dino.eval() self.compressor = compressor @@ -34,3 +55,51 @@ class DinoCompressor(nn.Module): # HashCompressor returns (logits, hash_codes, bits) _, _, bits = self.compressor(teacher_tokens) return bits # [B, 512] binary bits for CAM + + def extract_features(self, images: list[Image.Image]) -> torch.Tensor: + """Extract DINO features from a list of cropped object images. + + Args: + images: List of PIL Images (cropped objects) + + Returns: + DINO features [N, feature_dim], normalized + """ + if len(images) == 0: + return torch.empty(0, self.dino.config.hidden_size, device=self.device) + + # Process batch of images + inputs = self.processor(images, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = self.dino(**inputs) + + # Pool tokens to get global representation + features = outputs.last_hidden_state.mean(dim=1) # [N, 1024] + features = F.normalize(features, dim=-1) + + return features + + def encode(self, images: list[Image.Image]) -> torch.Tensor: + """Extract features from images and optionally compress to hash codes. + + Args: + images: List of PIL Images + + Returns: + If compressor is None: DINO features [N, 1024] + If compressor is set: Binary hash bits [N, 512] + """ + if self.compressor is None: + return self.extract_features(images) + + # Extract features first + features = self.extract_features(images) # [N, 1024] + + # Add sequence dimension for compressor (expects [B, N, dim]) + features = features.unsqueeze(1) # [N, 1, 1024] + + # Compress to hash codes + _, _, bits = self.compressor(features) + + return bits diff --git a/mini-nav/compressors/pipeline.py b/mini-nav/compressors/pipeline.py new file mode 100644 index 0000000..e51afa2 --- /dev/null +++ b/mini-nav/compressors/pipeline.py @@ -0,0 +1,170 @@ +"""Complete pipeline for SAM + DINO + HashCompressor. + +This pipeline extracts object masks from images using SAM2.1, +crops the objects, extracts features using DINOv2, +and compresses them to binary hash codes using HashCompressor. +""" + +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from PIL import Image + +from .dino_compressor import DinoCompressor +from .hash_compressor import HashCompressor +from .segament_compressor import SegmentCompressor + + +def create_pipeline_from_config(config) -> "SAMHashPipeline": + """Create SAMHashPipeline from a config object. + + Args: + config: Configuration object with model settings + + Returns: + Initialized SAMHashPipeline + """ + return SAMHashPipeline( + sam_model=config.model.sam_model, + dino_model=config.model.name, + hash_bits=config.model.compression_dim, + sam_min_mask_area=config.model.sam_min_mask_area, + sam_max_masks=config.model.sam_max_masks, + compressor_path=config.model.compressor_path, + device=config.model.device if config.model.device != "auto" else None, + ) + + +class SAMHashPipeline(nn.Module): + """Complete pipeline: SAM segmentation + DINO features + Hash compression. + + Pipeline flow: + Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes) + + Usage: + # Initialize with config + pipeline = SAMHashPipeline( + sam_model="facebook/sam2.1-hiera-large", + dino_model="facebook/dinov2-large", + hash_bits=512, + ) + + # Process image + image = Image.open("path/to/image.jpg") + hash_codes = pipeline(image) # [N, 512] binary bits + """ + + def __init__( + self, + sam_model: str = "facebook/sam2.1-hiera-large", + dino_model: str = "facebook/dinov2-large", + hash_bits: int = 512, + sam_min_mask_area: int = 100, + sam_max_masks: int = 10, + compressor_path: Optional[str] = None, + device: Optional[str] = None, + ): + """Initialize the complete pipeline. + + Args: + sam_model: SAM model name from HuggingFace + dino_model: DINOv2 model name from HuggingFace + hash_bits: Number of bits in hash code + sam_min_mask_area: Minimum mask area threshold + sam_max_masks: Maximum number of masks to keep + compressor_path: Optional path to trained HashCompressor weights + device: Device to run models on + """ + super().__init__() + + # Auto detect device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = torch.device(device) + + # Initialize components + self.segmentor = SegmentCompressor( + model_name=sam_model, + min_mask_area=sam_min_mask_area, + max_masks=sam_max_masks, + device=device, + ) + + # HashCompressor expects DINO features (1024 dim for dinov2-large) + dino_dim = 1024 if "large" in dino_model else 768 + self.hash_compressor = HashCompressor( + input_dim=dino_dim, hash_bits=hash_bits + ).to(device) + + # Load pretrained compressor if provided + if compressor_path is not None: + self.hash_compressor.load_state_dict( + torch.load(compressor_path, map_location=device) + ) + print(f"[OK] Loaded HashCompressor from {compressor_path}") + + self.dino = DinoCompressor( + model_name=dino_model, + compressor=self.hash_compressor, + device=device, + ) + + def forward(self, image: Image.Image) -> torch.Tensor: + """Process a single image through the complete pipeline. + + Args: + image: Input PIL Image + + Returns: + Binary hash codes [N, hash_bits] where N is number of detected objects + """ + # Step 1: SAM - extract and crop objects + cropped_objects = self.segmentor(image) + + if len(cropped_objects) == 0: + # No objects detected, return empty tensor + return torch.empty( + 0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device + ) + + # Step 2: DINO - extract features from cropped objects + # Step 3: HashCompressor - compress features to binary codes + hash_codes = self.dino.encode(cropped_objects) + + return hash_codes + + def extract_features( + self, image: Image.Image, use_hash: bool = False + ) -> torch.Tensor: + """Extract features from image with optional hash compression. + + Args: + image: Input PIL Image + use_hash: If True, return binary hash codes; else return DINO features + + Returns: + Features [N, dim] where dim is 1024 (DINO) or 512 (hash) + """ + cropped_objects = self.segmentor(image) + + if len(cropped_objects) == 0: + dim = self.hash_compressor.hash_bits if use_hash else 1024 + return torch.empty(0, dim, device=self.device) + + if use_hash: + return self.dino.encode(cropped_objects) + else: + return self.dino.extract_features(cropped_objects) + + def extract_masks(self, image: Image.Image) -> list[torch.Tensor]: + """Extract only masks without full processing (for debugging). + + Args: + image: Input PIL Image + + Returns: + List of binary masks [H, W] + """ + return self.segmentor.extract_masks(image) diff --git a/mini-nav/compressors/segament_compressor.py b/mini-nav/compressors/segament_compressor.py index e69de29..9b32ad0 100644 --- a/mini-nav/compressors/segament_compressor.py +++ b/mini-nav/compressors/segament_compressor.py @@ -0,0 +1,180 @@ +"""Segment Anything 2 feature extractor with mask filtering and image cropping. + +Extracts object masks from images using SAM2.1, filters by area and confidence, +then crops the original image to obtain individual object regions. +""" + +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from transformers import AutoModelForMaskGeneration, AutoProcessor + + +class SegmentCompressor(nn.Module): + """SAM2.1 based segmenter with mask filtering. + + Extracts object masks from images, filters by area and confidence, + and crops the original image to produce individual object patches. + """ + + def __init__( + self, + model_name: str = "facebook/sam2.1-hiera-large", + min_mask_area: int = 100, + max_masks: int = 10, + device: Optional[str] = None, + ): + """Initialize SAM2.1 segmenter. + + Args: + model_name: HuggingFace model name for SAM2.1 + min_mask_area: Minimum mask pixel area threshold + max_masks: Maximum number of masks to keep + device: Device to load model on (auto-detect if None) + """ + super().__init__() + + self.model_name = model_name + self.min_mask_area = min_mask_area + self.max_masks = max_masks + + # Auto detect device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = torch.device(device) + + # Load SAM model and processor + self.processor = AutoProcessor.from_pretrained(model_name) + self.model = AutoModelForMaskGeneration.from_pretrained(model_name).to( + self.device + ) + self.model.eval() + + def forward(self, image: Image.Image) -> list[Image.Image]: + """Extract object masks and crop object regions. + + Args: + image: Input PIL Image + + Returns: + List of cropped object images (one per valid mask) + """ + # Run SAM inference + inputs = self.processor(image, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Post-process masks + masks = self.processor.post_process_masks( + outputs.pred_masks, + inputs["original_sizes"], + inputs["reshaped_input_sizes"], + )[0] + + # Filter masks by area and confidence + valid_masks = self._filter_masks(masks) + + if len(valid_masks) == 0: + return [] + + # Crop object regions from original image + cropped_objects = self._crop_objects(image, valid_masks) + + return cropped_objects + + def _filter_masks(self, masks: torch.Tensor) -> list[dict]: + """Filter masks by area and keep top-N. + + Args: + masks: Predicted masks [N, H, W] + + Returns: + List of mask dictionaries with 'mask' and 'area' + """ + valid_masks = [] + + for mask in masks: + # Calculate mask area + area = mask.sum().item() + + # Filter by minimum area + if area < self.min_mask_area: + continue + + valid_masks.append({"mask": mask, "area": area}) + + # Sort by area (descending) and keep top-N + valid_masks = sorted(valid_masks, key=lambda x: x["area"], reverse=True) + valid_masks = valid_masks[: self.max_masks] + + return valid_masks + + def _crop_objects( + self, image: Image.Image, masks: list[dict] + ) -> list[Image.Image]: + """Crop object regions from image using masks. + + Args: + image: Original PIL Image + masks: List of mask dictionaries + + Returns: + List of cropped object images + """ + # Convert PIL to numpy for processing + image_np = np.array(image) + h, w = image_np.shape[:2] + + cropped_objects = [] + + for mask_info in masks: + mask = mask_info["mask"].cpu().numpy() + + # Find bounding box from mask + rows = mask.any(axis=1) + cols = mask.any(axis=0) + + if not rows.any() or not cols.any(): + continue + + y_min, y_max = rows.argmax(), h - rows[::-1].argmax() - 1 + x_min, x_max = cols.argmax(), w - cols[::-1].argmax() - 1 + + # Add small padding + pad = 5 + x_min = max(0, x_min - pad) + y_min = max(0, y_min - pad) + x_max = min(w, x_max + pad) + y_max = min(h, y_max + pad) + + # Crop + cropped = image.crop((x_min, y_min, x_max, y_max)) + cropped_objects.append(cropped) + + return cropped_objects + + @torch.no_grad() + def extract_masks(self, image: Image.Image) -> list[torch.Tensor]: + """Extract only masks without cropping (for debugging). + + Args: + image: Input PIL Image + + Returns: + List of binary masks [H, W] + """ + inputs = self.processor(image, return_tensors="pt").to(self.device) + outputs = self.model(**inputs) + + masks = self.processor.post_process_masks( + outputs.pred_masks, + inputs["original_sizes"], + inputs["reshaped_input_sizes"], + )[0] + + valid_masks = self._filter_masks(masks) + return [m["mask"] for m in valid_masks] diff --git a/mini-nav/configs/config.yaml b/mini-nav/configs/config.yaml index 9c05347..230f99a 100644 --- a/mini-nav/configs/config.yaml +++ b/mini-nav/configs/config.yaml @@ -2,6 +2,10 @@ model: name: "facebook/dinov2-large" compression_dim: 512 device: "auto" # auto-detect GPU + sam_model: "facebook/sam2.1-hiera-large" # SAM model name + sam_min_mask_area: 100 # Minimum mask area threshold + sam_max_masks: 10 # Maximum number of masks to keep + compressor_path: null # Path to trained HashCompressor weights (optional) output: directory: "./outputs" diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py index a87d837..ec26b2b 100644 --- a/mini-nav/configs/models.py +++ b/mini-nav/configs/models.py @@ -1,6 +1,7 @@ """Pydantic data models for feature compressor configuration.""" from pathlib import Path +from typing import Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -15,6 +16,19 @@ class ModelConfig(BaseModel): default=512, gt=0, description="Output feature dimension" ) device: str = "auto" + sam_model: str = Field( + default="facebook/sam2.1-hiera-large", + description="SAM model name from HuggingFace", + ) + sam_min_mask_area: int = Field( + default=100, gt=0, description="Minimum mask area threshold" + ) + sam_max_masks: int = Field( + default=10, gt=0, description="Maximum number of masks to keep" + ) + compressor_path: Optional[str] = Field( + default=None, description="Path to trained HashCompressor weights" + ) class OutputConfig(BaseModel): diff --git a/mini-nav/tests/test_compressors.py b/mini-nav/tests/test_compressors.py new file mode 100644 index 0000000..374861b --- /dev/null +++ b/mini-nav/tests/test_compressors.py @@ -0,0 +1,303 @@ +"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline).""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import torch +from PIL import Image + +from configs import cfg_manager +from compressors import ( + BinarySign, + DinoCompressor, + HashCompressor, + SegmentCompressor, + SAMHashPipeline, + create_pipeline_from_config, + bits_to_hash, + hash_to_bits, + hamming_distance, + hamming_similarity, +) + + +class TestHashCompressor: + """Test suite for HashCompressor.""" + + def test_hash_compressor_init(self): + """Verify HashCompressor initializes with correct dimensions.""" + compressor = HashCompressor(input_dim=1024, hash_bits=512) + assert compressor.input_dim == 1024 + assert compressor.hash_bits == 512 + + def test_hash_compressor_forward(self): + """Verify forward pass produces correct output shapes.""" + compressor = HashCompressor(input_dim=1024, hash_bits=512) + tokens = torch.randn(4, 197, 1024) # [B, N, input_dim] + + logits, hash_codes, bits = compressor(tokens) + + assert logits.shape == (4, 512) + assert hash_codes.shape == (4, 512) + assert bits.shape == (4, 512) + # Verify bits are binary (0 or 1) + assert torch.all((bits == 0) | (bits == 1)) + + def test_hash_compressor_encode(self): + """Verify encode method returns binary bits.""" + compressor = HashCompressor(input_dim=1024, hash_bits=512) + tokens = torch.randn(2, 197, 1024) + + bits = compressor.encode(tokens) + + assert bits.shape == (2, 512) + assert bits.dtype == torch.int32 + assert torch.all((bits == 0) | (bits == 1)) + + def test_hash_compressor_similarity(self): + """Verify compute_similarity returns correct shape.""" + compressor = HashCompressor(input_dim=1024, hash_bits=512) + + # Create random bits + bits1 = torch.randint(0, 2, (3, 512)) + bits2 = torch.randint(0, 2, (5, 512)) + + sim = compressor.compute_similarity(bits1, bits2) + + assert sim.shape == (3, 5) + + +class TestBinarySign: + """Test suite for BinarySign function.""" + + def test_binary_sign_forward(self): + """Verify BinarySign produces {-1, +1} outputs.""" + x = torch.randn(4, 512) + result = BinarySign.apply(x) + + assert torch.all((result == 1) | (result == -1)) + + def test_binary_sign_round_trip(self): + """Verify bits -> hash -> bits preserves values.""" + bits = torch.randint(0, 2, (4, 512)) + hash_codes = bits_to_hash(bits) + bits_recovered = hash_to_bits(hash_codes) + + assert torch.equal(bits, bits_recovered) + + +class TestHammingMetrics: + """Test suite for Hamming distance and similarity.""" + + def test_hamming_distance_same_codes(self): + """Verify hamming distance is 0 for identical single codes.""" + bits1 = torch.randint(0, 2, (512,)) + bits2 = bits1.clone() + + dist = hamming_distance(bits1, bits2) + + assert dist.item() == 0 + + def test_hamming_distance_self_comparison(self): + """Verify hamming distance diagonal is 0 (each code compared to itself).""" + bits = torch.randint(0, 2, (10, 512)) + + dist = hamming_distance(bits, bits) + + # Diagonal should be 0 (distance to self) + diagonal = torch.diag(dist) + assert torch.all(diagonal == 0) + + def test_hamming_distance_different(self): + """Verify hamming distance is correct for different codes.""" + bits1 = torch.zeros(1, 512, dtype=torch.int32) + bits2 = torch.ones(1, 512, dtype=torch.int32) + + dist = hamming_distance(bits1, bits2) + + assert dist.item() == 512 + + def test_hamming_similarity(self): + """Verify hamming similarity is positive for similar codes.""" + hash1 = torch.ones(1, 512) + hash2 = torch.ones(1, 512) + + sim = hamming_similarity(hash1, hash2) + + assert sim.item() == 512 # Max similarity + + +class TestSegmentCompressor: + """Test suite for SegmentCompressor.""" + + @pytest.fixture + def mock_image(self): + """Create a mock PIL image.""" + img = Image.new("RGB", (224, 224), color="red") + return img + + def test_segment_compressor_init(self): + """Verify SegmentCompressor initializes with correct parameters.""" + segmentor = SegmentCompressor( + model_name="facebook/sam2.1-hiera-large", + min_mask_area=100, + max_masks=10, + ) + + assert segmentor.model_name == "facebook/sam2.1-hiera-large" + assert segmentor.min_mask_area == 100 + assert segmentor.max_masks == 10 + + def test_filter_masks(self): + """Verify mask filtering logic.""" + # Create segmentor to get default filter params + segmentor = SegmentCompressor() + + # Create mock masks tensor with different areas + # Masks shape: [N, H, W] + masks = [] + for area in [50, 200, 150, 300, 10]: + mask = torch.zeros(100, 100) + mask[:1, :area] = 1 # Create mask with specific area + masks.append(mask) + + masks_tensor = torch.stack(masks) # [5, 100, 100] + valid = segmentor._filter_masks(masks_tensor) + + # Should filter out 50 and 10 (below min_mask_area=100) + # Then keep top 3 (max_masks=10) + assert len(valid) == 3 + # Verify sorted by area (descending) + areas = [v["area"] for v in valid] + assert areas == sorted(areas, reverse=True) + + +class TestDinoCompressor: + """Test suite for DinoCompressor.""" + + def test_dino_compressor_init(self): + """Verify DinoCompressor initializes correctly.""" + dino = DinoCompressor() + + assert dino.model_name == "facebook/dinov2-large" + + def test_dino_compressor_with_compressor(self): + """Verify DinoCompressor with HashCompressor.""" + hash_compressor = HashCompressor(input_dim=1024, hash_bits=512) + dino = DinoCompressor(compressor=hash_compressor) + + assert dino.compressor is hash_compressor + + +class TestSAMHashPipeline: + """Test suite for SAMHashPipeline.""" + + def test_pipeline_init(self): + """Verify pipeline initializes all components.""" + pipeline = SAMHashPipeline( + sam_model="facebook/sam2.1-hiera-large", + dino_model="facebook/dinov2-large", + hash_bits=512, + ) + + assert isinstance(pipeline.segmentor, SegmentCompressor) + assert isinstance(pipeline.dino, DinoCompressor) + assert isinstance(pipeline.hash_compressor, HashCompressor) + + def test_pipeline_hash_bits(self): + """Verify pipeline uses correct hash bits.""" + pipeline = SAMHashPipeline(hash_bits=256) + assert pipeline.hash_compressor.hash_bits == 256 + + +class TestConfigIntegration: + """Test suite for config integration with pipeline.""" + + def test_create_pipeline_from_config(self): + """Verify pipeline can be created from config.""" + config = cfg_manager.load() + + pipeline = create_pipeline_from_config(config) + + assert isinstance(pipeline, SAMHashPipeline) + assert pipeline.hash_compressor.hash_bits == config.model.compression_dim + + def test_config_sam_settings(self): + """Verify config contains SAM settings.""" + config = cfg_manager.load() + + assert hasattr(config.model, "sam_model") + assert hasattr(config.model, "sam_min_mask_area") + assert hasattr(config.model, "sam_max_masks") + assert config.model.sam_model == "facebook/sam2.1-hiera-large" + assert config.model.sam_min_mask_area == 100 + assert config.model.sam_max_masks == 10 + + +class TestPipelineIntegration: + """Integration tests for full pipeline (slow, requires model downloads).""" + + @pytest.mark.slow + def test_pipeline_end_to_end(self): + """Test full pipeline with actual models (slow test).""" + # Skip if no GPU + if not torch.cuda.is_available(): + pytest.skip("Requires CUDA") + + # Create a simple test image + image = Image.new("RGB", (640, 480), color=(128, 128, 128)) + + # Initialize pipeline (will download models on first run) + pipeline = SAMHashPipeline( + sam_model="facebook/sam2.1-hiera-large", + dino_model="facebook/dinov2-large", + hash_bits=512, + sam_min_mask_area=100, + sam_max_masks=5, + ) + + # Run pipeline + hash_codes = pipeline(image) + + # Verify output shape + assert hash_codes.dim() == 2 + assert hash_codes.shape[1] == 512 + assert torch.all((hash_codes == 0) | (hash_codes == 1)) + + @pytest.mark.slow + def test_extract_features_without_hash(self): + """Test feature extraction without hash compression.""" + if not torch.cuda.is_available(): + pytest.skip("Requires CUDA") + + image = Image.new("RGB", (640, 480), color=(128, 128, 128)) + + pipeline = SAMHashPipeline( + sam_model="facebook/sam2.1-hiera-large", + dino_model="facebook/dinov2-large", + ) + + features = pipeline.extract_features(image, use_hash=False) + + # Should return DINO features (1024 for large) + assert features.dim() == 2 + assert features.shape[1] == 1024 + + @pytest.mark.slow + def test_extract_masks_only(self): + """Test mask extraction only.""" + if not torch.cuda.is_available(): + pytest.skip("Requires CUDA") + + image = Image.new("RGB", (640, 480), color=(128, 128, 128)) + + pipeline = SAMHashPipeline( + sam_model="facebook/sam2.1-hiera-large", + ) + + masks = pipeline.extract_masks(image) + + # Should return a list of masks + assert isinstance(masks, list)