diff --git a/mini-nav/compressors/__init__.py b/mini-nav/compressors/__init__.py new file mode 100644 index 0000000..6f77d5e --- /dev/null +++ b/mini-nav/compressors/__init__.py @@ -0,0 +1,5 @@ +from .float_compressor import FloatCompressor +from .int_compressor import IntCompressor +from .train import train + +__all__ = ["train", "FloatCompressor", "IntCompressor"] diff --git a/mini-nav/compressors/float_compressor.py b/mini-nav/compressors/float_compressor.py new file mode 100644 index 0000000..fb08e28 --- /dev/null +++ b/mini-nav/compressors/float_compressor.py @@ -0,0 +1,27 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class FloatCompressor(nn.Module): + def __init__(self): + super().__init__() + + # projection head + self.proj = nn.Sequential( + nn.Linear(1024, 1024), + nn.LayerNorm(1024), + nn.GELU(), + nn.Linear(1024, 512), + ) + + self.recover = nn.Linear(512, 1024) + + def forward(self, tokens): + pooled = tokens.mean(dim=1) # [B,1024] + + z512 = self.proj(pooled) # [B,512] + z512 = F.normalize(z512, dim=-1) + + recon = self.recover(z512) # [B,1024] + + return z512, recon diff --git a/mini-nav/compressors/int_compressor.py b/mini-nav/compressors/int_compressor.py new file mode 100644 index 0000000..f07da04 --- /dev/null +++ b/mini-nav/compressors/int_compressor.py @@ -0,0 +1,9 @@ +import torch.nn as nn + + +class IntCompressor(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + pass diff --git a/mini-nav/compressors/train.py b/mini-nav/compressors/train.py new file mode 100644 index 0000000..2bda281 --- /dev/null +++ b/mini-nav/compressors/train.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from compressors import FloatCompressor +from datasets import load_dataset +from torch import nn +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoImageProcessor, AutoModel + + +def train(dinov2: nn.Module, epoch_size: int, batch_size: int): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + ds = load_dataset("uoft-cs/cifar10", split="train").with_format("torch") + dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4) + label_map = [ + "airplane", + "automobile", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ] + + processor = AutoImageProcessor.from_pretrained( + "facebook/dinov2-large", device_map=device + ) + dino = AutoModel.from_pretrained("facebook/dinov2-large", device_map=device) + dino.eval() + for p in dino.parameters(): + p.requires_grad = False + + compressor = FloatCompressor().to(device) + + optimizer = torch.optim.AdamW(compressor.parameters(), lr=1e-4) + + for epoch in range(epoch_size): + train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]") + + for batch in train_bar: + imgs = batch["img"] + + # ---- teacher forward ---- + with torch.no_grad(): + inputs = processor(imgs, return_tensors="pt").to(device) + + teacher_tokens = dino(**inputs).last_hidden_state + # [B,N,1024] + + teacher_embed = teacher_tokens.mean(dim=1) + teacher_embed = F.normalize(teacher_embed, dim=-1) + # [B,1024] + + # ---- student forward ---- + z512, recon = compressor(teacher_tokens) + + # ---- loss ---- + mse_loss = F.mse_loss(recon, teacher_embed) + + cos_loss = 1 - F.cosine_similarity(recon, teacher_embed, dim=-1).mean() + + loss = mse_loss + cos_loss + + # ---- backward ---- + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_bar.set_postfix(loss=loss.item()) diff --git a/mini-nav/main.py b/mini-nav/main.py index 9568012..bd499f8 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -1,4 +1,15 @@ -from visualizer import app +import argparse if __name__ == "__main__": - app.run(debug=True) + parser = argparse.ArgumentParser() + parser.add_argument("train") + args = parser.parse_args() + + if args.train: + from compressors import FloatCompressor, train + + train(FloatCompressor(), 1, 32) + else: + from visualizer import app + + app.run(debug=True)