mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(train): add checkpointing and training interruption handling
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from compressors import FloatCompressor
|
from compressors import FloatCompressor
|
||||||
|
from configs import cfg_manager
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -8,66 +11,122 @@ from tqdm.auto import tqdm
|
|||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
|
||||||
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
|
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
|
||||||
|
config = cfg_manager.get()
|
||||||
|
path = config.output.directory / path
|
||||||
|
|
||||||
|
ckpt = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": step,
|
||||||
|
}
|
||||||
|
torch.save(ckpt, path)
|
||||||
|
print(f"✅ Saved checkpoint to {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"):
|
||||||
|
ckpt = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
model.load_state_dict(ckpt["model"])
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
|
||||||
|
start_epoch = ckpt["epoch"]
|
||||||
|
start_step = ckpt["step"]
|
||||||
|
|
||||||
|
print(f"✅ Loaded checkpoint from {path}")
|
||||||
|
print(f"➡️ Resume from epoch={start_epoch}, step={start_step}")
|
||||||
|
|
||||||
|
return start_epoch, start_step
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt"
|
||||||
|
):
|
||||||
|
# Auto dectect device
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Global variables
|
||||||
|
save_every = 500
|
||||||
|
start_epoch = 0
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
||||||
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||||
label_map = [
|
|
||||||
"airplane",
|
|
||||||
"automobile",
|
|
||||||
"bird",
|
|
||||||
"cat",
|
|
||||||
"deer",
|
|
||||||
"dog",
|
|
||||||
"frog",
|
|
||||||
"horse",
|
|
||||||
"ship",
|
|
||||||
"truck",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
# Load processor
|
||||||
processor = AutoImageProcessor.from_pretrained(
|
processor = AutoImageProcessor.from_pretrained(
|
||||||
"facebook/dinov2-large", device_map=device
|
"facebook/dinov2-large", device_map=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
||||||
dino.eval()
|
dino.eval()
|
||||||
for p in dino.parameters():
|
for p in dino.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
|
# Load compressor model
|
||||||
compressor = FloatCompressor().to(device)
|
compressor = FloatCompressor().to(device)
|
||||||
|
|
||||||
|
# Load optimizer
|
||||||
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
||||||
|
|
||||||
for epoch in range(epoch_size):
|
# Auto load checkpoint
|
||||||
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
output_dir = cfg_manager.get().output.directory
|
||||||
|
if os.path.exists(output_dir / checkpoint_path):
|
||||||
|
start_epoch, global_step = load_checkpoint(
|
||||||
|
compressor, optimizer, output_dir / checkpoint_path
|
||||||
|
)
|
||||||
|
|
||||||
for batch in train_bar:
|
try:
|
||||||
imgs = batch["img"]
|
for epoch in range(start_epoch, epoch_size):
|
||||||
|
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
||||||
|
|
||||||
# ---- teacher forward ----
|
for batch in train_bar:
|
||||||
with torch.no_grad():
|
global_step += 1
|
||||||
inputs = processor(imgs, return_tensors="pt").to(device)
|
|
||||||
|
|
||||||
teacher_tokens = dino(**inputs).last_hidden_state
|
# ---- training step ----
|
||||||
# [B,N,1024]
|
imgs = batch["img"]
|
||||||
|
|
||||||
teacher_embed = teacher_tokens.mean(dim=1)
|
# ---- teacher forward ----
|
||||||
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
with torch.no_grad():
|
||||||
# [B,1024]
|
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||||
|
|
||||||
# ---- student forward ----
|
teacher_tokens = dino(**inputs).last_hidden_state
|
||||||
z512, recon = compressor(teacher_tokens)
|
# [B,N,1024]
|
||||||
|
|
||||||
# ---- loss ----
|
teacher_embed = teacher_tokens.mean(dim=1)
|
||||||
mse_loss = F.mse_loss(recon, teacher_embed)
|
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
||||||
|
# [B,1024]
|
||||||
|
|
||||||
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
# ---- student forward ----
|
||||||
|
z512, recon = compressor(teacher_tokens)
|
||||||
|
|
||||||
loss = mse_loss + cos_loss
|
# ---- loss ----
|
||||||
|
mse_loss = F.mse_loss(recon, teacher_embed)
|
||||||
|
|
||||||
# ---- backward ----
|
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
train_bar.set_postfix(loss=loss.item())
|
loss = mse_loss + cos_loss
|
||||||
|
|
||||||
|
# ---- backward ----
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_bar.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
|
# ---- periodic save ----
|
||||||
|
if global_step % save_every == 0:
|
||||||
|
save_checkpoint(compressor, optimizer, epoch, global_step)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||||
|
|
||||||
|
save_checkpoint(compressor, optimizer, epoch, global_step)
|
||||||
|
|
||||||
|
print("✅ Checkpoint saved. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.save(compressor.state_dict(), output_dir / "compressor.pt")
|
||||||
|
print("✅ Final compressor saved")
|
||||||
|
|||||||
Reference in New Issue
Block a user