fix(compressors): fix the wrong usage of loss function in training pipeline

This commit is contained in:
2026-02-28 17:49:55 +08:00
parent 1926cb53e2
commit f61857feba
4 changed files with 29 additions and 9 deletions

1
.gitignore vendored
View File

@@ -206,6 +206,7 @@ marimo/_lsp/
__marimo__/ __marimo__/
# Projects # Projects
datasets/
data/ data/
deps/ deps/
outputs/ outputs/

View File

@@ -13,6 +13,7 @@
- 先编写测试集,再实现代码 - 先编写测试集,再实现代码
- 实现测试集后,先询问用户意见,用户确认后才能继续 - 实现测试集后,先询问用户意见,用户确认后才能继续
- 如非用户要求,无需编写基准测试代码 - 如非用户要求,无需编写基准测试代码
- 英文注释
### 测试编写原则 ### 测试编写原则
- 精简、干净、快速 - 精简、干净、快速

View File

@@ -4,9 +4,12 @@ Converts DINO features to 512-bit binary hash codes suitable for
Content Addressable Memory (CAM) retrieval. Content Addressable Memory (CAM) retrieval.
""" """
from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from .common import BinarySign, hamming_similarity from .common import BinarySign, hamming_similarity
@@ -53,10 +56,12 @@ class HashCompressor(nn.Module):
) )
# Initialize last layer with smaller weights for stable training # Initialize last layer with smaller weights for stable training
nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1) nn.init.xavier_uniform_(cast(Tensor, self.proj[-1].weight), gain=0.1)
nn.init.zeros_(self.proj[-1].bias) nn.init.zeros_(cast(Tensor, self.proj[-1].bias))
def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def forward(
self, tokens: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass producing hash codes. """Forward pass producing hash codes.
Args: Args:
@@ -96,7 +101,9 @@ class HashCompressor(nn.Module):
_, _, bits = self.forward(tokens) _, _, bits = self.forward(tokens)
return bits return bits
def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor: def compute_similarity(
self, query_bits: torch.Tensor, db_bits: torch.Tensor
) -> torch.Tensor:
"""Compute Hamming similarity between query and database entries. """Compute Hamming similarity between query and database entries.
Higher score = more similar (fewer differing bits). Higher score = more similar (fewer differing bits).
@@ -259,7 +266,7 @@ class HashLoss(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
hash_codes: torch.Tensor, hash_codes: torch.Tensor,
teacher_embed: torch.Tensor, teacher_embed: torch.Tensor,
positive_mask: torch.Tensor | None = None, positive_mask: torch.Tensor,
) -> tuple[torch.Tensor, dict[str, float]]: ) -> tuple[torch.Tensor, dict[str, float]]:
"""Compute combined hash training loss. """Compute combined hash training loss.

View File

@@ -6,12 +6,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from compressors import HashCompressor, HashLoss from compressors import HashCompressor, HashLoss
from configs import cfg_manager from configs import cfg_manager
from datasets import load_dataset
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
from datasets import load_dataset
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"): def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
config = cfg_manager.get() config = cfg_manager.get()
@@ -65,8 +66,10 @@ def train(
global_step = 0 global_step = 0
# Load dataset # Load dataset
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch") ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4) dataloader = DataLoader(
ds_train, batch_size=batch_size, shuffle=True, num_workers=4
)
# Load processor # Load processor
processor = AutoImageProcessor.from_pretrained( processor = AutoImageProcessor.from_pretrained(
@@ -122,11 +125,17 @@ def train(
# ---- student forward ---- # ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens) logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ---- # ---- loss ----
total_loss, components = loss_fn( total_loss, components = loss_fn(
logits=logits, logits=logits,
hash_codes=hash_codes, hash_codes=hash_codes,
teacher_embed=teacher_embed, teacher_embed=teacher_embed,
positive_mask=positive_mask,
) )
# ---- backward ---- # ---- backward ----
@@ -144,7 +153,9 @@ def train(
# ---- periodic save ---- # ---- periodic save ----
if global_step % save_every == 0: if global_step % save_every == 0:
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path) save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...") print("\n⚠️ Training interrupted, saving checkpoint...")