fix(compressors): fix the wrong usage of loss function in training pipeline

This commit is contained in:
2026-02-28 17:49:55 +08:00
parent 1926cb53e2
commit f61857feba
4 changed files with 29 additions and 9 deletions

1
.gitignore vendored
View File

@@ -206,6 +206,7 @@ marimo/_lsp/
__marimo__/
# Projects
datasets/
data/
deps/
outputs/

View File

@@ -13,6 +13,7 @@
- 先编写测试集,再实现代码
- 实现测试集后,先询问用户意见,用户确认后才能继续
- 如非用户要求,无需编写基准测试代码
- 英文注释
### 测试编写原则
- 精简、干净、快速

View File

@@ -4,9 +4,12 @@ Converts DINO features to 512-bit binary hash codes suitable for
Content Addressable Memory (CAM) retrieval.
"""
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .common import BinarySign, hamming_similarity
@@ -53,10 +56,12 @@ class HashCompressor(nn.Module):
)
# Initialize last layer with smaller weights for stable training
nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1)
nn.init.zeros_(self.proj[-1].bias)
nn.init.xavier_uniform_(cast(Tensor, self.proj[-1].weight), gain=0.1)
nn.init.zeros_(cast(Tensor, self.proj[-1].bias))
def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(
self, tokens: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass producing hash codes.
Args:
@@ -96,7 +101,9 @@ class HashCompressor(nn.Module):
_, _, bits = self.forward(tokens)
return bits
def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor:
def compute_similarity(
self, query_bits: torch.Tensor, db_bits: torch.Tensor
) -> torch.Tensor:
"""Compute Hamming similarity between query and database entries.
Higher score = more similar (fewer differing bits).
@@ -259,7 +266,7 @@ class HashLoss(nn.Module):
logits: torch.Tensor,
hash_codes: torch.Tensor,
teacher_embed: torch.Tensor,
positive_mask: torch.Tensor | None = None,
positive_mask: torch.Tensor,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Compute combined hash training loss.

View File

@@ -6,12 +6,13 @@ import torch
import torch.nn.functional as F
from compressors import HashCompressor, HashLoss
from configs import cfg_manager
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
from datasets import load_dataset
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
config = cfg_manager.get()
@@ -65,8 +66,10 @@ def train(
global_step = 0
# Load dataset
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(
ds_train, batch_size=batch_size, shuffle=True, num_workers=4
)
# Load processor
processor = AutoImageProcessor.from_pretrained(
@@ -122,11 +125,17 @@ def train(
# ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ----
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
)
# ---- backward ----
@@ -144,7 +153,9 @@ def train(
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
)
except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...")