refactor(ui): replace tqdm with rich for enhanced console output

This commit is contained in:
2026-03-06 16:20:38 +08:00
parent 4a6918ce56
commit e832f9d656
9 changed files with 113 additions and 95 deletions

View File

@@ -8,7 +8,7 @@ from compressors import HashCompressor, HashLoss
from configs import cfg_manager
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
from transformers import AutoImageProcessor, AutoModel
from datasets import load_dataset
@@ -104,59 +104,71 @@ def train(
)
try:
for epoch in range(start_epoch, epoch_size):
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
for batch in train_bar:
global_step += 1
# ---- training step ----
imgs = batch["img"]
# ---- teacher forward ----
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ----
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
)
with progress:
for epoch in range(start_epoch, epoch_size):
task_id = progress.add_task(
f"Epoch [{epoch + 1}/{epoch_size}]", total=len(dataloader)
)
# ---- backward ----
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
for batch in dataloader:
progress.update(task_id, advance=1)
# ---- logging ----
train_bar.set_postfix(
loss=f"{components['total']:.4f}",
cont=f"{components['contrastive']:.2f}",
distill=f"{components['distill']:.3f}",
quant=f"{components['quantization']:.2f}",
)
global_step += 1
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
# ---- training step ----
imgs = batch["img"]
# ---- teacher forward ----
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ----
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
)
# ---- backward ----
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# ---- logging ----
progress.update(
task_id,
description=f"Epoch [{epoch + 1}/{epoch_size}] "
f"loss={components['total']:.4f} "
f"cont={components['contrastive']:.2f} "
f"distill={components['distill']:.3f}",
)
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
)
except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...")
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)