mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
fix(compressors): fix the wrong usage of loss function in training pipeline
This commit is contained in:
@@ -4,9 +4,12 @@ Converts DINO features to 512-bit binary hash codes suitable for
|
||||
Content Addressable Memory (CAM) retrieval.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from .common import BinarySign, hamming_similarity
|
||||
|
||||
@@ -53,10 +56,12 @@ class HashCompressor(nn.Module):
|
||||
)
|
||||
|
||||
# Initialize last layer with smaller weights for stable training
|
||||
nn.init.xavier_uniform_(self.proj[-1].weight, gain=0.1)
|
||||
nn.init.zeros_(self.proj[-1].bias)
|
||||
nn.init.xavier_uniform_(cast(Tensor, self.proj[-1].weight), gain=0.1)
|
||||
nn.init.zeros_(cast(Tensor, self.proj[-1].bias))
|
||||
|
||||
def forward(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, tokens: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass producing hash codes.
|
||||
|
||||
Args:
|
||||
@@ -96,7 +101,9 @@ class HashCompressor(nn.Module):
|
||||
_, _, bits = self.forward(tokens)
|
||||
return bits
|
||||
|
||||
def compute_similarity(self, query_bits: torch.Tensor, db_bits: torch.Tensor) -> torch.Tensor:
|
||||
def compute_similarity(
|
||||
self, query_bits: torch.Tensor, db_bits: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Compute Hamming similarity between query and database entries.
|
||||
|
||||
Higher score = more similar (fewer differing bits).
|
||||
@@ -259,7 +266,7 @@ class HashLoss(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
hash_codes: torch.Tensor,
|
||||
teacher_embed: torch.Tensor,
|
||||
positive_mask: torch.Tensor | None = None,
|
||||
positive_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
"""Compute combined hash training loss.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user