Files
Mini-Nav/mini-nav/compressors/train.py

74 lines
2.2 KiB
Python

import torch
import torch.nn.functional as F
from compressors import FloatCompressor
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
label_map = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=device
)
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
dino.eval()
for p in dino.parameters():
p.requires_grad = False
compressor = FloatCompressor().to(device)
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
for epoch in range(epoch_size):
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
for batch in train_bar:
imgs = batch["img"]
# ---- teacher forward ----
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state
# [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1)
# [B,1024]
# ---- student forward ----
z512, recon = compressor(teacher_tokens)
# ---- loss ----
mse_loss = F.mse_loss(recon, teacher_embed)
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
loss = mse_loss + cos_loss
# ---- backward ----
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_bar.set_postfix(loss=loss.item())