feat(compressors): replace float/int compressors with hash-based compression for CAM

This commit is contained in:
2026-02-20 20:29:38 +08:00
parent 5f9d2bfcd8
commit 1926cb53e2
9 changed files with 527 additions and 64 deletions

View File

@@ -1,8 +1,10 @@
"""Training script for hash compressor."""
import os
import torch
import torch.nn.functional as F
from compressors import FloatCompressor
from compressors import HashCompressor, HashLoss
from configs import cfg_manager
from datasets import load_dataset
from torch import nn
@@ -41,9 +43,20 @@ def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"):
def train(
dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt"
epoch_size: int = 10,
batch_size: int = 64,
lr: float = 1e-4,
checkpoint_path: str = "hash_checkpoint.pt",
):
# Auto dectect device
"""Train hash compressor with batch-level retrieval loss.
Args:
epoch_size: Number of epochs to train
batch_size: Batch size for training
lr: Learning rate
checkpoint_path: Path to save/load checkpoints
"""
# Auto detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Global variables
@@ -60,17 +73,25 @@ def train(
"facebook/dinov2-large", device_map=device
)
# Load model
# Load DINO model (frozen)
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
dino.eval()
for p in dino.parameters():
p.requires_grad = False
# Load compressor model
compressor = FloatCompressor().to(device)
# Load hash compressor
compressor = HashCompressor(input_dim=1024, hash_bits=512).to(device)
# Load loss function
loss_fn = HashLoss(
contrastive_weight=1.0,
distill_weight=0.5,
quant_weight=0.01,
temperature=0.2,
)
# Load optimizer
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
optimizer = torch.optim.AdamW(compressor.parameters(), lr=lr)
# Auto load checkpoint
output_dir = cfg_manager.get().output.directory
@@ -99,32 +120,38 @@ def train(
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# ---- student forward ----
z512, recon = compressor(teacher_tokens)
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- loss ----
mse_loss = F.mse_loss(recon, teacher_embed)
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
loss = mse_loss + cos_loss
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
)
# ---- backward ----
optimizer.zero_grad()
loss.backward()
total_loss.backward()
optimizer.step()
train_bar.set_postfix(loss=loss.item())
# ---- logging ----
train_bar.set_postfix(
loss=f"{components['total']:.4f}",
cont=f"{components['contrastive']:.2f}",
distill=f"{components['distill']:.3f}",
quant=f"{components['quantization']:.2f}",
)
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(compressor, optimizer, epoch, global_step)
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...")
save_checkpoint(compressor, optimizer, epoch, global_step)
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
print("✅ Checkpoint saved. Exiting.")
return
torch.save(compressor.state_dict(), output_dir / "compressor.pt")
print("✅ Final compressor saved")
# Save final model
torch.save(compressor.state_dict(), output_dir / "hash_compressor.pt")
print("✅ Final hash compressor saved")