Files
Mini-Nav/mini-nav/compressors/common.py

88 lines
2.0 KiB
Python

"""Common utilities for compressor modules."""
import torch
import torch.nn.functional as F
class BinarySign(torch.autograd.Function):
"""Binary sign function with Straight-Through Estimator (STE).
Forward: returns sign(x) in {-1, +1}
Backward: passes gradients through as if identity
For CAM storage, convert: bits = (sign_output + 1) / 2
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.sign()
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
# STE: treat as identity
# Optional: gradient clipping for stability
return grad_output.clone()
def hamming_distance(b1, b2):
"""Compute Hamming distance between binary codes.
Args:
b1: Binary codes {0,1}, shape [N, D] or [D]
b2: Binary codes {0,1}, shape [M, D] or [D]
Returns:
Hamming distances, shape [N, M] or scalar
"""
if b1.dim() == 1 and b2.dim() == 1:
return (b1 != b2).sum()
# Expand for pairwise computation
b1 = b1.unsqueeze(1) # [N, 1, D]
b2 = b2.unsqueeze(0) # [1, M, D]
return (b1 != b2).sum(dim=-1) # [N, M]
def hamming_similarity(h1, h2):
"""Compute Hamming similarity for {-1, +1} codes.
Args:
h1: Hash codes {-1, +1}, shape [N, D] or [D]
h2: Hash codes {-1, +1}, shape [M, D] or [D]
Returns:
Similarity scores in [-D, D], shape [N, M] or scalar
Higher is more similar
"""
if h1.dim() == 1 and h2.dim() == 1:
return (h1 * h2).sum()
return h1 @ h2.t() # [N, M]
def bits_to_hash(b):
"""Convert {0,1} bits to {-1,+1} hash codes.
Args:
b: Binary bits {0,1}, any shape
Returns:
Hash codes {-1,+1}, same shape
"""
return b * 2 - 1
def hash_to_bits(h):
"""Convert {-1,+1} hash codes to {0,1} bits.
Args:
h: Hash codes {-1,+1}, any shape
Returns:
Binary bits {0,1}, same shape
"""
return (h + 1) / 2