fix(compressors): fix the wrong usage of loss function in training pipeline

This commit is contained in:
2026-02-28 17:49:55 +08:00
parent 1926cb53e2
commit f61857feba
4 changed files with 29 additions and 9 deletions

View File

@@ -4,9 +4,12 @@ Converts DINO features to 512-bit binary hash codes suitable for
Content Addressable Memory (CAM) retrieval.
"""
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .common import BinarySign, hamming_similarity
@@ -53,10 +56,12 @@ class HashCompressor(nn.Module):
)
# Initialize last layer with smaller weights for stable training
nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1)
nn.init.zeros_(self.proj[-1].bias)
nn.init.xavier_uniform_(cast(Tensor, self.proj[-1].weight), gain=0.1)
nn.init.zeros_(cast(Tensor, self.proj[-1].bias))
def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(
self, tokens: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass producing hash codes.
Args:
@@ -96,7 +101,9 @@ class HashCompressor(nn.Module):
_, _, bits = self.forward(tokens)
return bits
def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor:
def compute_similarity(
self, query_bits: torch.Tensor, db_bits: torch.Tensor
) -> torch.Tensor:
"""Compute Hamming similarity between query and database entries.
Higher score = more similar (fewer differing bits).
@@ -259,7 +266,7 @@ class HashLoss(nn.Module):
logits: torch.Tensor,
hash_codes: torch.Tensor,
teacher_embed: torch.Tensor,
positive_mask: torch.Tensor | None = None,
positive_mask: torch.Tensor,
) -> tuple[torch.Tensor, dict[str, float]]:
"""Compute combined hash training loss.