feat(compressors): add neural compression modules and training pipeline

This commit is contained in:
2026-02-08 16:40:44 +08:00
parent b93381accc
commit 8f417b674c
5 changed files with 127 additions and 2 deletions

View File

@@ -0,0 +1,5 @@
from .float_compressor import FloatCompressor
from .int_compressor import IntCompressor
from .train import train
__all__ = ["train", "FloatCompressor", "IntCompressor"]

View File

@@ -0,0 +1,27 @@
import torch.nn as nn
import torch.nn.functional as F
class FloatCompressor(nn.Module):
def __init__(self):
super().__init__()
# projection head
self.proj = nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.GELU(),
nn.Linear(1024, 512),
)
self.recover = nn.Linear(512, 1024)
def forward(self, tokens):
pooled = tokens.mean(dim=1) # [B,1024]
z512 = self.proj(pooled) # [B,512]
z512 = F.normalize(z512, dim=-1)
recon = self.recover(z512) # [B,1024]
return z512, recon

View File

@@ -0,0 +1,9 @@
import torch.nn as nn
class IntCompressor(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
pass

View File

@@ -0,0 +1,73 @@
import torch
import torch.nn.functional as F
from compressors import FloatCompressor
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
label_map = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map=device
)
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
dino.eval()
for p in dino.parameters():
p.requires_grad = False
compressor = FloatCompressor().to(device)
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
for epoch in range(epoch_size):
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
for batch in train_bar:
imgs = batch["img"]
# ---- teacher forward ----
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state
# [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1)
# [B,1024]
# ---- student forward ----
z512, recon = compressor(teacher_tokens)
# ---- loss ----
mse_loss = F.mse_loss(recon, teacher_embed)
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
loss = mse_loss + cos_loss
# ---- backward ----
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_bar.set_postfix(loss=loss.item())

View File

@@ -1,4 +1,15 @@
from visualizer import app
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("train")
args = parser.parse_args()
if args.train:
from compressors import FloatCompressor, train
train(FloatCompressor(), 1, 32)
else:
from visualizer import app
app.run(debug=True)