diff --git a/.gitignore b/.gitignore index 10d1859..95a6253 100644 --- a/.gitignore +++ b/.gitignore @@ -206,6 +206,7 @@ marimo/_lsp/ __marimo__/ # Projects +datasets/ data/ deps/ outputs/ diff --git a/DEVELOPMENT.md b/CLAUDE.md similarity index 99% rename from DEVELOPMENT.md rename to CLAUDE.md index 3c85fdd..f0aa60a 100644 --- a/DEVELOPMENT.md +++ b/CLAUDE.md @@ -13,6 +13,7 @@ - 先编写测试集,再实现代码 - 实现测试集后,先询问用户意见,用户确认后才能继续 - 如非用户要求,无需编写基准测试代码 +- 英文注释 ### 测试编写原则 - 精简、干净、快速 diff --git a/mini-nav/compressors/hash_compressor.py b/mini-nav/compressors/hash_compressor.py index 248c78c..7f87fbd 100644 --- a/mini-nav/compressors/hash_compressor.py +++ b/mini-nav/compressors/hash_compressor.py @@ -4,9 +4,12 @@ Converts DINO features to 512-bit binary hash codes suitable for Content Addressable Memory (CAM) retrieval. """ +from typing import cast + import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from .common import BinarySign, hamming_similarity @@ -53,10 +56,12 @@ class HashCompressor(nn.Module): ) # Initialize last layer with smaller weights for stable training - nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1) - nn.init.zeros_(self.proj[-1].bias) + nn.init.xavier_uniform_(cast(Tensor, self.proj[-1].weight), gain=0.1) + nn.init.zeros_(cast(Tensor, self.proj[-1].bias)) - def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, tokens: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass producing hash codes. Args: @@ -96,7 +101,9 @@ class HashCompressor(nn.Module): _, _, bits = self.forward(tokens) return bits - def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor: + def compute_similarity( + self, query_bits: torch.Tensor, db_bits: torch.Tensor + ) -> torch.Tensor: """Compute Hamming similarity between query and database entries. Higher score = more similar (fewer differing bits). @@ -259,7 +266,7 @@ class HashLoss(nn.Module): logits: torch.Tensor, hash_codes: torch.Tensor, teacher_embed: torch.Tensor, - positive_mask: torch.Tensor | None = None, + positive_mask: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, float]]: """Compute combined hash training loss. diff --git a/mini-nav/compressors/train.py b/mini-nav/compressors/train.py index 279c72e..dde89f1 100644 --- a/mini-nav/compressors/train.py +++ b/mini-nav/compressors/train.py @@ -6,12 +6,13 @@ import torch import torch.nn.functional as F from compressors import HashCompressor, HashLoss from configs import cfg_manager -from datasets import load_dataset from torch import nn from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import AutoImageProcessor, AutoModel +from datasets import load_dataset + def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"): config = cfg_manager.get() @@ -65,8 +66,10 @@ def train( global_step = 0 # Load dataset - ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch") - dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4) + ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch") + dataloader = DataLoader( + ds_train, batch_size=batch_size, shuffle=True, num_workers=4 + ) # Load processor processor = AutoImageProcessor.from_pretrained( @@ -122,11 +125,17 @@ def train( # ---- student forward ---- logits, hash_codes, bits = compressor(teacher_tokens) + # ---- generate positive mask ---- + labels = batch["label"] + # positive_mask[i,j] = True if labels[i] == labels[j] + positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B] + # ---- loss ---- total_loss, components = loss_fn( logits=logits, hash_codes=hash_codes, teacher_embed=teacher_embed, + positive_mask=positive_mask, ) # ---- backward ---- @@ -144,7 +153,9 @@ def train( # ---- periodic save ---- if global_step % save_every == 0: - save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path) + save_checkpoint( + compressor, optimizer, epoch, global_step, checkpoint_path + ) except KeyboardInterrupt: print("\n⚠️ Training interrupted, saving checkpoint...")