mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(compressors): replace float/int compressors with hash-based compression for CAM
This commit is contained in:
87
mini-nav/compressors/common.py
Normal file
87
mini-nav/compressors/common.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Common utilities for compressor modules."""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BinarySign(torch.autograd.Function):
|
||||
"""Binary sign function with Straight-Through Estimator (STE).
|
||||
|
||||
Forward: returns sign(x) in {-1, +1}
|
||||
Backward: passes gradients through as if identity
|
||||
|
||||
For CAM storage, convert: bits = (sign_output + 1) / 2
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return x.sign()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(x,) = ctx.saved_tensors
|
||||
# STE: treat as identity
|
||||
# Optional: gradient clipping for stability
|
||||
return grad_output.clone()
|
||||
|
||||
|
||||
def hamming_distance(b1, b2):
|
||||
"""Compute Hamming distance between binary codes.
|
||||
|
||||
Args:
|
||||
b1: Binary codes {0,1}, shape [N, D] or [D]
|
||||
b2: Binary codes {0,1}, shape [M, D] or [D]
|
||||
|
||||
Returns:
|
||||
Hamming distances, shape [N, M] or scalar
|
||||
"""
|
||||
if b1.dim() == 1 and b2.dim() == 1:
|
||||
return (b1 != b2).sum()
|
||||
|
||||
# Expand for pairwise computation
|
||||
b1 = b1.unsqueeze(1) # [N, 1, D]
|
||||
b2 = b2.unsqueeze(0) # [1, M, D]
|
||||
|
||||
return (b1 != b2).sum(dim=-1) # [N, M]
|
||||
|
||||
|
||||
def hamming_similarity(h1, h2):
|
||||
"""Compute Hamming similarity for {-1, +1} codes.
|
||||
|
||||
Args:
|
||||
h1: Hash codes {-1, +1}, shape [N, D] or [D]
|
||||
h2: Hash codes {-1, +1}, shape [M, D] or [D]
|
||||
|
||||
Returns:
|
||||
Similarity scores in [-D, D], shape [N, M] or scalar
|
||||
Higher is more similar
|
||||
"""
|
||||
if h1.dim() == 1 and h2.dim() == 1:
|
||||
return (h1 * h2).sum()
|
||||
|
||||
return h1 @ h2.t() # [N, M]
|
||||
|
||||
|
||||
def bits_to_hash(b):
|
||||
"""Convert {0,1} bits to {-1,+1} hash codes.
|
||||
|
||||
Args:
|
||||
b: Binary bits {0,1}, any shape
|
||||
|
||||
Returns:
|
||||
Hash codes {-1,+1}, same shape
|
||||
"""
|
||||
return b * 2 - 1
|
||||
|
||||
|
||||
def hash_to_bits(h):
|
||||
"""Convert {-1,+1} hash codes to {0,1} bits.
|
||||
|
||||
Args:
|
||||
h: Hash codes {-1,+1}, any shape
|
||||
|
||||
Returns:
|
||||
Binary bits {0,1}, same shape
|
||||
"""
|
||||
return (h + 1) / 2
|
||||
Reference in New Issue
Block a user