fix(compressors): fix the wrong usage of loss function in training pipeline

This commit is contained in:
2026-02-28 17:49:55 +08:00
parent 1926cb53e2
commit f61857feba
4 changed files with 29 additions and 9 deletions

View File

@@ -6,12 +6,13 @@ import torch
import torch.nn.functional as F
from compressors import HashCompressor, HashLoss
from configs import cfg_manager
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
from datasets import load_dataset
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
config = cfg_manager.get()
@@ -65,8 +66,10 @@ def train(
global_step = 0
# Load dataset
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(
ds_train, batch_size=batch_size, shuffle=True, num_workers=4
)
# Load processor
processor = AutoImageProcessor.from_pretrained(
@@ -122,11 +125,17 @@ def train(
# ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ----
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
)
# ---- backward ----
@@ -144,7 +153,9 @@ def train(
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
)
except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...")