Files
Mini-Nav/mini-nav/compressors/pipeline.py
SikongJueluo 4da08dc3d3 refactor(compressors): Simplify module by removing SAM/DINO separation code
- Remove dino_compressor.py and segament_compressor.py
- Rewrite pipeline.py to inline DINO into HashPipeline
- Maintain backward compatibility: SAMHashPipeline alias
- Update tests and benchmark.py
2026-03-07 22:55:13 +08:00

171 lines
4.9 KiB
Python

"""Hash compression pipeline with DINO feature extraction.
This pipeline extracts features using DINOv2 and compresses them
to binary hash codes using HashCompressor.
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
def create_pipeline_from_config(config) -> "HashPipeline":
"""Create HashPipeline from a config object.
Args:
config: Configuration object with model settings
Returns:
Initialized HashPipeline
"""
return HashPipeline(
dino_model=config.model.dino_model,
hash_bits=config.model.compression_dim,
compressor_path=config.model.compressor_path,
device=config.model.device if config.model.device != "auto" else None,
)
class HashPipeline(nn.Module):
"""Pipeline: DINO features + Hash compression.
Pipeline flow:
PIL Image -> DINO (features) -> Hash (binary codes)
Usage:
# Initialize with config
pipeline = HashPipeline(
dino_model="facebook/dinov2-large",
hash_bits=512,
)
# Process image
image = Image.open("path/to/image.jpg")
hash_bits = pipeline(image) # [1, 512] binary bits
"""
def __init__(
self,
dino_model: str = "facebook/dinov2-large",
hash_bits: int = 512,
compressor_path: Optional[str] = None,
device: Optional[str] = None,
):
"""Initialize the pipeline.
Args:
dino_model: DINOv2 model name from HuggingFace
hash_bits: Number of bits in hash code
compressor_path: Optional path to trained HashCompressor weights
device: Device to run models on
"""
super().__init__()
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.dino_model = dino_model
# Initialize DINO processor and model
self.processor = AutoImageProcessor.from_pretrained(dino_model)
self.dino = AutoModel.from_pretrained(dino_model).to(self.device)
self.dino.eval()
# Determine DINO feature dimension
self.dino_dim = 1024 if "large" in dino_model else 768
# Initialize HashCompressor
self.hash_compressor = nn.Module() # Placeholder, will be replaced
self._init_hash_compressor(hash_bits, compressor_path)
def _init_hash_compressor(
self, hash_bits: int, compressor_path: Optional[str] = None
):
"""Initialize the hash compressor module.
This is called during __init__ but we need to replace it properly.
"""
# Import here to avoid circular imports
from .hash_compressor import HashCompressor
compressor = HashCompressor(input_dim=self.dino_dim, hash_bits=hash_bits).to(
self.device
)
# Load pretrained compressor if provided
if compressor_path is not None:
compressor.load_state_dict(
torch.load(compressor_path, map_location=self.device)
)
print(f"[OK] Loaded HashCompressor from {compressor_path}")
# Replace the placeholder
self.hash_compressor = compressor
@property
def hash_bits(self):
"""Return the number of hash bits."""
return self.hash_compressor.hash_bits
def forward(self, image: Image.Image) -> torch.Tensor:
"""Process a single image through the pipeline.
Args:
image: Input PIL Image
Returns:
Binary hash codes [1, hash_bits] as int32
"""
# Extract DINO features
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
tokens = outputs.last_hidden_state # [1, N, dim]
# Compress to hash codes
_, _, bits = self.hash_compressor(tokens)
return bits
def encode(self, image: Image.Image) -> torch.Tensor:
"""Encode an image to binary hash bits.
Alias for forward().
Args:
image: Input PIL Image
Returns:
Binary hash codes [1, hash_bits] as int32
"""
return self.forward(image)
def extract_features(self, image: Image.Image) -> torch.Tensor:
"""Extract DINO features from an image.
Args:
image: Input PIL Image
Returns:
DINO features [1, dino_dim], normalized
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
features = outputs.last_hidden_state.mean(dim=1) # [1, dim]
features = F.normalize(features, dim=-1)
return features
# Backward compatibility alias
SAMHashPipeline = HashPipeline