diff --git a/mini-nav/compressors/train.py b/mini-nav/compressors/train.py index 2bda281..3aa9b50 100644 --- a/mini-nav/compressors/train.py +++ b/mini-nav/compressors/train.py @@ -1,6 +1,9 @@ +import os + import torch import torch.nn.functional as F from compressors import FloatCompressor +from configs import cfg_manager from datasets import load_dataset from torch import nn from torch.utils.data import DataLoader @@ -8,66 +11,122 @@ from tqdm.auto import tqdm from transformers import AutoImageProcessor, AutoModel -def train(dinov2: nn.Module, epoch_size: int, batch_size: int): +def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"): + config = cfg_manager.get() + path = config.output.directory / path + + ckpt = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "epoch": epoch, + "step": step, + } + torch.save(ckpt, path) + print(f"✅ Saved checkpoint to {path}") + + +def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"): + ckpt = torch.load(path, map_location="cpu") + + model.load_state_dict(ckpt["model"]) + optimizer.load_state_dict(ckpt["optimizer"]) + + start_epoch = ckpt["epoch"] + start_step = ckpt["step"] + + print(f"✅ Loaded checkpoint from {path}") + print(f"➡️ Resume from epoch={start_epoch}, step={start_step}") + + return start_epoch, start_step + + +def train( + dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt" +): + # Auto dectect device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Global variables + save_every = 500 + start_epoch = 0 + global_step = 0 + + # Load dataset ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch") dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4) - label_map = [ - "airplane", - "automobile", - "bird", - "cat", - "deer", - "dog", - "frog", - "horse", - "ship", - "truck", - ] + # Load processor processor = AutoImageProcessor.from_pretrained( "facebook/dinov2-large", device_map=device ) + + # Load model dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device) dino.eval() for p in dino.parameters(): p.requires_grad = False + # Load compressor model compressor = FloatCompressor().to(device) + # Load optimizer optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4) - for epoch in range(epoch_size): - train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]") + # Auto load checkpoint + output_dir = cfg_manager.get().output.directory + if os.path.exists(output_dir / checkpoint_path): + start_epoch, global_step = load_checkpoint( + compressor, optimizer, output_dir / checkpoint_path + ) - for batch in train_bar: - imgs = batch["img"] + try: + for epoch in range(start_epoch, epoch_size): + train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]") - # ---- teacher forward ---- - with torch.no_grad(): - inputs = processor(imgs, return_tensors="pt").to(device) + for batch in train_bar: + global_step += 1 - teacher_tokens = dino(**inputs).last_hidden_state - # [B,N,1024] + # ---- training step ---- + imgs = batch["img"] - teacher_embed = teacher_tokens.mean(dim=1) - teacher_embed = F.normalize(teacher_embed, dim=-1) - # [B,1024] + # ---- teacher forward ---- + with torch.no_grad(): + inputs = processor(imgs, return_tensors="pt").to(device) - # ---- student forward ---- - z512, recon = compressor(teacher_tokens) + teacher_tokens = dino(**inputs).last_hidden_state + # [B,N,1024] - # ---- loss ---- - mse_loss = F.mse_loss(recon, teacher_embed) + teacher_embed = teacher_tokens.mean(dim=1) + teacher_embed = F.normalize(teacher_embed, dim=-1) + # [B,1024] - cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean() + # ---- student forward ---- + z512, recon = compressor(teacher_tokens) - loss = mse_loss + cos_loss + # ---- loss ---- + mse_loss = F.mse_loss(recon, teacher_embed) - # ---- backward ---- - optimizer.zero_grad() - loss.backward() - optimizer.step() + cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean() - train_bar.set_postfix(loss=loss.item()) + loss = mse_loss + cos_loss + + # ---- backward ---- + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_bar.set_postfix(loss=loss.item()) + + # ---- periodic save ---- + if global_step % save_every == 0: + save_checkpoint(compressor, optimizer, epoch, global_step) + except KeyboardInterrupt: + print("\n⚠️ Training interrupted, saving checkpoint...") + + save_checkpoint(compressor, optimizer, epoch, global_step) + + print("✅ Checkpoint saved. Exiting.") + return + + torch.save(compressor.state_dict(), output_dir / "compressor.pt") + print("✅ Final compressor saved")