mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(compressors): replace float/int compressors with hash-based compression for CAM
This commit is contained in:
@@ -1,6 +1,17 @@
|
|||||||
|
from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits
|
||||||
from .dino_compressor import DinoCompressor
|
from .dino_compressor import DinoCompressor
|
||||||
from .float_compressor import FloatCompressor
|
from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask
|
||||||
from .int_compressor import IntCompressor
|
|
||||||
from .train import train
|
from .train import train
|
||||||
|
|
||||||
__all__ = ["train", "FloatCompressor", "IntCompressor", "DinoCompressor"]
|
__all__ = [
|
||||||
|
"train",
|
||||||
|
"DinoCompressor",
|
||||||
|
"HashCompressor",
|
||||||
|
"HashLoss",
|
||||||
|
"VideoPositiveMask",
|
||||||
|
"BinarySign",
|
||||||
|
"hamming_distance",
|
||||||
|
"hamming_similarity",
|
||||||
|
"bits_to_hash",
|
||||||
|
"hash_to_bits",
|
||||||
|
]
|
||||||
|
|||||||
87
mini-nav/compressors/common.py
Normal file
87
mini-nav/compressors/common.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Common utilities for compressor modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class BinarySign(torch.autograd.Function):
|
||||||
|
"""Binary sign function with Straight-Through Estimator (STE).
|
||||||
|
|
||||||
|
Forward: returns sign(x) in {-1, +1}
|
||||||
|
Backward: passes gradients through as if identity
|
||||||
|
|
||||||
|
For CAM storage, convert: bits = (sign_output + 1) / 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return x.sign()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
(x,) = ctx.saved_tensors
|
||||||
|
# STE: treat as identity
|
||||||
|
# Optional: gradient clipping for stability
|
||||||
|
return grad_output.clone()
|
||||||
|
|
||||||
|
|
||||||
|
def hamming_distance(b1, b2):
|
||||||
|
"""Compute Hamming distance between binary codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
b1: Binary codes {0,1}, shape [N, D] or [D]
|
||||||
|
b2: Binary codes {0,1}, shape [M, D] or [D]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hamming distances, shape [N, M] or scalar
|
||||||
|
"""
|
||||||
|
if b1.dim() == 1 and b2.dim() == 1:
|
||||||
|
return (b1 != b2).sum()
|
||||||
|
|
||||||
|
# Expand for pairwise computation
|
||||||
|
b1 = b1.unsqueeze(1) # [N, 1, D]
|
||||||
|
b2 = b2.unsqueeze(0) # [1, M, D]
|
||||||
|
|
||||||
|
return (b1 != b2).sum(dim=-1) # [N, M]
|
||||||
|
|
||||||
|
|
||||||
|
def hamming_similarity(h1, h2):
|
||||||
|
"""Compute Hamming similarity for {-1, +1} codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h1: Hash codes {-1, +1}, shape [N, D] or [D]
|
||||||
|
h2: Hash codes {-1, +1}, shape [M, D] or [D]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity scores in [-D, D], shape [N, M] or scalar
|
||||||
|
Higher is more similar
|
||||||
|
"""
|
||||||
|
if h1.dim() == 1 and h2.dim() == 1:
|
||||||
|
return (h1 * h2).sum()
|
||||||
|
|
||||||
|
return h1 @ h2.t() # [N, M]
|
||||||
|
|
||||||
|
|
||||||
|
def bits_to_hash(b):
|
||||||
|
"""Convert {0,1} bits to {-1,+1} hash codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
b: Binary bits {0,1}, any shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hash codes {-1,+1}, same shape
|
||||||
|
"""
|
||||||
|
return b * 2 - 1
|
||||||
|
|
||||||
|
|
||||||
|
def hash_to_bits(h):
|
||||||
|
"""Convert {-1,+1} hash codes to {0,1} bits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h: Hash codes {-1,+1}, any shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary bits {0,1}, same shape
|
||||||
|
"""
|
||||||
|
return (h + 1) / 2
|
||||||
@@ -6,6 +6,12 @@ from transformers import AutoModel, Dinov2Model
|
|||||||
|
|
||||||
|
|
||||||
class DinoCompressor(nn.Module):
|
class DinoCompressor(nn.Module):
|
||||||
|
"""DINOv2 feature extractor with optional hash compression.
|
||||||
|
|
||||||
|
When compressor is None: returns normalized DINO embeddings.
|
||||||
|
When compressor is provided: returns binary hash bits for CAM storage.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, compressor: Optional[nn.Module] = None):
|
def __init__(self, compressor: Optional[nn.Module] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -25,5 +31,6 @@ class DinoCompressor(nn.Module):
|
|||||||
if self.compressor is None:
|
if self.compressor is None:
|
||||||
return teacher_embed
|
return teacher_embed
|
||||||
|
|
||||||
feats, recon = self.compressor(teacher_tokens)
|
# HashCompressor returns (logits, hash_codes, bits)
|
||||||
return feats
|
_, _, bits = self.compressor(teacher_tokens)
|
||||||
|
return bits # [B, 512] binary bits for CAM
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class FloatCompressor(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# projection head
|
|
||||||
self.proj = nn.Sequential(
|
|
||||||
nn.Linear(1024, 1024),
|
|
||||||
nn.LayerNorm(1024),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(1024, 512),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.recover = nn.Linear(512, 1024)
|
|
||||||
|
|
||||||
def forward(self, tokens):
|
|
||||||
pooled = tokens.mean(dim=1) # [B,1024]
|
|
||||||
|
|
||||||
z512 = self.proj(pooled) # [B,512]
|
|
||||||
z512 = F.normalize(z512, dim=-1)
|
|
||||||
|
|
||||||
recon = self.recover(z512) # [B,1024]
|
|
||||||
|
|
||||||
return z512, recon
|
|
||||||
364
mini-nav/compressors/hash_compressor.py
Normal file
364
mini-nav/compressors/hash_compressor.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""Hash-based compressor for CAM-compatible binary codes.
|
||||||
|
|
||||||
|
Converts DINO features to 512-bit binary hash codes suitable for
|
||||||
|
Content Addressable Memory (CAM) retrieval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .common import BinarySign, hamming_similarity
|
||||||
|
|
||||||
|
|
||||||
|
class HashCompressor(nn.Module):
|
||||||
|
"""Compress DINO tokens to 512-bit binary codes for CAM storage.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
tokens -> mean pool -> projection -> binary sign -> hash codes
|
||||||
|
|
||||||
|
Output formats:
|
||||||
|
- logits: continuous values for training (before sign)
|
||||||
|
- hash_codes: {-1, +1} for similarity computation
|
||||||
|
- bits: {0, 1} for CAM storage
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> compressor = HashCompressor()
|
||||||
|
>>> tokens = torch.randn(4, 197, 1024) # DINO output
|
||||||
|
>>> logits, hash_codes, bits = compressor(tokens)
|
||||||
|
>>> bits.shape
|
||||||
|
torch.Size([4, 512])
|
||||||
|
>>> bits.dtype
|
||||||
|
torch.int32
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim: int = 1024, hash_bits: int = 512):
|
||||||
|
"""Initialize hash compressor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: Input feature dimension (DINO output = 1024)
|
||||||
|
hash_bits: Number of bits in hash code (CAM constraint = 512)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.hash_bits = hash_bits
|
||||||
|
|
||||||
|
# Projection head: maps DINO features to hash logits
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, input_dim),
|
||||||
|
nn.LayerNorm(input_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(input_dim, hash_bits),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize last layer with smaller weights for stable training
|
||||||
|
nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1)
|
||||||
|
nn.init.zeros_(self.proj[-1].bias)
|
||||||
|
|
||||||
|
def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward pass producing hash codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: DINO patch tokens, shape [B, N, input_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (logits, hash_codes, bits):
|
||||||
|
- logits: [B, hash_bits] continuous values for training
|
||||||
|
- hash_codes: [B, hash_bits] {-1, +1} values
|
||||||
|
- bits: [B, hash_bits] {0, 1} values for CAM storage
|
||||||
|
"""
|
||||||
|
# Pool tokens to single feature vector
|
||||||
|
pooled = tokens.mean(dim=1) # [B, input_dim]
|
||||||
|
|
||||||
|
# Project to hash dimension
|
||||||
|
logits = self.proj(pooled) # [B, hash_bits]
|
||||||
|
|
||||||
|
# Binary hash codes with STE for backprop
|
||||||
|
hash_codes = BinarySign.apply(logits) # [B, hash_bits] in {-1, +1}
|
||||||
|
|
||||||
|
# Convert to bits for CAM storage
|
||||||
|
bits = (hash_codes > 0).int() # [B, hash_bits] in {0, 1}
|
||||||
|
|
||||||
|
return logits, hash_codes, bits
|
||||||
|
|
||||||
|
def encode(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Encode tokens to binary bits for CAM storage.
|
||||||
|
|
||||||
|
This is the inference-time method for database insertion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: DINO patch tokens, shape [B, N, input_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary bits [B, hash_bits] as int32 for CAM
|
||||||
|
"""
|
||||||
|
_, _, bits = self.forward(tokens)
|
||||||
|
return bits
|
||||||
|
|
||||||
|
def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute Hamming similarity between query and database entries.
|
||||||
|
|
||||||
|
Higher score = more similar (fewer differing bits).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_bits: Query bits {0,1}, shape [Q, hash_bits]
|
||||||
|
db_bits: Database bits {0,1}, shape [N, hash_bits]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity scores [Q, N], range [0, hash_bits]
|
||||||
|
"""
|
||||||
|
# Convert bits to hash codes
|
||||||
|
query_hash = query_bits * 2 - 1 # {0,1} -> {-1,+1}
|
||||||
|
db_hash = db_bits * 2 - 1
|
||||||
|
|
||||||
|
return hamming_similarity(query_hash, db_hash)
|
||||||
|
|
||||||
|
|
||||||
|
class HashLoss(nn.Module):
|
||||||
|
"""Batch-level retrieval loss for hash code learning.
|
||||||
|
|
||||||
|
Combines three objectives:
|
||||||
|
1. Contrastive: similar inputs have similar hash codes
|
||||||
|
2. Distillation: hash preserves original DINO similarity structure
|
||||||
|
3. Quantization: hash codes are close to binary {-1, +1}
|
||||||
|
|
||||||
|
All losses are computed within batch - no full database retrieval needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
contrastive_weight: float = 1.0,
|
||||||
|
distill_weight: float = 0.5,
|
||||||
|
quant_weight: float = 0.01,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
):
|
||||||
|
"""Initialize loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contrastive_weight: Weight for contrastive loss
|
||||||
|
distill_weight: Weight for distillation loss
|
||||||
|
quant_weight: Weight for quantization loss
|
||||||
|
temperature: Temperature for contrastive similarity scaling
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.contrastive_weight = contrastive_weight
|
||||||
|
self.distill_weight = distill_weight
|
||||||
|
self.quant_weight = quant_weight
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def contrastive_loss(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
hash_codes: torch.Tensor,
|
||||||
|
positive_mask: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""InfoNCE-style contrastive loss within batch.
|
||||||
|
|
||||||
|
Learns that positive pairs (similar images) have similar hash codes,
|
||||||
|
and negative pairs (different images) have dissimilar codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Continuous logits [B, hash_bits]
|
||||||
|
hash_codes: Binary hash codes {-1,+1} [B, hash_bits]
|
||||||
|
positive_mask: Boolean mask [B, B] where True indicates positive pair
|
||||||
|
If None, uses identity matrix (each sample is its own positive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar contrastive loss
|
||||||
|
"""
|
||||||
|
batch_size = logits.size(0)
|
||||||
|
device = logits.device
|
||||||
|
|
||||||
|
# Use cosine similarity on continuous logits (more stable during training)
|
||||||
|
logits_norm = F.normalize(logits, dim=-1)
|
||||||
|
sim_matrix = logits_norm @ logits_norm.t() / self.temperature # [B, B]
|
||||||
|
|
||||||
|
# Create positive mask: diagonal is always positive (self-similarity)
|
||||||
|
if positive_mask is None:
|
||||||
|
positive_mask = torch.eye(batch_size, device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
# InfoNCE: for each sample, positives should have high similarity
|
||||||
|
# Mask out self-similarity for numerical stability
|
||||||
|
mask_self = torch.eye(batch_size, device=device, dtype=torch.bool)
|
||||||
|
sim_matrix_masked = sim_matrix.masked_fill(mask_self, float("-inf"))
|
||||||
|
|
||||||
|
# For each anchor, positives are the target
|
||||||
|
# We use a symmetric formulation: each positive pair contributes
|
||||||
|
loss = 0.0
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
pos_indices = positive_mask[i].nonzero(as_tuple=True)[0]
|
||||||
|
if len(pos_indices) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Numerator: similarity to positives
|
||||||
|
pos_sim = sim_matrix[i, pos_indices] # [num_positives]
|
||||||
|
|
||||||
|
# Denominator: similarity to all negatives (including self as neg for stability)
|
||||||
|
neg_sim = sim_matrix_masked[i] # [B]
|
||||||
|
|
||||||
|
# Log-sum-exp for numerical stability
|
||||||
|
max_sim = neg_sim.max()
|
||||||
|
log_denom = max_sim + torch.log(torch.exp(neg_sim - max_sim).sum())
|
||||||
|
|
||||||
|
# Loss for this anchor
|
||||||
|
loss += -pos_sim.mean() + log_denom
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
return loss / max(num_positives, 1)
|
||||||
|
|
||||||
|
def distillation_loss(
|
||||||
|
self,
|
||||||
|
hash_codes: torch.Tensor,
|
||||||
|
teacher_embed: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Distillation loss preserving DINO similarity structure.
|
||||||
|
|
||||||
|
Ensures that if two images are similar in DINO space,
|
||||||
|
they remain similar in hash space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hash_codes: Binary hash codes {-1,+1} [B, hash_bits]
|
||||||
|
teacher_embed: DINO embeddings [B, teacher_dim], assumed normalized
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar distillation loss
|
||||||
|
"""
|
||||||
|
hash_bits = hash_codes.size(-1)
|
||||||
|
|
||||||
|
# Hash similarity: inner product of {-1,+1} gives range [-hash_bits, hash_bits]
|
||||||
|
hash_sim = hash_codes @ hash_codes.t() # [B, B]
|
||||||
|
hash_sim = hash_sim / hash_bits # Normalize to [-1, 1]
|
||||||
|
|
||||||
|
# Teacher similarity: cosine (assumes teacher_embed is normalized)
|
||||||
|
teacher_sim = teacher_embed @ teacher_embed.t() # [B, B]
|
||||||
|
|
||||||
|
# MSE between similarity matrices
|
||||||
|
loss = F.mse_loss(hash_sim, teacher_sim)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def quantization_loss(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Quantization loss pushing logits toward {-1, +1}.
|
||||||
|
|
||||||
|
Without this, logits stay near 0 and sign() is unstable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Continuous logits [B, hash_bits]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar quantization loss
|
||||||
|
"""
|
||||||
|
# Push |logit| toward 1
|
||||||
|
return torch.mean(torch.abs(logits.abs() - 1))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
hash_codes: torch.Tensor,
|
||||||
|
teacher_embed: torch.Tensor,
|
||||||
|
positive_mask: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||||
|
"""Compute combined hash training loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Continuous logits [B, hash_bits]
|
||||||
|
hash_codes: Binary hash codes {-1,+1} [B, hash_bits]
|
||||||
|
teacher_embed: DINO embeddings [B, teacher_dim]
|
||||||
|
positive_mask: Optional positive pair mask [B, B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (total_loss, loss_components_dict)
|
||||||
|
"""
|
||||||
|
# Ensure teacher embeddings are normalized
|
||||||
|
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
||||||
|
|
||||||
|
# Compute individual losses
|
||||||
|
loss_cont = self.contrastive_loss(logits, hash_codes, positive_mask)
|
||||||
|
loss_distill = self.distillation_loss(hash_codes, teacher_embed)
|
||||||
|
loss_quant = self.quantization_loss(logits)
|
||||||
|
|
||||||
|
# Combine
|
||||||
|
total_loss = (
|
||||||
|
self.contrastive_weight * loss_cont
|
||||||
|
+ self.distill_weight * loss_distill
|
||||||
|
+ self.quant_weight * loss_quant
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return components for logging
|
||||||
|
components = {
|
||||||
|
"contrastive": loss_cont.item(),
|
||||||
|
"distill": loss_distill.item(),
|
||||||
|
"quantization": loss_quant.item(),
|
||||||
|
"total": total_loss.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return total_loss, components
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPositiveMask:
|
||||||
|
"""Generate positive pair masks for video sequences.
|
||||||
|
|
||||||
|
In indoor navigation, consecutive video frames are positive pairs
|
||||||
|
(same location, different viewpoint/lighting).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temporal_window: int = 3):
|
||||||
|
"""Initialize mask generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temporal_window: Frames within this distance are considered positive
|
||||||
|
"""
|
||||||
|
self.temporal_window = temporal_window
|
||||||
|
|
||||||
|
def from_frame_indices(self, frame_indices: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create positive mask from frame indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_indices: Frame index for each sample [B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean mask [B, B] where True indicates positive pair
|
||||||
|
"""
|
||||||
|
batch_size = frame_indices.size(0)
|
||||||
|
device = frame_indices.device
|
||||||
|
|
||||||
|
# Compute temporal distance
|
||||||
|
indices_i = frame_indices.unsqueeze(1) # [B, 1]
|
||||||
|
indices_j = frame_indices.unsqueeze(0) # [1, B]
|
||||||
|
temporal_dist = (indices_i - indices_j).abs() # [B, B]
|
||||||
|
|
||||||
|
# Positive if within temporal window
|
||||||
|
positive_mask = temporal_dist <= self.temporal_window
|
||||||
|
|
||||||
|
# Exclude self (diagonal will be handled separately in loss)
|
||||||
|
# Actually keep it, loss handles self-similarity specially
|
||||||
|
|
||||||
|
return positive_mask
|
||||||
|
|
||||||
|
def from_video_ids(
|
||||||
|
self, video_ids: torch.Tensor, frame_indices: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create positive mask considering both video ID and frame index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_ids: Video ID for each sample [B]
|
||||||
|
frame_indices: Frame index within video [B]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean mask [B, B] where True indicates positive pair
|
||||||
|
"""
|
||||||
|
batch_size = video_ids.size(0)
|
||||||
|
device = video_ids.device
|
||||||
|
|
||||||
|
# Same video
|
||||||
|
same_video = video_ids.unsqueeze(1) == video_ids.unsqueeze(0) # [B, B]
|
||||||
|
|
||||||
|
# Temporal proximity
|
||||||
|
temporal_dist = (frame_indices.unsqueeze(1) - frame_indices.unsqueeze(0)).abs()
|
||||||
|
temporal_close = temporal_dist <= self.temporal_window
|
||||||
|
|
||||||
|
# Positive if same video AND temporally close
|
||||||
|
return same_video & temporal_close
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class IntCompressor(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pass
|
|
||||||
0
mini-nav/compressors/segament_compressor.py
Normal file
0
mini-nav/compressors/segament_compressor.py
Normal file
@@ -1,8 +1,10 @@
|
|||||||
|
"""Training script for hash compressor."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from compressors import FloatCompressor
|
from compressors import HashCompressor, HashLoss
|
||||||
from configs import cfg_manager
|
from configs import cfg_manager
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -41,9 +43,20 @@ def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"):
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt"
|
epoch_size: int = 10,
|
||||||
|
batch_size: int = 64,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
checkpoint_path: str = "hash_checkpoint.pt",
|
||||||
):
|
):
|
||||||
# Auto dectect device
|
"""Train hash compressor with batch-level retrieval loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch_size: Number of epochs to train
|
||||||
|
batch_size: Batch size for training
|
||||||
|
lr: Learning rate
|
||||||
|
checkpoint_path: Path to save/load checkpoints
|
||||||
|
"""
|
||||||
|
# Auto detect device
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Global variables
|
# Global variables
|
||||||
@@ -60,17 +73,25 @@ def train(
|
|||||||
"facebook/dinov2-large", device_map=device
|
"facebook/dinov2-large", device_map=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load DINO model (frozen)
|
||||||
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
||||||
dino.eval()
|
dino.eval()
|
||||||
for p in dino.parameters():
|
for p in dino.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
# Load compressor model
|
# Load hash compressor
|
||||||
compressor = FloatCompressor().to(device)
|
compressor = HashCompressor(input_dim=1024, hash_bits=512).to(device)
|
||||||
|
|
||||||
|
# Load loss function
|
||||||
|
loss_fn = HashLoss(
|
||||||
|
contrastive_weight=1.0,
|
||||||
|
distill_weight=0.5,
|
||||||
|
quant_weight=0.01,
|
||||||
|
temperature=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
# Load optimizer
|
# Load optimizer
|
||||||
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
optimizer = torch.optim.AdamW(compressor.parameters(), lr=lr)
|
||||||
|
|
||||||
# Auto load checkpoint
|
# Auto load checkpoint
|
||||||
output_dir = cfg_manager.get().output.directory
|
output_dir = cfg_manager.get().output.directory
|
||||||
@@ -99,32 +120,38 @@ def train(
|
|||||||
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||||
|
|
||||||
# ---- student forward ----
|
# ---- student forward ----
|
||||||
z512, recon = compressor(teacher_tokens)
|
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||||
|
|
||||||
# ---- loss ----
|
# ---- loss ----
|
||||||
mse_loss = F.mse_loss(recon, teacher_embed)
|
total_loss, components = loss_fn(
|
||||||
|
logits=logits,
|
||||||
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
hash_codes=hash_codes,
|
||||||
|
teacher_embed=teacher_embed,
|
||||||
loss = mse_loss + cos_loss
|
)
|
||||||
|
|
||||||
# ---- backward ----
|
# ---- backward ----
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
total_loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
train_bar.set_postfix(loss=loss.item())
|
# ---- logging ----
|
||||||
|
train_bar.set_postfix(
|
||||||
|
loss=f"{components['total']:.4f}",
|
||||||
|
cont=f"{components['contrastive']:.2f}",
|
||||||
|
distill=f"{components['distill']:.3f}",
|
||||||
|
quant=f"{components['quantization']:.2f}",
|
||||||
|
)
|
||||||
|
|
||||||
# ---- periodic save ----
|
# ---- periodic save ----
|
||||||
if global_step % save_every == 0:
|
if global_step % save_every == 0:
|
||||||
save_checkpoint(compressor, optimizer, epoch, global_step)
|
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n⚠️ Training interrupted, saving checkpoint...")
|
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||||
|
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||||
save_checkpoint(compressor, optimizer, epoch, global_step)
|
|
||||||
|
|
||||||
print("✅ Checkpoint saved. Exiting.")
|
print("✅ Checkpoint saved. Exiting.")
|
||||||
return
|
return
|
||||||
|
|
||||||
torch.save(compressor.state_dict(), output_dir / "compressor.pt")
|
# Save final model
|
||||||
print("✅ Final compressor saved")
|
torch.save(compressor.state_dict(), output_dir / "hash_compressor.pt")
|
||||||
|
print("✅ Final hash compressor saved")
|
||||||
|
|||||||
@@ -10,9 +10,12 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.action == "train":
|
if args.action == "train":
|
||||||
from compressors import FloatCompressor, train
|
from compressors import train
|
||||||
|
|
||||||
train(FloatCompressor(), 1, 32)
|
# 启动训练
|
||||||
|
train(
|
||||||
|
epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt"
|
||||||
|
)
|
||||||
elif args.action == "benchmark":
|
elif args.action == "benchmark":
|
||||||
from benchmarks import evaluate
|
from benchmarks import evaluate
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user