mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(compressors): add neural compression modules and training pipeline
This commit is contained in:
27
mini-nav/compressors/float_compressor.py
Normal file
27
mini-nav/compressors/float_compressor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FloatCompressor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# projection head
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.GELU(),
|
||||
nn.Linear(1024, 512),
|
||||
)
|
||||
|
||||
self.recover = nn.Linear(512, 1024)
|
||||
|
||||
def forward(self, tokens):
|
||||
pooled = tokens.mean(dim=1) # [B,1024]
|
||||
|
||||
z512 = self.proj(pooled) # [B,512]
|
||||
z512 = F.normalize(z512, dim=-1)
|
||||
|
||||
recon = self.recover(z512) # [B,1024]
|
||||
|
||||
return z512, recon
|
||||
Reference in New Issue
Block a user