mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(compressors): replace float/int compressors with hash-based compression for CAM
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
"""Training script for hash compressor."""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from compressors import FloatCompressor
|
||||
from compressors import HashCompressor, HashLoss
|
||||
from configs import cfg_manager
|
||||
from datasets import load_dataset
|
||||
from torch import nn
|
||||
@@ -41,9 +43,20 @@ def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"):
|
||||
|
||||
|
||||
def train(
|
||||
dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt"
|
||||
epoch_size: int = 10,
|
||||
batch_size: int = 64,
|
||||
lr: float = 1e-4,
|
||||
checkpoint_path: str = "hash_checkpoint.pt",
|
||||
):
|
||||
# Auto dectect device
|
||||
"""Train hash compressor with batch-level retrieval loss.
|
||||
|
||||
Args:
|
||||
epoch_size: Number of epochs to train
|
||||
batch_size: Batch size for training
|
||||
lr: Learning rate
|
||||
checkpoint_path: Path to save/load checkpoints
|
||||
"""
|
||||
# Auto detect device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Global variables
|
||||
@@ -60,17 +73,25 @@ def train(
|
||||
"facebook/dinov2-large", device_map=device
|
||||
)
|
||||
|
||||
# Load model
|
||||
# Load DINO model (frozen)
|
||||
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
||||
dino.eval()
|
||||
for p in dino.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Load compressor model
|
||||
compressor = FloatCompressor().to(device)
|
||||
# Load hash compressor
|
||||
compressor = HashCompressor(input_dim=1024, hash_bits=512).to(device)
|
||||
|
||||
# Load loss function
|
||||
loss_fn = HashLoss(
|
||||
contrastive_weight=1.0,
|
||||
distill_weight=0.5,
|
||||
quant_weight=0.01,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
# Load optimizer
|
||||
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
||||
optimizer = torch.optim.AdamW(compressor.parameters(), lr=lr)
|
||||
|
||||
# Auto load checkpoint
|
||||
output_dir = cfg_manager.get().output.directory
|
||||
@@ -99,32 +120,38 @@ def train(
|
||||
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||
|
||||
# ---- student forward ----
|
||||
z512, recon = compressor(teacher_tokens)
|
||||
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||
|
||||
# ---- loss ----
|
||||
mse_loss = F.mse_loss(recon, teacher_embed)
|
||||
|
||||
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
||||
|
||||
loss = mse_loss + cos_loss
|
||||
total_loss, components = loss_fn(
|
||||
logits=logits,
|
||||
hash_codes=hash_codes,
|
||||
teacher_embed=teacher_embed,
|
||||
)
|
||||
|
||||
# ---- backward ----
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_bar.set_postfix(loss=loss.item())
|
||||
# ---- logging ----
|
||||
train_bar.set_postfix(
|
||||
loss=f"{components['total']:.4f}",
|
||||
cont=f"{components['contrastive']:.2f}",
|
||||
distill=f"{components['distill']:.3f}",
|
||||
quant=f"{components['quantization']:.2f}",
|
||||
)
|
||||
|
||||
# ---- periodic save ----
|
||||
if global_step % save_every == 0:
|
||||
save_checkpoint(compressor, optimizer, epoch, global_step)
|
||||
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||
|
||||
save_checkpoint(compressor, optimizer, epoch, global_step)
|
||||
|
||||
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||
print("✅ Checkpoint saved. Exiting.")
|
||||
return
|
||||
|
||||
torch.save(compressor.state_dict(), output_dir / "compressor.pt")
|
||||
print("✅ Final compressor saved")
|
||||
# Save final model
|
||||
torch.save(compressor.state_dict(), output_dir / "hash_compressor.pt")
|
||||
print("✅ Final hash compressor saved")
|
||||
|
||||
Reference in New Issue
Block a user