"""Training script for hash compressor.""" import os import torch import torch.nn.functional as F from compressors import HashCompressor, HashLoss from configs import cfg_manager from torch import nn from torch.utils.data import DataLoader from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn from transformers import AutoImageProcessor, AutoModel from datasets import load_dataset def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"): config = cfg_manager.get() path = config.output.directory / path ckpt = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "step": step, } torch.save(ckpt, path) print(f"✅ Saved checkpoint to {path}") def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"): ckpt = torch.load(path, map_location="cpu") model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) start_epoch = ckpt["epoch"] start_step = ckpt["step"] print(f"✅ Loaded checkpoint from {path}") print(f"➡️ Resume from epoch={start_epoch}, step={start_step}") return start_epoch, start_step def train( epoch_size: int = 10, batch_size: int = 64, lr: float = 1e-4, checkpoint_path: str = "hash_checkpoint.pt", ): """Train hash compressor with batch-level retrieval loss. Args: epoch_size: Number of epochs to train batch_size: Batch size for training lr: Learning rate checkpoint_path: Path to save/load checkpoints """ # Auto detect device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Global variables save_every = 500 start_epoch = 0 global_step = 0 # Load dataset ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch") dataloader = DataLoader( ds_train, batch_size=batch_size, shuffle=True, num_workers=4 ) # Load processor processor = AutoImageProcessor.from_pretrained( "facebook/dinov2-large", device_map=device ) # Load DINO model (frozen) dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device) dino.eval() for p in dino.parameters(): p.requires_grad = False # Load hash compressor compressor = HashCompressor(input_dim=1024, hash_bits=512).to(device) # Load loss function loss_fn = HashLoss( contrastive_weight=1.0, distill_weight=0.5, quant_weight=0.01, temperature=0.2, ) # Load optimizer optimizer = torch.optim.AdamW(compressor.parameters(), lr=lr) # Auto load checkpoint output_dir = cfg_manager.get().output.directory if os.path.exists(output_dir / checkpoint_path): start_epoch, global_step = load_checkpoint( compressor, optimizer, output_dir / checkpoint_path ) try: progress = Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), ) with progress: for epoch in range(start_epoch, epoch_size): task_id = progress.add_task( f"Epoch [{epoch + 1}/{epoch_size}]", total=len(dataloader) ) for batch in dataloader: progress.update(task_id, advance=1) global_step += 1 # ---- training step ---- imgs = batch["img"] # ---- teacher forward ---- with torch.no_grad(): inputs = processor(imgs, return_tensors="pt").to(device) teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024] teacher_embed = teacher_tokens.mean(dim=1) teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] # ---- student forward ---- logits, hash_codes, bits = compressor(teacher_tokens) # ---- generate positive mask ---- labels = batch["label"] # positive_mask[i,j] = True if labels[i] == labels[j] positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B] # ---- loss ---- total_loss, components = loss_fn( logits=logits, hash_codes=hash_codes, teacher_embed=teacher_embed, positive_mask=positive_mask, ) # ---- backward ---- optimizer.zero_grad() total_loss.backward() optimizer.step() # ---- logging ---- progress.update( task_id, description=f"Epoch [{epoch + 1}/{epoch_size}] " f"loss={components['total']:.4f} " f"cont={components['contrastive']:.2f} " f"distill={components['distill']:.3f}", ) # ---- periodic save ---- if global_step % save_every == 0: save_checkpoint( compressor, optimizer, epoch, global_step, checkpoint_path ) except KeyboardInterrupt: print("\n⚠️ Training interrupted, saving checkpoint...") save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path) print("✅ Checkpoint saved. Exiting.") return # Save final model torch.save(compressor.state_dict(), output_dir / "hash_compressor.pt") print("✅ Final hash compressor saved")