feat(compressors): add SAM+DINO+Hash pipeline for object feature extraction

This commit is contained in:
2026-03-02 14:22:44 +08:00
parent 370c4a6588
commit a7b01cb49e
7 changed files with 753 additions and 8 deletions

View File

@@ -1,6 +1,8 @@
from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits
from .dino_compressor import DinoCompressor from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask
from .pipeline import SAMHashPipeline, create_pipeline_from_config
from .segament_compressor import SegmentCompressor
from .train import train from .train import train
__all__ = [ __all__ = [
@@ -9,6 +11,9 @@ __all__ = [
"HashCompressor", "HashCompressor",
"HashLoss", "HashLoss",
"VideoPositiveMask", "VideoPositiveMask",
"SegmentCompressor",
"SAMHashPipeline",
"create_pipeline_from_config",
"BinarySign", "BinarySign",
"hamming_distance", "hamming_distance",
"hamming_similarity", "hamming_similarity",

View File

@@ -1,8 +1,10 @@
from typing import Optional, cast from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from PIL import Image
from transformers import AutoModel, Dinov2Model from transformers import AutoImageProcessor, AutoModel
class DinoCompressor(nn.Module): class DinoCompressor(nn.Module):
@@ -10,15 +12,34 @@ class DinoCompressor(nn.Module):
When compressor is None: returns normalized DINO embeddings. When compressor is None: returns normalized DINO embeddings.
When compressor is provided: returns binary hash bits for CAM storage. When compressor is provided: returns binary hash bits for CAM storage.
Supports both PIL Image input and pre-extracted tokens.
""" """
def __init__(self, compressor: Optional[nn.Module] = None): def __init__(
self,
model_name: str = "facebook/dinov2-large",
compressor: Optional[nn.Module] = None,
device: Optional[str] = None,
):
"""Initialize DINOv2 extractor.
Args:
model_name: HuggingFace model name
compressor: Optional hash compressor for producing binary codes
device: Device to load model on
"""
super().__init__() super().__init__()
self.dino = cast( # Auto detect device
Dinov2Model, if device is None:
AutoModel.from_pretrained("facebook/dinov2-large"), device = "cuda" if torch.cuda.is_available() else "cpu"
) self.device = torch.device(device)
self.model_name = model_name
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.dino = AutoModel.from_pretrained(model_name).to(self.device)
self.dino.eval()
self.compressor = compressor self.compressor = compressor
@@ -34,3 +55,51 @@ class DinoCompressor(nn.Module):
# HashCompressor returns (logits, hash_codes, bits) # HashCompressor returns (logits, hash_codes, bits)
_, _, bits = self.compressor(teacher_tokens) _, _, bits = self.compressor(teacher_tokens)
return bits # [B, 512] binary bits for CAM return bits # [B, 512] binary bits for CAM
def extract_features(self, images: list[Image.Image]) -> torch.Tensor:
"""Extract DINO features from a list of cropped object images.
Args:
images: List of PIL Images (cropped objects)
Returns:
DINO features [N, feature_dim], normalized
"""
if len(images) == 0:
return torch.empty(0, self.dino.config.hidden_size, device=self.device)
# Process batch of images
inputs = self.processor(images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
# Pool tokens to get global representation
features = outputs.last_hidden_state.mean(dim=1) # [N, 1024]
features = F.normalize(features, dim=-1)
return features
def encode(self, images: list[Image.Image]) -> torch.Tensor:
"""Extract features from images and optionally compress to hash codes.
Args:
images: List of PIL Images
Returns:
If compressor is None: DINO features [N, 1024]
If compressor is set: Binary hash bits [N, 512]
"""
if self.compressor is None:
return self.extract_features(images)
# Extract features first
features = self.extract_features(images) # [N, 1024]
# Add sequence dimension for compressor (expects [B, N, dim])
features = features.unsqueeze(1) # [N, 1, 1024]
# Compress to hash codes
_, _, bits = self.compressor(features)
return bits

View File

@@ -0,0 +1,170 @@
"""Complete pipeline for SAM + DINO + HashCompressor.
This pipeline extracts object masks from images using SAM2.1,
crops the objects, extracts features using DINOv2,
and compresses them to binary hash codes using HashCompressor.
"""
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from PIL import Image
from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor
from .segament_compressor import SegmentCompressor
def create_pipeline_from_config(config) -> "SAMHashPipeline":
"""Create SAMHashPipeline from a config object.
Args:
config: Configuration object with model settings
Returns:
Initialized SAMHashPipeline
"""
return SAMHashPipeline(
sam_model=config.model.sam_model,
dino_model=config.model.name,
hash_bits=config.model.compression_dim,
sam_min_mask_area=config.model.sam_min_mask_area,
sam_max_masks=config.model.sam_max_masks,
compressor_path=config.model.compressor_path,
device=config.model.device if config.model.device != "auto" else None,
)
class SAMHashPipeline(nn.Module):
"""Complete pipeline: SAM segmentation + DINO features + Hash compression.
Pipeline flow:
Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes)
Usage:
# Initialize with config
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large",
hash_bits=512,
)
# Process image
image = Image.open("path/to/image.jpg")
hash_codes = pipeline(image) # [N, 512] binary bits
"""
def __init__(
self,
sam_model: str = "facebook/sam2.1-hiera-large",
dino_model: str = "facebook/dinov2-large",
hash_bits: int = 512,
sam_min_mask_area: int = 100,
sam_max_masks: int = 10,
compressor_path: Optional[str] = None,
device: Optional[str] = None,
):
"""Initialize the complete pipeline.
Args:
sam_model: SAM model name from HuggingFace
dino_model: DINOv2 model name from HuggingFace
hash_bits: Number of bits in hash code
sam_min_mask_area: Minimum mask area threshold
sam_max_masks: Maximum number of masks to keep
compressor_path: Optional path to trained HashCompressor weights
device: Device to run models on
"""
super().__init__()
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Initialize components
self.segmentor = SegmentCompressor(
model_name=sam_model,
min_mask_area=sam_min_mask_area,
max_masks=sam_max_masks,
device=device,
)
# HashCompressor expects DINO features (1024 dim for dinov2-large)
dino_dim = 1024 if "large" in dino_model else 768
self.hash_compressor = HashCompressor(
input_dim=dino_dim, hash_bits=hash_bits
).to(device)
# Load pretrained compressor if provided
if compressor_path is not None:
self.hash_compressor.load_state_dict(
torch.load(compressor_path, map_location=device)
)
print(f"[OK] Loaded HashCompressor from {compressor_path}")
self.dino = DinoCompressor(
model_name=dino_model,
compressor=self.hash_compressor,
device=device,
)
def forward(self, image: Image.Image) -> torch.Tensor:
"""Process a single image through the complete pipeline.
Args:
image: Input PIL Image
Returns:
Binary hash codes [N, hash_bits] where N is number of detected objects
"""
# Step 1: SAM - extract and crop objects
cropped_objects = self.segmentor(image)
if len(cropped_objects) == 0:
# No objects detected, return empty tensor
return torch.empty(
0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device
)
# Step 2: DINO - extract features from cropped objects
# Step 3: HashCompressor - compress features to binary codes
hash_codes = self.dino.encode(cropped_objects)
return hash_codes
def extract_features(
self, image: Image.Image, use_hash: bool = False
) -> torch.Tensor:
"""Extract features from image with optional hash compression.
Args:
image: Input PIL Image
use_hash: If True, return binary hash codes; else return DINO features
Returns:
Features [N, dim] where dim is 1024 (DINO) or 512 (hash)
"""
cropped_objects = self.segmentor(image)
if len(cropped_objects) == 0:
dim = self.hash_compressor.hash_bits if use_hash else 1024
return torch.empty(0, dim, device=self.device)
if use_hash:
return self.dino.encode(cropped_objects)
else:
return self.dino.extract_features(cropped_objects)
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without full processing (for debugging).
Args:
image: Input PIL Image
Returns:
List of binary masks [H, W]
"""
return self.segmentor.extract_masks(image)

View File

@@ -0,0 +1,180 @@
"""Segment Anything 2 feature extractor with mask filtering and image cropping.
Extracts object masks from images using SAM2.1, filters by area and confidence,
then crops the original image to obtain individual object regions.
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
class SegmentCompressor(nn.Module):
"""SAM2.1 based segmenter with mask filtering.
Extracts object masks from images, filters by area and confidence,
and crops the original image to produce individual object patches.
"""
def __init__(
self,
model_name: str = "facebook/sam2.1-hiera-large",
min_mask_area: int = 100,
max_masks: int = 10,
device: Optional[str] = None,
):
"""Initialize SAM2.1 segmenter.
Args:
model_name: HuggingFace model name for SAM2.1
min_mask_area: Minimum mask pixel area threshold
max_masks: Maximum number of masks to keep
device: Device to load model on (auto-detect if None)
"""
super().__init__()
self.model_name = model_name
self.min_mask_area = min_mask_area
self.max_masks = max_masks
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Load SAM model and processor
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForMaskGeneration.from_pretrained(model_name).to(
self.device
)
self.model.eval()
def forward(self, image: Image.Image) -> list[Image.Image]:
"""Extract object masks and crop object regions.
Args:
image: Input PIL Image
Returns:
List of cropped object images (one per valid mask)
"""
# Run SAM inference
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process masks
masks = self.processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)[0]
# Filter masks by area and confidence
valid_masks = self._filter_masks(masks)
if len(valid_masks) == 0:
return []
# Crop object regions from original image
cropped_objects = self._crop_objects(image, valid_masks)
return cropped_objects
def _filter_masks(self, masks: torch.Tensor) -> list[dict]:
"""Filter masks by area and keep top-N.
Args:
masks: Predicted masks [N, H, W]
Returns:
List of mask dictionaries with 'mask' and 'area'
"""
valid_masks = []
for mask in masks:
# Calculate mask area
area = mask.sum().item()
# Filter by minimum area
if area < self.min_mask_area:
continue
valid_masks.append({"mask": mask, "area": area})
# Sort by area (descending) and keep top-N
valid_masks = sorted(valid_masks, key=lambda x: x["area"], reverse=True)
valid_masks = valid_masks[: self.max_masks]
return valid_masks
def _crop_objects(
self, image: Image.Image, masks: list[dict]
) -> list[Image.Image]:
"""Crop object regions from image using masks.
Args:
image: Original PIL Image
masks: List of mask dictionaries
Returns:
List of cropped object images
"""
# Convert PIL to numpy for processing
image_np = np.array(image)
h, w = image_np.shape[:2]
cropped_objects = []
for mask_info in masks:
mask = mask_info["mask"].cpu().numpy()
# Find bounding box from mask
rows = mask.any(axis=1)
cols = mask.any(axis=0)
if not rows.any() or not cols.any():
continue
y_min, y_max = rows.argmax(), h - rows[::-1].argmax() - 1
x_min, x_max = cols.argmax(), w - cols[::-1].argmax() - 1
# Add small padding
pad = 5
x_min = max(0, x_min - pad)
y_min = max(0, y_min - pad)
x_max = min(w, x_max + pad)
y_max = min(h, y_max + pad)
# Crop
cropped = image.crop((x_min, y_min, x_max, y_max))
cropped_objects.append(cropped)
return cropped_objects
@torch.no_grad()
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without cropping (for debugging).
Args:
image: Input PIL Image
Returns:
List of binary masks [H, W]
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
masks = self.processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)[0]
valid_masks = self._filter_masks(masks)
return [m["mask"] for m in valid_masks]

View File

@@ -2,6 +2,10 @@ model:
name: "facebook/dinov2-large" name: "facebook/dinov2-large"
compression_dim: 512 compression_dim: 512
device: "auto" # auto-detect GPU device: "auto" # auto-detect GPU
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
sam_min_mask_area: 100 # Minimum mask area threshold
sam_max_masks: 10 # Maximum number of masks to keep
compressor_path: null # Path to trained HashCompressor weights (optional)
output: output:
directory: "./outputs" directory: "./outputs"

View File

@@ -1,6 +1,7 @@
"""Pydantic data models for feature compressor configuration.""" """Pydantic data models for feature compressor configuration."""
from pathlib import Path from pathlib import Path
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -15,6 +16,19 @@ class ModelConfig(BaseModel):
default=512, gt=0, description="Output feature dimension" default=512, gt=0, description="Output feature dimension"
) )
device: str = "auto" device: str = "auto"
sam_model: str = Field(
default="facebook/sam2.1-hiera-large",
description="SAM model name from HuggingFace",
)
sam_min_mask_area: int = Field(
default=100, gt=0, description="Minimum mask area threshold"
)
sam_max_masks: int = Field(
default=10, gt=0, description="Maximum number of masks to keep"
)
compressor_path: Optional[str] = Field(
default=None, description="Path to trained HashCompressor weights"
)
class OutputConfig(BaseModel): class OutputConfig(BaseModel):

