From 1926cb53e271b6c849e99aab74a4cea07e00b70f Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Fri, 20 Feb 2026 20:29:38 +0800 Subject: [PATCH] feat(compressors): replace float/int compressors with hash-based compression for CAM --- mini-nav/compressors/__init__.py | 17 +- mini-nav/compressors/common.py | 87 +++++ mini-nav/compressors/dino_compressor.py | 11 +- mini-nav/compressors/float_compressor.py | 27 -- mini-nav/compressors/hash_compressor.py | 364 ++++++++++++++++++++ mini-nav/compressors/int_compressor.py | 9 - mini-nav/compressors/segament_compressor.py | 0 mini-nav/compressors/train.py | 69 ++-- mini-nav/main.py | 7 +- 9 files changed, 527 insertions(+), 64 deletions(-) create mode 100644 mini-nav/compressors/common.py delete mode 100644 mini-nav/compressors/float_compressor.py create mode 100644 mini-nav/compressors/hash_compressor.py delete mode 100644 mini-nav/compressors/int_compressor.py create mode 100644 mini-nav/compressors/segament_compressor.py diff --git a/mini-nav/compressors/__init__.py b/mini-nav/compressors/__init__.py index 3d9efd3..3e1ec0e 100644 --- a/mini-nav/compressors/__init__.py +++ b/mini-nav/compressors/__init__.py @@ -1,6 +1,17 @@ +from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits from .dino_compressor import DinoCompressor -from .float_compressor import FloatCompressor -from .int_compressor import IntCompressor +from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask from .train import train -__all__ = ["train", "FloatCompressor", "IntCompressor", "DinoCompressor"] +__all__ = [ + "train", + "DinoCompressor", + "HashCompressor", + "HashLoss", + "VideoPositiveMask", + "BinarySign", + "hamming_distance", + "hamming_similarity", + "bits_to_hash", + "hash_to_bits", +] diff --git a/mini-nav/compressors/common.py b/mini-nav/compressors/common.py new file mode 100644 index 0000000..2705c9d --- /dev/null +++ b/mini-nav/compressors/common.py @@ -0,0 +1,87 @@ +"""Common utilities for compressor modules.""" + +import torch +import torch.nn.functional as F + + +class BinarySign(torch.autograd.Function): + """Binary sign function with Straight-Through Estimator (STE). + + Forward: returns sign(x) in {-1, +1} + Backward: passes gradients through as if identity + + For CAM storage, convert: bits = (sign_output + 1) / 2 + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sign() + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + # STE: treat as identity + # Optional: gradient clipping for stability + return grad_output.clone() + + +def hamming_distance(b1, b2): + """Compute Hamming distance between binary codes. + + Args: + b1: Binary codes {0,1}, shape [N, D] or [D] + b2: Binary codes {0,1}, shape [M, D] or [D] + + Returns: + Hamming distances, shape [N, M] or scalar + """ + if b1.dim() == 1 and b2.dim() == 1: + return (b1 != b2).sum() + + # Expand for pairwise computation + b1 = b1.unsqueeze(1) # [N, 1, D] + b2 = b2.unsqueeze(0) # [1, M, D] + + return (b1 != b2).sum(dim=-1) # [N, M] + + +def hamming_similarity(h1, h2): + """Compute Hamming similarity for {-1, +1} codes. + + Args: + h1: Hash codes {-1, +1}, shape [N, D] or [D] + h2: Hash codes {-1, +1}, shape [M, D] or [D] + + Returns: + Similarity scores in [-D, D], shape [N, M] or scalar + Higher is more similar + """ + if h1.dim() == 1 and h2.dim() == 1: + return (h1 * h2).sum() + + return h1 @ h2.t() # [N, M] + + +def bits_to_hash(b): + """Convert {0,1} bits to {-1,+1} hash codes. + + Args: + b: Binary bits {0,1}, any shape + + Returns: + Hash codes {-1,+1}, same shape + """ + return b * 2 - 1 + + +def hash_to_bits(h): + """Convert {-1,+1} hash codes to {0,1} bits. + + Args: + h: Hash codes {-1,+1}, any shape + + Returns: + Binary bits {0,1}, same shape + """ + return (h + 1) / 2 diff --git a/mini-nav/compressors/dino_compressor.py b/mini-nav/compressors/dino_compressor.py index 1f5018b..33a948a 100644 --- a/mini-nav/compressors/dino_compressor.py +++ b/mini-nav/compressors/dino_compressor.py @@ -6,6 +6,12 @@ from transformers import AutoModel, Dinov2Model class DinoCompressor(nn.Module): + """DINOv2 feature extractor with optional hash compression. + + When compressor is None: returns normalized DINO embeddings. + When compressor is provided: returns binary hash bits for CAM storage. + """ + def __init__(self, compressor: Optional[nn.Module] = None): super().__init__() @@ -25,5 +31,6 @@ class DinoCompressor(nn.Module): if self.compressor is None: return teacher_embed - feats, recon = self.compressor(teacher_tokens) - return feats + # HashCompressor returns (logits, hash_codes, bits) + _, _, bits = self.compressor(teacher_tokens) + return bits # [B, 512] binary bits for CAM diff --git a/mini-nav/compressors/float_compressor.py b/mini-nav/compressors/float_compressor.py deleted file mode 100644 index fb08e28..0000000 --- a/mini-nav/compressors/float_compressor.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class FloatCompressor(nn.Module): - def __init__(self): - super().__init__() - - # projection head - self.proj = nn.Sequential( - nn.Linear(1024, 1024), - nn.LayerNorm(1024), - nn.GELU(), - nn.Linear(1024, 512), - ) - - self.recover = nn.Linear(512, 1024) - - def forward(self, tokens): - pooled = tokens.mean(dim=1) # [B,1024] - - z512 = self.proj(pooled) # [B,512] - z512 = F.normalize(z512, dim=-1) - - recon = self.recover(z512) # [B,1024] - - return z512, recon diff --git a/mini-nav/compressors/hash_compressor.py b/mini-nav/compressors/hash_compressor.py new file mode 100644 index 0000000..248c78c --- /dev/null +++ b/mini-nav/compressors/hash_compressor.py @@ -0,0 +1,364 @@ +"""Hash-based compressor for CAM-compatible binary codes. + +Converts DINO features to 512-bit binary hash codes suitable for +Content Addressable Memory (CAM) retrieval. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .common import BinarySign, hamming_similarity + + +class HashCompressor(nn.Module): + """Compress DINO tokens to 512-bit binary codes for CAM storage. + + Architecture: + tokens -> mean pool -> projection -> binary sign -> hash codes + + Output formats: + - logits: continuous values for training (before sign) + - hash_codes: {-1, +1} for similarity computation + - bits: {0, 1} for CAM storage + + Example: + >>> compressor = HashCompressor() + >>> tokens = torch.randn(4, 197, 1024) # DINO output + >>> logits, hash_codes, bits = compressor(tokens) + >>> bits.shape + torch.Size([4, 512]) + >>> bits.dtype + torch.int32 + """ + + def __init__(self, input_dim: int = 1024, hash_bits: int = 512): + """Initialize hash compressor. + + Args: + input_dim: Input feature dimension (DINO output = 1024) + hash_bits: Number of bits in hash code (CAM constraint = 512) + """ + super().__init__() + + self.input_dim = input_dim + self.hash_bits = hash_bits + + # Projection head: maps DINO features to hash logits + self.proj = nn.Sequential( + nn.Linear(input_dim, input_dim), + nn.LayerNorm(input_dim), + nn.GELU(), + nn.Linear(input_dim, hash_bits), + ) + + # Initialize last layer with smaller weights for stable training + nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1) + nn.init.zeros_(self.proj[-1].bias) + + def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass producing hash codes. + + Args: + tokens: DINO patch tokens, shape [B, N, input_dim] + + Returns: + Tuple of (logits, hash_codes, bits): + - logits: [B, hash_bits] continuous values for training + - hash_codes: [B, hash_bits] {-1, +1} values + - bits: [B, hash_bits] {0, 1} values for CAM storage + """ + # Pool tokens to single feature vector + pooled = tokens.mean(dim=1) # [B, input_dim] + + # Project to hash dimension + logits = self.proj(pooled) # [B, hash_bits] + + # Binary hash codes with STE for backprop + hash_codes = BinarySign.apply(logits) # [B, hash_bits] in {-1, +1} + + # Convert to bits for CAM storage + bits = (hash_codes > 0).int() # [B, hash_bits] in {0, 1} + + return logits, hash_codes, bits + + def encode(self, tokens: torch.Tensor) -> torch.Tensor: + """Encode tokens to binary bits for CAM storage. + + This is the inference-time method for database insertion. + + Args: + tokens: DINO patch tokens, shape [B, N, input_dim] + + Returns: + Binary bits [B, hash_bits] as int32 for CAM + """ + _, _, bits = self.forward(tokens) + return bits + + def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor: + """Compute Hamming similarity between query and database entries. + + Higher score = more similar (fewer differing bits). + + Args: + query_bits: Query bits {0,1}, shape [Q, hash_bits] + db_bits: Database bits {0,1}, shape [N, hash_bits] + + Returns: + Similarity scores [Q, N], range [0, hash_bits] + """ + # Convert bits to hash codes + query_hash = query_bits * 2 - 1 # {0,1} -> {-1,+1} + db_hash = db_bits * 2 - 1 + + return hamming_similarity(query_hash, db_hash) + + +class HashLoss(nn.Module): + """Batch-level retrieval loss for hash code learning. + + Combines three objectives: + 1. Contrastive: similar inputs have similar hash codes + 2. Distillation: hash preserves original DINO similarity structure + 3. Quantization: hash codes are close to binary {-1, +1} + + All losses are computed within batch - no full database retrieval needed. + """ + + def __init__( + self, + contrastive_weight: float = 1.0, + distill_weight: float = 0.5, + quant_weight: float = 0.01, + temperature: float = 0.2, + ): + """Initialize loss function. + + Args: + contrastive_weight: Weight for contrastive loss + distill_weight: Weight for distillation loss + quant_weight: Weight for quantization loss + temperature: Temperature for contrastive similarity scaling + """ + super().__init__() + self.contrastive_weight = contrastive_weight + self.distill_weight = distill_weight + self.quant_weight = quant_weight + self.temperature = temperature + + def contrastive_loss( + self, + logits: torch.Tensor, + hash_codes: torch.Tensor, + positive_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """InfoNCE-style contrastive loss within batch. + + Learns that positive pairs (similar images) have similar hash codes, + and negative pairs (different images) have dissimilar codes. + + Args: + logits: Continuous logits [B, hash_bits] + hash_codes: Binary hash codes {-1,+1} [B, hash_bits] + positive_mask: Boolean mask [B, B] where True indicates positive pair + If None, uses identity matrix (each sample is its own positive) + + Returns: + Scalar contrastive loss + """ + batch_size = logits.size(0) + device = logits.device + + # Use cosine similarity on continuous logits (more stable during training) + logits_norm = F.normalize(logits, dim=-1) + sim_matrix = logits_norm @ logits_norm.t() / self.temperature # [B, B] + + # Create positive mask: diagonal is always positive (self-similarity) + if positive_mask is None: + positive_mask = torch.eye(batch_size, device=device, dtype=torch.bool) + + # InfoNCE: for each sample, positives should have high similarity + # Mask out self-similarity for numerical stability + mask_self = torch.eye(batch_size, device=device, dtype=torch.bool) + sim_matrix_masked = sim_matrix.masked_fill(mask_self, float("-inf")) + + # For each anchor, positives are the target + # We use a symmetric formulation: each positive pair contributes + loss = 0.0 + num_positives = 0 + + for i in range(batch_size): + pos_indices = positive_mask[i].nonzero(as_tuple=True)[0] + if len(pos_indices) == 0: + continue + + # Numerator: similarity to positives + pos_sim = sim_matrix[i, pos_indices] # [num_positives] + + # Denominator: similarity to all negatives (including self as neg for stability) + neg_sim = sim_matrix_masked[i] # [B] + + # Log-sum-exp for numerical stability + max_sim = neg_sim.max() + log_denom = max_sim + torch.log(torch.exp(neg_sim - max_sim).sum()) + + # Loss for this anchor + loss += -pos_sim.mean() + log_denom + num_positives += 1 + + return loss / max(num_positives, 1) + + def distillation_loss( + self, + hash_codes: torch.Tensor, + teacher_embed: torch.Tensor, + ) -> torch.Tensor: + """Distillation loss preserving DINO similarity structure. + + Ensures that if two images are similar in DINO space, + they remain similar in hash space. + + Args: + hash_codes: Binary hash codes {-1,+1} [B, hash_bits] + teacher_embed: DINO embeddings [B, teacher_dim], assumed normalized + + Returns: + Scalar distillation loss + """ + hash_bits = hash_codes.size(-1) + + # Hash similarity: inner product of {-1,+1} gives range [-hash_bits, hash_bits] + hash_sim = hash_codes @ hash_codes.t() # [B, B] + hash_sim = hash_sim / hash_bits # Normalize to [-1, 1] + + # Teacher similarity: cosine (assumes teacher_embed is normalized) + teacher_sim = teacher_embed @ teacher_embed.t() # [B, B] + + # MSE between similarity matrices + loss = F.mse_loss(hash_sim, teacher_sim) + + return loss + + def quantization_loss(self, logits: torch.Tensor) -> torch.Tensor: + """Quantization loss pushing logits toward {-1, +1}. + + Without this, logits stay near 0 and sign() is unstable. + + Args: + logits: Continuous logits [B, hash_bits] + + Returns: + Scalar quantization loss + """ + # Push |logit| toward 1 + return torch.mean(torch.abs(logits.abs() - 1)) + + def forward( + self, + logits: torch.Tensor, + hash_codes: torch.Tensor, + teacher_embed: torch.Tensor, + positive_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, dict[str, float]]: + """Compute combined hash training loss. + + Args: + logits: Continuous logits [B, hash_bits] + hash_codes: Binary hash codes {-1,+1} [B, hash_bits] + teacher_embed: DINO embeddings [B, teacher_dim] + positive_mask: Optional positive pair mask [B, B] + + Returns: + Tuple of (total_loss, loss_components_dict) + """ + # Ensure teacher embeddings are normalized + teacher_embed = F.normalize(teacher_embed, dim=-1) + + # Compute individual losses + loss_cont = self.contrastive_loss(logits, hash_codes, positive_mask) + loss_distill = self.distillation_loss(hash_codes, teacher_embed) + loss_quant = self.quantization_loss(logits) + + # Combine + total_loss = ( + self.contrastive_weight * loss_cont + + self.distill_weight * loss_distill + + self.quant_weight * loss_quant + ) + + # Return components for logging + components = { + "contrastive": loss_cont.item(), + "distill": loss_distill.item(), + "quantization": loss_quant.item(), + "total": total_loss.item(), + } + + return total_loss, components + + +class VideoPositiveMask: + """Generate positive pair masks for video sequences. + + In indoor navigation, consecutive video frames are positive pairs + (same location, different viewpoint/lighting). + """ + + def __init__(self, temporal_window: int = 3): + """Initialize mask generator. + + Args: + temporal_window: Frames within this distance are considered positive + """ + self.temporal_window = temporal_window + + def from_frame_indices(self, frame_indices: torch.Tensor) -> torch.Tensor: + """Create positive mask from frame indices. + + Args: + frame_indices: Frame index for each sample [B] + + Returns: + Boolean mask [B, B] where True indicates positive pair + """ + batch_size = frame_indices.size(0) + device = frame_indices.device + + # Compute temporal distance + indices_i = frame_indices.unsqueeze(1) # [B, 1] + indices_j = frame_indices.unsqueeze(0) # [1, B] + temporal_dist = (indices_i - indices_j).abs() # [B, B] + + # Positive if within temporal window + positive_mask = temporal_dist <= self.temporal_window + + # Exclude self (diagonal will be handled separately in loss) + # Actually keep it, loss handles self-similarity specially + + return positive_mask + + def from_video_ids( + self, video_ids: torch.Tensor, frame_indices: torch.Tensor + ) -> torch.Tensor: + """Create positive mask considering both video ID and frame index. + + Args: + video_ids: Video ID for each sample [B] + frame_indices: Frame index within video [B] + + Returns: + Boolean mask [B, B] where True indicates positive pair + """ + batch_size = video_ids.size(0) + device = video_ids.device + + # Same video + same_video = video_ids.unsqueeze(1) == video_ids.unsqueeze(0) # [B, B] + + # Temporal proximity + temporal_dist = (frame_indices.unsqueeze(1) - frame_indices.unsqueeze(0)).abs() + temporal_close = temporal_dist <= self.temporal_window + + # Positive if same video AND temporally close + return same_video & temporal_close diff --git a/mini-nav/compressors/int_compressor.py b/mini-nav/compressors/int_compressor.py deleted file mode 100644 index f07da04..0000000 --- a/mini-nav/compressors/int_compressor.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch.nn as nn - - -class IntCompressor(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - pass diff --git a/mini-nav/compressors/segament_compressor.py b/mini-nav/compressors/segament_compressor.py new file mode 100644 index 0000000..e69de29 diff --git a/mini-nav/compressors/train.py b/mini-nav/compressors/train.py index 2fa52aa..279c72e 100644 --- a/mini-nav/compressors/train.py +++ b/mini-nav/compressors/train.py @@ -1,8 +1,10 @@ +"""Training script for hash compressor.""" + import os import torch import torch.nn.functional as F -from compressors import FloatCompressor +from compressors import HashCompressor, HashLoss from configs import cfg_manager from datasets import load_dataset from torch import nn @@ -41,9 +43,20 @@ def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"): def train( - dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt" + epoch_size: int = 10, + batch_size: int = 64, + lr: float = 1e-4, + checkpoint_path: str = "hash_checkpoint.pt", ): - # Auto dectect device + """Train hash compressor with batch-level retrieval loss. + + Args: + epoch_size: Number of epochs to train + batch_size: Batch size for training + lr: Learning rate + checkpoint_path: Path to save/load checkpoints + """ + # Auto detect device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Global variables @@ -60,17 +73,25 @@ def train( "facebook/dinov2-large", device_map=device ) - # Load model + # Load DINO model (frozen) dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device) dino.eval() for p in dino.parameters(): p.requires_grad = False - # Load compressor model - compressor = FloatCompressor().to(device) + # Load hash compressor + compressor = HashCompressor(input_dim=1024, hash_bits=512).to(device) + + # Load loss function + loss_fn = HashLoss( + contrastive_weight=1.0, + distill_weight=0.5, + quant_weight=0.01, + temperature=0.2, + ) # Load optimizer - optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4) + optimizer = torch.optim.AdamW(compressor.parameters(), lr=lr) # Auto load checkpoint output_dir = cfg_manager.get().output.directory @@ -99,32 +120,38 @@ def train( teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] # ---- student forward ---- - z512, recon = compressor(teacher_tokens) + logits, hash_codes, bits = compressor(teacher_tokens) # ---- loss ---- - mse_loss = F.mse_loss(recon, teacher_embed) - - cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean() - - loss = mse_loss + cos_loss + total_loss, components = loss_fn( + logits=logits, + hash_codes=hash_codes, + teacher_embed=teacher_embed, + ) # ---- backward ---- optimizer.zero_grad() - loss.backward() + total_loss.backward() optimizer.step() - train_bar.set_postfix(loss=loss.item()) + # ---- logging ---- + train_bar.set_postfix( + loss=f"{components['total']:.4f}", + cont=f"{components['contrastive']:.2f}", + distill=f"{components['distill']:.3f}", + quant=f"{components['quantization']:.2f}", + ) # ---- periodic save ---- if global_step % save_every == 0: - save_checkpoint(compressor, optimizer, epoch, global_step) + save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path) + except KeyboardInterrupt: print("\n⚠️ Training interrupted, saving checkpoint...") - - save_checkpoint(compressor, optimizer, epoch, global_step) - + save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path) print("✅ Checkpoint saved. Exiting.") return - torch.save(compressor.state_dict(), output_dir / "compressor.pt") - print("✅ Final compressor saved") + # Save final model + torch.save(compressor.state_dict(), output_dir / "hash_compressor.pt") + print("✅ Final hash compressor saved") diff --git a/mini-nav/main.py b/mini-nav/main.py index 8dfccfc..af38e0d 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -10,9 +10,12 @@ if __name__ == "__main__": args = parser.parse_args() if args.action == "train": - from compressors import FloatCompressor, train + from compressors import train - train(FloatCompressor(), 1, 32) + # 启动训练 + train( + epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt" + ) elif args.action == "benchmark": from benchmarks import evaluate