mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(ui): replace tqdm with rich for enhanced console output
This commit is contained in:
@@ -8,7 +8,7 @@ from compressors import HashCompressor, HashLoss
|
||||
from configs import cfg_manager
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from datasets import load_dataset
|
||||
@@ -104,59 +104,71 @@ def train(
|
||||
)
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, epoch_size):
|
||||
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
||||
|
||||
for batch in train_bar:
|
||||
global_step += 1
|
||||
|
||||
# ---- training step ----
|
||||
imgs = batch["img"]
|
||||
|
||||
# ---- teacher forward ----
|
||||
with torch.no_grad():
|
||||
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||
|
||||
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
|
||||
|
||||
teacher_embed = teacher_tokens.mean(dim=1)
|
||||
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||
|
||||
# ---- student forward ----
|
||||
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||
|
||||
# ---- generate positive mask ----
|
||||
labels = batch["label"]
|
||||
# positive_mask[i,j] = True if labels[i] == labels[j]
|
||||
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
|
||||
|
||||
# ---- loss ----
|
||||
total_loss, components = loss_fn(
|
||||
logits=logits,
|
||||
hash_codes=hash_codes,
|
||||
teacher_embed=teacher_embed,
|
||||
positive_mask=positive_mask,
|
||||
progress = Progress(
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
with progress:
|
||||
for epoch in range(start_epoch, epoch_size):
|
||||
task_id = progress.add_task(
|
||||
f"Epoch [{epoch + 1}/{epoch_size}]", total=len(dataloader)
|
||||
)
|
||||
|
||||
# ---- backward ----
|
||||
optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
for batch in dataloader:
|
||||
progress.update(task_id, advance=1)
|
||||
|
||||
# ---- logging ----
|
||||
train_bar.set_postfix(
|
||||
loss=f"{components['total']:.4f}",
|
||||
cont=f"{components['contrastive']:.2f}",
|
||||
distill=f"{components['distill']:.3f}",
|
||||
quant=f"{components['quantization']:.2f}",
|
||||
)
|
||||
global_step += 1
|
||||
|
||||
# ---- periodic save ----
|
||||
if global_step % save_every == 0:
|
||||
save_checkpoint(
|
||||
compressor, optimizer, epoch, global_step, checkpoint_path
|
||||
# ---- training step ----
|
||||
imgs = batch["img"]
|
||||
|
||||
# ---- teacher forward ----
|
||||
with torch.no_grad():
|
||||
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||
|
||||
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
|
||||
|
||||
teacher_embed = teacher_tokens.mean(dim=1)
|
||||
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||
|
||||
# ---- student forward ----
|
||||
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||
|
||||
# ---- generate positive mask ----
|
||||
labels = batch["label"]
|
||||
# positive_mask[i,j] = True if labels[i] == labels[j]
|
||||
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
|
||||
|
||||
# ---- loss ----
|
||||
total_loss, components = loss_fn(
|
||||
logits=logits,
|
||||
hash_codes=hash_codes,
|
||||
teacher_embed=teacher_embed,
|
||||
positive_mask=positive_mask,
|
||||
)
|
||||
|
||||
# ---- backward ----
|
||||
optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# ---- logging ----
|
||||
progress.update(
|
||||
task_id,
|
||||
description=f"Epoch [{epoch + 1}/{epoch_size}] "
|
||||
f"loss={components['total']:.4f} "
|
||||
f"cont={components['contrastive']:.2f} "
|
||||
f"distill={components['distill']:.3f}",
|
||||
)
|
||||
|
||||
# ---- periodic save ----
|
||||
if global_step % save_every == 0:
|
||||
save_checkpoint(
|
||||
compressor, optimizer, epoch, global_step, checkpoint_path
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||
|
||||
Reference in New Issue
Block a user