mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
fix(compressors): fix the wrong usage of loss function in training pipeline
This commit is contained in:
@@ -6,12 +6,13 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from compressors import HashCompressor, HashLoss
|
||||
from configs import cfg_manager
|
||||
from datasets import load_dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
|
||||
config = cfg_manager.get()
|
||||
@@ -65,8 +66,10 @@ def train(
|
||||
global_step = 0
|
||||
|
||||
# Load dataset
|
||||
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
||||
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
||||
dataloader = DataLoader(
|
||||
ds_train, batch_size=batch_size, shuffle=True, num_workers=4
|
||||
)
|
||||
|
||||
# Load processor
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
@@ -122,11 +125,17 @@ def train(
|
||||
# ---- student forward ----
|
||||
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||
|
||||
# ---- generate positive mask ----
|
||||
labels = batch["label"]
|
||||
# positive_mask[i,j] = True if labels[i] == labels[j]
|
||||
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
|
||||
|
||||
# ---- loss ----
|
||||
total_loss, components = loss_fn(
|
||||
logits=logits,
|
||||
hash_codes=hash_codes,
|
||||
teacher_embed=teacher_embed,
|
||||
positive_mask=positive_mask,
|
||||
)
|
||||
|
||||
# ---- backward ----
|
||||
@@ -144,7 +153,9 @@ def train(
|
||||
|
||||
# ---- periodic save ----
|
||||
if global_step % save_every == 0:
|
||||
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||
save_checkpoint(
|
||||
compressor, optimizer, epoch, global_step, checkpoint_path
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||
|
||||
Reference in New Issue
Block a user