mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
fix(compressors): fix the wrong usage of loss function in training pipeline
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -206,6 +206,7 @@ marimo/_lsp/
|
|||||||
__marimo__/
|
__marimo__/
|
||||||
|
|
||||||
# Projects
|
# Projects
|
||||||
|
datasets/
|
||||||
data/
|
data/
|
||||||
deps/
|
deps/
|
||||||
outputs/
|
outputs/
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
- 先编写测试集,再实现代码
|
- 先编写测试集,再实现代码
|
||||||
- 实现测试集后,先询问用户意见,用户确认后才能继续
|
- 实现测试集后,先询问用户意见,用户确认后才能继续
|
||||||
- 如非用户要求,无需编写基准测试代码
|
- 如非用户要求,无需编写基准测试代码
|
||||||
|
- 英文注释
|
||||||
|
|
||||||
### 测试编写原则
|
### 测试编写原则
|
||||||
- 精简、干净、快速
|
- 精简、干净、快速
|
||||||
@@ -4,9 +4,12 @@ Converts DINO features to 512-bit binary hash codes suitable for
|
|||||||
Content Addressable Memory (CAM) retrieval.
|
Content Addressable Memory (CAM) retrieval.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from .common import BinarySign, hamming_similarity
|
from .common import BinarySign, hamming_similarity
|
||||||
|
|
||||||
@@ -53,10 +56,12 @@ class HashCompressor(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize last layer with smaller weights for stable training
|
# Initialize last layer with smaller weights for stable training
|
||||||
nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1)
|
nn.init.xavier_uniform_(cast(Tensor, self.proj[-1].weight), gain=0.1)
|
||||||
nn.init.zeros_(self.proj[-1].bias)
|
nn.init.zeros_(cast(Tensor, self.proj[-1].bias))
|
||||||
|
|
||||||
def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def forward(
|
||||||
|
self, tokens: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass producing hash codes.
|
"""Forward pass producing hash codes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -96,7 +101,9 @@ class HashCompressor(nn.Module):
|
|||||||
_, _, bits = self.forward(tokens)
|
_, _, bits = self.forward(tokens)
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor:
|
def compute_similarity(
|
||||||
|
self, query_bits: torch.Tensor, db_bits: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Compute Hamming similarity between query and database entries.
|
"""Compute Hamming similarity between query and database entries.
|
||||||
|
|
||||||
Higher score = more similar (fewer differing bits).
|
Higher score = more similar (fewer differing bits).
|
||||||
@@ -259,7 +266,7 @@ class HashLoss(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
hash_codes: torch.Tensor,
|
hash_codes: torch.Tensor,
|
||||||
teacher_embed: torch.Tensor,
|
teacher_embed: torch.Tensor,
|
||||||
positive_mask: torch.Tensor | None = None,
|
positive_mask: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||||
"""Compute combined hash training loss.
|
"""Compute combined hash training loss.
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,13 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from compressors import HashCompressor, HashLoss
|
from compressors import HashCompressor, HashLoss
|
||||||
from configs import cfg_manager
|
from configs import cfg_manager
|
||||||
from datasets import load_dataset
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
|
def save_checkpoint(model: nn.Module, optimizer, epoch, step, path="checkpoint.pt"):
|
||||||
config = cfg_manager.get()
|
config = cfg_manager.get()
|
||||||
@@ -65,8 +66,10 @@ def train(
|
|||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
ds_train = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
||||||
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
dataloader = DataLoader(
|
||||||
|
ds_train, batch_size=batch_size, shuffle=True, num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
# Load processor
|
# Load processor
|
||||||
processor = AutoImageProcessor.from_pretrained(
|
processor = AutoImageProcessor.from_pretrained(
|
||||||
@@ -122,11 +125,17 @@ def train(
|
|||||||
# ---- student forward ----
|
# ---- student forward ----
|
||||||
logits, hash_codes, bits = compressor(teacher_tokens)
|
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||||
|
|
||||||
|
# ---- generate positive mask ----
|
||||||
|
labels = batch["label"]
|
||||||
|
# positive_mask[i,j] = True if labels[i] == labels[j]
|
||||||
|
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
|
||||||
|
|
||||||
# ---- loss ----
|
# ---- loss ----
|
||||||
total_loss, components = loss_fn(
|
total_loss, components = loss_fn(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
hash_codes=hash_codes,
|
hash_codes=hash_codes,
|
||||||
teacher_embed=teacher_embed,
|
teacher_embed=teacher_embed,
|
||||||
|
positive_mask=positive_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---- backward ----
|
# ---- backward ----
|
||||||
@@ -144,7 +153,9 @@ def train(
|
|||||||
|
|
||||||
# ---- periodic save ----
|
# ---- periodic save ----
|
||||||
if global_step % save_every == 0:
|
if global_step % save_every == 0:
|
||||||
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
save_checkpoint(
|
||||||
|
compressor, optimizer, epoch, global_step, checkpoint_path
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n⚠️ Training interrupted, saving checkpoint...")
|
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||||
|
|||||||
Reference in New Issue
Block a user