mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(compressors): add neural compression modules and training pipeline
This commit is contained in:
5
mini-nav/compressors/__init__.py
Normal file
5
mini-nav/compressors/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .float_compressor import FloatCompressor
|
||||
from .int_compressor import IntCompressor
|
||||
from .train import train
|
||||
|
||||
__all__ = ["train", "FloatCompressor", "IntCompressor"]
|
||||
27
mini-nav/compressors/float_compressor.py
Normal file
27
mini-nav/compressors/float_compressor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FloatCompressor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# projection head
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(1024, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.GELU(),
|
||||
nn.Linear(1024, 512),
|
||||
)
|
||||
|
||||
self.recover = nn.Linear(512, 1024)
|
||||
|
||||
def forward(self, tokens):
|
||||
pooled = tokens.mean(dim=1) # [B,1024]
|
||||
|
||||
z512 = self.proj(pooled) # [B,512]
|
||||
z512 = F.normalize(z512, dim=-1)
|
||||
|
||||
recon = self.recover(z512) # [B,1024]
|
||||
|
||||
return z512, recon
|
||||
9
mini-nav/compressors/int_compressor.py
Normal file
9
mini-nav/compressors/int_compressor.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class IntCompressor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
73
mini-nav/compressors/train.py
Normal file
73
mini-nav/compressors/train.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from compressors import FloatCompressor
|
||||
from datasets import load_dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
||||
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
label_map = [
|
||||
"airplane",
|
||||
"automobile",
|
||||
"bird",
|
||||
"cat",
|
||||
"deer",
|
||||
"dog",
|
||||
"frog",
|
||||
"horse",
|
||||
"ship",
|
||||
"truck",
|
||||
]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map=device
|
||||
)
|
||||
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
||||
dino.eval()
|
||||
for p in dino.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
compressor = FloatCompressor().to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
||||
|
||||
for epoch in range(epoch_size):
|
||||
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
||||
|
||||
for batch in train_bar:
|
||||
imgs = batch["img"]
|
||||
|
||||
# ---- teacher forward ----
|
||||
with torch.no_grad():
|
||||
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||
|
||||
teacher_tokens = dino(**inputs).last_hidden_state
|
||||
# [B,N,1024]
|
||||
|
||||
teacher_embed = teacher_tokens.mean(dim=1)
|
||||
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
||||
# [B,1024]
|
||||
|
||||
# ---- student forward ----
|
||||
z512, recon = compressor(teacher_tokens)
|
||||
|
||||
# ---- loss ----
|
||||
mse_loss = F.mse_loss(recon, teacher_embed)
|
||||
|
||||
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
||||
|
||||
loss = mse_loss + cos_loss
|
||||
|
||||
# ---- backward ----
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_bar.set_postfix(loss=loss.item())
|
||||
@@ -1,4 +1,15 @@
|
||||
from visualizer import app
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.train:
|
||||
from compressors import FloatCompressor, train
|
||||
|
||||
train(FloatCompressor(), 1, 32)
|
||||
else:
|
||||
from visualizer import app
|
||||
|
||||
app.run(debug=True)
|
||||
|
||||
Reference in New Issue
Block a user