feat(train): add checkpointing and training interruption handling

This commit is contained in:
2026-02-08 17:24:03 +08:00
parent 8f417b674c
commit 76a572ee12

View File

@@ -1,6 +1,9 @@
import os
import torch
import torch.nn.functional as F
from compressors import FloatCompressor
from configs import cfg_manager
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
@@ -8,40 +11,82 @@ from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
config = cfg_manager.get()
path = config.output.directory / path
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"step": step,
}
torch.save(ckpt, path)
print(f"✅ Saved checkpoint to {path}")
def load_checkpoint(model: nn.Module, optimizer, path="checkpoint.pt"):
ckpt = torch.load(path, map_location="cpu")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"]
start_step = ckpt["step"]
print(f"✅ Loaded checkpoint from {path}")
print(f"➡️ Resume from epoch={start_epoch}, step={start_step}")
return start_epoch, start_step
def train(
dinov2: nn.Module, epoch_size: int, batch_size: int, checkpoint_path="checkpoint.pt"
):
# Auto dectect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Global variables
save_every = 500
start_epoch = 0
global_step = 0
# Load dataset
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
label_map = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
# Load processor
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=device
)
# Load model
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
dino.eval()
for p in dino.parameters():
p.requires_grad = False
# Load compressor model
compressor = FloatCompressor().to(device)
# Load optimizer
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
for epoch in range(epoch_size):
# Auto load checkpoint
output_dir = cfg_manager.get().output.directory
if os.path.exists(output_dir / checkpoint_path):
start_epoch, global_step = load_checkpoint(
compressor, optimizer, output_dir / checkpoint_path
)
try:
for epoch in range(start_epoch, epoch_size):
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
for batch in train_bar:
global_step += 1
# ---- training step ----
imgs = batch["img"]
# ---- teacher forward ----
@@ -71,3 +116,17 @@ def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
optimizer.step()
train_bar.set_postfix(loss=loss.item())
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(compressor, optimizer, epoch, global_step)
except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...")
save_checkpoint(compressor, optimizer, epoch, global_step)
print("✅ Checkpoint saved. Exiting.")
return
torch.save(compressor.state_dict(), output_dir / "compressor.pt")
print("✅ Final compressor saved")