Files
Mini-Nav/mini-nav/compressors/float_compressor.py

28 lines
624 B
Python

import torch.nn as nn
import torch.nn.functional as F
class FloatCompressor(nn.Module):
def __init__(self):
super().__init__()
# projection head
self.proj = nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Linear(1024, 512),
)
self.recover = nn.Linear(512, 1024)
def forward(self, tokens):
pooled = tokens.mean(dim=1) # [B,1024]
z512 = self.proj(pooled) # [B,512]
z512 = F.normalize(z512, dim=-1)
recon = self.recover(z512) # [B,1024]
return z512, recon