refactor(compressors): Simplify module by removing SAM/DINO separation code

- Remove dino_compressor.py and segament_compressor.py
- Rewrite pipeline.py to inline DINO into HashPipeline
- Maintain backward compatibility: SAMHashPipeline alias
- Update tests and benchmark.py
This commit is contained in:
2026-03-07 21:33:42 +08:00
parent c8dc5f9301
commit 4da08dc3d3
8 changed files with 276 additions and 490 deletions

View File

@@ -1,78 +1,65 @@
"""Complete pipeline for SAM + DINO + HashCompressor.
"""Hash compression pipeline with DINO feature extraction.
This pipeline extracts object masks from images using SAM2.1,
crops the objects, extracts features using DINOv2,
and compresses them to binary hash codes using HashCompressor.
This pipeline extracts features using DINOv2 and compresses them
to binary hash codes using HashCompressor.
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor
from .segament_compressor import SegmentCompressor
from transformers import AutoImageProcessor, AutoModel
def create_pipeline_from_config(config) -> "SAMHashPipeline":
"""Create SAMHashPipeline from a config object.
def create_pipeline_from_config(config) -> "HashPipeline":
"""Create HashPipeline from a config object.
Args:
config: Configuration object with model settings
Returns:
Initialized SAMHashPipeline
Initialized HashPipeline
"""
return SAMHashPipeline(
sam_model=config.model.sam_model,
dino_model=config.model.name,
return HashPipeline(
dino_model=config.model.dino_model,
hash_bits=config.model.compression_dim,
sam_min_mask_area=config.model.sam_min_mask_area,
sam_max_masks=config.model.sam_max_masks,
compressor_path=config.model.compressor_path,
device=config.model.device if config.model.device != "auto" else None,
)
class SAMHashPipeline(nn.Module):
"""Complete pipeline: SAM segmentation + DINO features + Hash compression.
class HashPipeline(nn.Module):
"""Pipeline: DINO features + Hash compression.
Pipeline flow:
Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes)
PIL Image -> DINO (features) -> Hash (binary codes)
Usage:
# Initialize with config
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
pipeline = HashPipeline(
dino_model="facebook/dinov2-large",
hash_bits=512,
)
# Process image
image = Image.open("path/to/image.jpg")
hash_codes = pipeline(image) # [N, 512] binary bits
hash_bits = pipeline(image) # [1, 512] binary bits
"""
def __init__(
self,
sam_model: str = "facebook/sam2.1-hiera-large",
dino_model: str = "facebook/dinov2-large",
hash_bits: int = 512,
sam_min_mask_area: int = 100,
sam_max_masks: int = 10,
compressor_path: Optional[str] = None,
device: Optional[str] = None,
):
"""Initialize the complete pipeline.
"""Initialize the pipeline.
Args:
sam_model: SAM model name from HuggingFace
dino_model: DINOv2 model name from HuggingFace
hash_bits: Number of bits in hash code
sam_min_mask_area: Minimum mask area threshold
sam_max_masks: Maximum number of masks to keep
compressor_path: Optional path to trained HashCompressor weights
device: Device to run models on
"""
@@ -83,87 +70,101 @@ class SAMHashPipeline(nn.Module):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Initialize components
self.segmentor = SegmentCompressor(
model_name=sam_model,
min_mask_area=sam_min_mask_area,
max_masks=sam_max_masks,
device=device,
)
self.dino_model = dino_model
# HashCompressor expects DINO features (1024 dim for dinov2-large)
dino_dim = 1024 if "large" in dino_model else 768
self.hash_compressor = HashCompressor(
input_dim=dino_dim, hash_bits=hash_bits
).to(device)
# Initialize DINO processor and model
self.processor = AutoImageProcessor.from_pretrained(dino_model)
self.dino = AutoModel.from_pretrained(dino_model).to(self.device)
self.dino.eval()
# Determine DINO feature dimension
self.dino_dim = 1024 if "large" in dino_model else 768
# Initialize HashCompressor
self.hash_compressor = nn.Module() # Placeholder, will be replaced
self._init_hash_compressor(hash_bits, compressor_path)
def _init_hash_compressor(
self, hash_bits: int, compressor_path: Optional[str] = None
):
"""Initialize the hash compressor module.
This is called during __init__ but we need to replace it properly.
"""
# Import here to avoid circular imports
from .hash_compressor import HashCompressor
compressor = HashCompressor(input_dim=self.dino_dim, hash_bits=hash_bits).to(
self.device
)
# Load pretrained compressor if provided
if compressor_path is not None:
self.hash_compressor.load_state_dict(
torch.load(compressor_path, map_location=device)
compressor.load_state_dict(
torch.load(compressor_path, map_location=self.device)
)
print(f"[OK] Loaded HashCompressor from {compressor_path}")
self.dino = DinoCompressor(
model_name=dino_model,
compressor=self.hash_compressor,
device=device,
)
# Replace the placeholder
self.hash_compressor = compressor
@property
def hash_bits(self):
"""Return the number of hash bits."""
return self.hash_compressor.hash_bits
def forward(self, image: Image.Image) -> torch.Tensor:
"""Process a single image through the complete pipeline.
"""Process a single image through the pipeline.
Args:
image: Input PIL Image
Returns:
Binary hash codes [N, hash_bits] where N is number of detected objects
Binary hash codes [1, hash_bits] as int32
"""
# Step 1: SAM - extract and crop objects
cropped_objects = self.segmentor(image)
# Extract DINO features
inputs = self.processor(image, return_tensors="pt").to(self.device)
if len(cropped_objects) == 0:
# No objects detected, return empty tensor
return torch.empty(
0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device
)
with torch.no_grad():
outputs = self.dino(**inputs)
tokens = outputs.last_hidden_state # [1, N, dim]
# Step 2: DINO - extract features from cropped objects
# Step 3: HashCompressor - compress features to binary codes
hash_codes = self.dino.encode(cropped_objects)
# Compress to hash codes
_, _, bits = self.hash_compressor(tokens)
return hash_codes
return bits
def extract_features(
self, image: Image.Image, use_hash: bool = False
) -> torch.Tensor:
"""Extract features from image with optional hash compression.
def encode(self, image: Image.Image) -> torch.Tensor:
"""Encode an image to binary hash bits.
Args:
image: Input PIL Image
use_hash: If True, return binary hash codes; else return DINO features
Returns:
Features [N, dim] where dim is 1024 (DINO) or 512 (hash)
"""
cropped_objects = self.segmentor(image)
if len(cropped_objects) == 0:
dim = self.hash_compressor.hash_bits if use_hash else 1024
return torch.empty(0, dim, device=self.device)
if use_hash:
return self.dino.encode(cropped_objects)
else:
return self.dino.extract_features(cropped_objects)
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without full processing (for debugging).
Alias for forward().
Args:
image: Input PIL Image
Returns:
List of binary masks [H, W]
Binary hash codes [1, hash_bits] as int32
"""
return self.segmentor.extract_masks(image)
return self.forward(image)
def extract_features(self, image: Image.Image) -> torch.Tensor:
"""Extract DINO features from an image.
Args:
image: Input PIL Image
Returns:
DINO features [1, dino_dim], normalized
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
features = outputs.last_hidden_state.mean(dim=1) # [1, dim]
features = F.normalize(features, dim=-1)
return features
# Backward compatibility alias
SAMHashPipeline = HashPipeline