mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from compressors import FloatCompressor
|
|
from datasets import load_dataset
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from tqdm.auto import tqdm
|
|
from transformers import AutoImageProcessor, AutoModel
|
|
|
|
|
|
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
|
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
|
label_map = [
|
|
"airplane",
|
|
"automobile",
|
|
"bird",
|
|
"cat",
|
|
"deer",
|
|
"dog",
|
|
"frog",
|
|
"horse",
|
|
"ship",
|
|
"truck",
|
|
]
|
|
|
|
processor = AutoImageProcessor.from_pretrained(
|
|
"facebook/dinov2-large", device_map=device
|
|
)
|
|
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
|
dino.eval()
|
|
for p in dino.parameters():
|
|
p.requires_grad = False
|
|
|
|
compressor = FloatCompressor().to(device)
|
|
|
|
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
|
|
|
for epoch in range(epoch_size):
|
|
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
|
|
|
for batch in train_bar:
|
|
imgs = batch["img"]
|
|
|
|
# ---- teacher forward ----
|
|
with torch.no_grad():
|
|
inputs = processor(imgs, return_tensors="pt").to(device)
|
|
|
|
teacher_tokens = dino(**inputs).last_hidden_state
|
|
# [B,N,1024]
|
|
|
|
teacher_embed = teacher_tokens.mean(dim=1)
|
|
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
|
# [B,1024]
|
|
|
|
# ---- student forward ----
|
|
z512, recon = compressor(teacher_tokens)
|
|
|
|
# ---- loss ----
|
|
mse_loss = F.mse_loss(recon, teacher_embed)
|
|
|
|
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
|
|
|
loss = mse_loss + cos_loss
|
|
|
|
# ---- backward ----
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
train_bar.set_postfix(loss=loss.item())
|