mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(compressors): add neural compression modules and training pipeline
This commit is contained in:
5
mini-nav/compressors/__init__.py
Normal file
5
mini-nav/compressors/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .float_compressor import FloatCompressor
|
||||||
|
from .int_compressor import IntCompressor
|
||||||
|
from .train import train
|
||||||
|
|
||||||
|
__all__ = ["train", "FloatCompressor", "IntCompressor"]
|
||||||
27
mini-nav/compressors/float_compressor.py
Normal file
27
mini-nav/compressors/float_compressor.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FloatCompressor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# projection head
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(1024, 1024),
|
||||||
|
nn.LayerNorm(1024),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.recover = nn.Linear(512, 1024)
|
||||||
|
|
||||||
|
def forward(self, tokens):
|
||||||
|
pooled = tokens.mean(dim=1) # [B,1024]
|
||||||
|
|
||||||
|
z512 = self.proj(pooled) # [B,512]
|
||||||
|
z512 = F.normalize(z512, dim=-1)
|
||||||
|
|
||||||
|
recon = self.recover(z512) # [B,1024]
|
||||||
|
|
||||||
|
return z512, recon
|
||||||
9
mini-nav/compressors/int_compressor.py
Normal file
9
mini-nav/compressors/int_compressor.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class IntCompressor(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
73
mini-nav/compressors/train.py
Normal file
73
mini-nav/compressors/train.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from compressors import FloatCompressor
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
|
||||||
|
def train(dinov2: nn.Module, epoch_size: int, batch_size: int):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch")
|
||||||
|
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||||
|
label_map = [
|
||||||
|
"airplane",
|
||||||
|
"automobile",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"deer",
|
||||||
|
"dog",
|
||||||
|
"frog",
|
||||||
|
"horse",
|
||||||
|
"ship",
|
||||||
|
"truck",
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = AutoImageProcessor.from_pretrained(
|
||||||
|
"facebook/dinov2-large", device_map=device
|
||||||
|
)
|
||||||
|
dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device)
|
||||||
|
dino.eval()
|
||||||
|
for p in dino.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
compressor = FloatCompressor().to(device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4)
|
||||||
|
|
||||||
|
for epoch in range(epoch_size):
|
||||||
|
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
||||||
|
|
||||||
|
for batch in train_bar:
|
||||||
|
imgs = batch["img"]
|
||||||
|
|
||||||
|
# ---- teacher forward ----
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
teacher_tokens = dino(**inputs).last_hidden_state
|
||||||
|
# [B,N,1024]
|
||||||
|
|
||||||
|
teacher_embed = teacher_tokens.mean(dim=1)
|
||||||
|
teacher_embed = F.normalize(teacher_embed, dim=-1)
|
||||||
|
# [B,1024]
|
||||||
|
|
||||||
|
# ---- student forward ----
|
||||||
|
z512, recon = compressor(teacher_tokens)
|
||||||
|
|
||||||
|
# ---- loss ----
|
||||||
|
mse_loss = F.mse_loss(recon, teacher_embed)
|
||||||
|
|
||||||
|
cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean()
|
||||||
|
|
||||||
|
loss = mse_loss + cos_loss
|
||||||
|
|
||||||
|
# ---- backward ----
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
train_bar.set_postfix(loss=loss.item())
|
||||||
@@ -1,4 +1,15 @@
|
|||||||
from visualizer import app
|
import argparse
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(debug=True)
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.train:
|
||||||
|
from compressors import FloatCompressor, train
|
||||||
|
|
||||||
|
train(FloatCompressor(), 1, 32)
|
||||||
|
else:
|
||||||
|
from visualizer import app
|
||||||
|
|
||||||
|
app.run(debug=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user