View File

@@ -0,0 +1,303 @@
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import torch
from PIL import Image
from configs import cfg_manager
from compressors import (
BinarySign,
DinoCompressor,
HashCompressor,
SegmentCompressor,
SAMHashPipeline,
create_pipeline_from_config,
bits_to_hash,
hash_to_bits,
hamming_distance,
hamming_similarity,
)
class TestHashCompressor:
"""Test suite for HashCompressor."""
def test_hash_compressor_init(self):
"""Verify HashCompressor initializes with correct dimensions."""
compressor = HashCompressor(input_dim=1024, hash_bits=512)
assert compressor.input_dim == 1024
assert compressor.hash_bits == 512
def test_hash_compressor_forward(self):
"""Verify forward pass produces correct output shapes."""
compressor = HashCompressor(input_dim=1024, hash_bits=512)
tokens = torch.randn(4, 197, 1024) # [B, N, input_dim]
logits, hash_codes, bits = compressor(tokens)
assert logits.shape == (4, 512)
assert hash_codes.shape == (4, 512)
assert bits.shape == (4, 512)
# Verify bits are binary (0 or 1)
assert torch.all((bits == 0) | (bits == 1))
def test_hash_compressor_encode(self):
"""Verify encode method returns binary bits."""
compressor = HashCompressor(input_dim=1024, hash_bits=512)
tokens = torch.randn(2, 197, 1024)
bits = compressor.encode(tokens)
assert bits.shape == (2, 512)
assert bits.dtype == torch.int32
assert torch.all((bits == 0) | (bits == 1))
def test_hash_compressor_similarity(self):
"""Verify compute_similarity returns correct shape."""
compressor = HashCompressor(input_dim=1024, hash_bits=512)
# Create random bits
bits1 = torch.randint(0, 2, (3, 512))
bits2 = torch.randint(0, 2, (5, 512))
sim = compressor.compute_similarity(bits1, bits2)
assert sim.shape == (3, 5)
class TestBinarySign:
"""Test suite for BinarySign function."""
def test_binary_sign_forward(self):
"""Verify BinarySign produces {-1, +1} outputs."""
x = torch.randn(4, 512)
result = BinarySign.apply(x)
assert torch.all((result == 1) | (result == -1))
def test_binary_sign_round_trip(self):
"""Verify bits -> hash -> bits preserves values."""
bits = torch.randint(0, 2, (4, 512))
hash_codes = bits_to_hash(bits)
bits_recovered = hash_to_bits(hash_codes)
assert torch.equal(bits, bits_recovered)
class TestHammingMetrics:
"""Test suite for Hamming distance and similarity."""
def test_hamming_distance_same_codes(self):
"""Verify hamming distance is 0 for identical single codes."""
bits1 = torch.randint(0, 2, (512,))
bits2 = bits1.clone()
dist = hamming_distance(bits1, bits2)
assert dist.item() == 0
def test_hamming_distance_self_comparison(self):
"""Verify hamming distance diagonal is 0 (each code compared to itself)."""
bits = torch.randint(0, 2, (10, 512))
dist = hamming_distance(bits, bits)
# Diagonal should be 0 (distance to self)
diagonal = torch.diag(dist)
assert torch.all(diagonal == 0)
def test_hamming_distance_different(self):
"""Verify hamming distance is correct for different codes."""
bits1 = torch.zeros(1, 512, dtype=torch.int32)
bits2 = torch.ones(1, 512, dtype=torch.int32)
dist = hamming_distance(bits1, bits2)
assert dist.item() == 512
def test_hamming_similarity(self):
"""Verify hamming similarity is positive for similar codes."""
hash1 = torch.ones(1, 512)
hash2 = torch.ones(1, 512)
sim = hamming_similarity(hash1, hash2)
assert sim.item() == 512 # Max similarity
class TestSegmentCompressor:
"""Test suite for SegmentCompressor."""
@pytest.fixture
def mock_image(self):
"""Create a mock PIL image."""
img = Image.new("RGB", (224, 224), color="red")
return img
def test_segment_compressor_init(self):
"""Verify SegmentCompressor initializes with correct parameters."""
segmentor = SegmentCompressor(
model_name="facebook/sam2.1-hiera-large",
min_mask_area=100,
max_masks=10,
)
assert segmentor.model_name == "facebook/sam2.1-hiera-large"
assert segmentor.min_mask_area == 100
assert segmentor.max_masks == 10
def test_filter_masks(self):
"""Verify mask filtering logic."""
# Create segmentor to get default filter params
segmentor = SegmentCompressor()
# Create mock masks tensor with different areas
# Masks shape: [N, H, W]
masks = []
for area in [50, 200, 150, 300, 10]:
mask = torch.zeros(100, 100)
mask[:1, :area] = 1 # Create mask with specific area
masks.append(mask)
masks_tensor = torch.stack(masks) # [5, 100, 100]
valid = segmentor._filter_masks(masks_tensor)
# Should filter out 50 and 10 (below min_mask_area=100)
# Then keep top 3 (max_masks=10)
assert len(valid) == 3
# Verify sorted by area (descending)
areas = [v["area"] for v in valid]
assert areas == sorted(areas, reverse=True)
class TestDinoCompressor:
"""Test suite for DinoCompressor."""
def test_dino_compressor_init(self):
"""Verify DinoCompressor initializes correctly."""
dino = DinoCompressor()
assert dino.model_name == "facebook/dinov2-large"
def test_dino_compressor_with_compressor(self):
"""Verify DinoCompressor with HashCompressor."""
hash_compressor = HashCompressor(input_dim=1024, hash_bits=512)
dino = DinoCompressor(compressor=hash_compressor)
assert dino.compressor is hash_compressor
class TestSAMHashPipeline:
"""Test suite for SAMHashPipeline."""
def test_pipeline_init(self):
"""Verify pipeline initializes all components."""
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large",
hash_bits=512,
)
assert isinstance(pipeline.segmentor, SegmentCompressor)
assert isinstance(pipeline.dino, DinoCompressor)
assert isinstance(pipeline.hash_compressor, HashCompressor)
def test_pipeline_hash_bits(self):
"""Verify pipeline uses correct hash bits."""
pipeline = SAMHashPipeline(hash_bits=256)
assert pipeline.hash_compressor.hash_bits == 256
class TestConfigIntegration:
"""Test suite for config integration with pipeline."""
def test_create_pipeline_from_config(self):
"""Verify pipeline can be created from config."""
config = cfg_manager.load()
pipeline = create_pipeline_from_config(config)
assert isinstance(pipeline, SAMHashPipeline)
assert pipeline.hash_compressor.hash_bits == config.model.compression_dim
def test_config_sam_settings(self):
"""Verify config contains SAM settings."""
config = cfg_manager.load()
assert hasattr(config.model, "sam_model")
assert hasattr(config.model, "sam_min_mask_area")
assert hasattr(config.model, "sam_max_masks")
assert config.model.sam_model == "facebook/sam2.1-hiera-large"
assert config.model.sam_min_mask_area == 100
assert config.model.sam_max_masks == 10
class TestPipelineIntegration:
"""Integration tests for full pipeline (slow, requires model downloads)."""
@pytest.mark.slow
def test_pipeline_end_to_end(self):
"""Test full pipeline with actual models (slow test)."""
# Skip if no GPU
if not torch.cuda.is_available():
pytest.skip("Requires CUDA")
# Create a simple test image
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
# Initialize pipeline (will download models on first run)
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large",
hash_bits=512,
sam_min_mask_area=100,
sam_max_masks=5,
)
# Run pipeline
hash_codes = pipeline(image)
# Verify output shape
assert hash_codes.dim() == 2
assert hash_codes.shape[1] == 512
assert torch.all((hash_codes == 0) | (hash_codes == 1))
@pytest.mark.slow
def test_extract_features_without_hash(self):
"""Test feature extraction without hash compression."""
if not torch.cuda.is_available():
pytest.skip("Requires CUDA")
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large",
)
features = pipeline.extract_features(image, use_hash=False)
# Should return DINO features (1024 for large)
assert features.dim() == 2
assert features.shape[1] == 1024
@pytest.mark.slow
def test_extract_masks_only(self):
"""Test mask extraction only."""
if not torch.cuda.is_available():
pytest.skip("Requires CUDA")
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
)
masks = pipeline.extract_masks(image)
# Should return a list of masks
assert isinstance(masks, list)