refactor(ui): replace tqdm with rich for enhanced console output

This commit is contained in:
2026-03-06 16:20:38 +08:00
parent 4a6918ce56
commit e832f9d656
9 changed files with 113 additions and 95 deletions

View File

@@ -7,6 +7,10 @@ import lancedb
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
from benchmarks.tasks import get_task
from configs.models import BenchmarkConfig, DatasetSourceConfig
from rich.console import Console
from rich.table import Table
console = Console()
def create_dataset(config: DatasetSourceConfig) -> Any:
@@ -92,7 +96,9 @@ def _ensure_table(
if table_name in existing_tables:
table = db.open_table(table_name)
if table.schema != schema:
print(f"Table '{table_name}' schema mismatch, rebuilding.")
console.print(
f"[yellow]Table '{table_name}' schema mismatch, rebuilding.[/yellow]"
)
db.drop_table(table_name)
table = db.create_table(table_name, schema=schema)
else:
@@ -101,6 +107,29 @@ def _ensure_table(
return table
def _print_benchmark_info(
config: BenchmarkConfig, vector_dim: int, table_name: str, table_count: int
) -> None:
"""Print benchmark configuration info using Rich table.
Args:
config: Benchmark configuration.
vector_dim: Feature vector dimension.
table_name: Database table name.
table_count: Number of entries in the table.
"""
table = Table(title="Benchmark Configuration", show_header=False)
table.add_column("Key", style="cyan", no_wrap=True)
table.add_column("Value", style="magenta")
table.add_row("Dataset", f"{config.dataset.source_type} - {config.dataset.path}")
table.add_row("Model Output Dimension", str(vector_dim))
table.add_row("Table Name", table_name)
table.add_row("Table Entries", str(table_count))
console.print(table)
def run_benchmark(
model: Any,
processor: Any,
@@ -127,13 +156,10 @@ def run_benchmark(
Raises:
ValueError: If benchmark is not enabled in config.
"""
if not config.enabled:
raise ValueError(
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
)
# Create dataset
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
console.print(
f"[cyan]Loading dataset:[/cyan] {config.dataset.source_type} - {config.dataset.path}"
)
dataset = create_dataset(config.dataset)
# Get train and test splits
@@ -152,7 +178,7 @@ def run_benchmark(
from utils.feature_extractor import infer_vector_dim
vector_dim = infer_vector_dim(processor, model, sample_image)
print(f"Model output dimension: {vector_dim}")
console.print(f"[cyan]Model output dimension:[/cyan] {vector_dim}")
# Ensure table exists with correct schema
table = _ensure_table(config, model_name, vector_dim)
@@ -161,26 +187,20 @@ def run_benchmark(
# Check if database is already built
table_count = table.count_rows()
if table_count > 0:
print(
f"Table '{table_name}' already has {table_count} entries, skipping database build."
console.print(
f"[yellow]Table '{table_name}' already has {table_count} entries, skipping database build.[/yellow]"
)
else:
# Create and run benchmark task
task = get_task(config.task.type, top_k=config.task.top_k)
print(f"Building database with {len(train_dataset)} training samples...")
console.print(
f"[cyan]Building database[/cyan] with {len(train_dataset)} training samples..."
)
task.build_database(model, processor, train_dataset, table, config.batch_size)
# Run evaluation
# Run evaluation (results with Rich table will be printed by the task)
task = get_task(config.task.type, top_k=config.task.top_k)
print(f"Evaluating on {len(test_dataset)} test samples...")
console.print(f"[cyan]Evaluating[/cyan] on {len(test_dataset)} test samples...")
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
# Print results
print("\n=== Benchmark Results ===")
print(f"Dataset: {config.dataset.path}")
print(f"Task: {config.task.name}")
print(f"Top-K: {results['top_k']}")
print(f"Accuracy: {results['accuracy']:.4f}")
print(f"Correct: {results['correct']}/{results['total']}")
return results

View File

@@ -6,11 +6,10 @@ import lancedb
import pyarrow as pa
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from rich.progress import track
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils.feature_extractor import extract_batch_features, infer_vector_dim
@@ -47,11 +46,11 @@ def _establish_eval_database(
dataloader: DataLoader for the training dataset.
"""
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
all_features = extract_batch_features(processor, model, dataloader)
# Store features to database
global_idx = 0
for batch in tqdm(dataloader, desc="Storing eval database"):
for batch in track(dataloader, description="Storing eval database"):
labels = batch["label"]
labels_list = labels.tolist()
batch_size = len(labels_list)
@@ -89,13 +88,13 @@ def _evaluate_recall(
A tuple of (correct_count, total_count).
"""
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
all_features = extract_batch_features(processor, model, dataloader)
correct = 0
total = 0
feature_idx = 0
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
for batch in track(dataloader, description=f"Evaluating Recall@{top_k}"):
labels = batch["label"]
labels_list = labels.tolist()

View File

@@ -21,13 +21,6 @@ def benchmark(
config = cfg_manager.get()
benchmark_cfg = config.benchmark
if not benchmark_cfg.enabled:
typer.echo(
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml",
err=True,
)
raise typer.Exit(code=1)
device = get_device()
model_cfg = config.model

View File

@@ -8,7 +8,7 @@ from compressors import HashCompressor, HashLoss
from configs import cfg_manager
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
from transformers import AutoImageProcessor, AutoModel
from datasets import load_dataset
@@ -104,59 +104,71 @@ def train(
)
try:
for epoch in range(start_epoch, epoch_size):
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
for batch in train_bar:
global_step += 1
# ---- training step ----
imgs = batch["img"]
# ---- teacher forward ----
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ----
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
)
with progress:
for epoch in range(start_epoch, epoch_size):
task_id = progress.add_task(
f"Epoch [{epoch + 1}/{epoch_size}]", total=len(dataloader)
)
# ---- backward ----
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
for batch in dataloader:
progress.update(task_id, advance=1)
# ---- logging ----
train_bar.set_postfix(
loss=f"{components['total']:.4f}",
cont=f"{components['contrastive']:.2f}",
distill=f"{components['distill']:.3f}",
quant=f"{components['quantization']:.2f}",
)
global_step += 1
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
# ---- training step ----
imgs = batch["img"]
# ---- teacher forward ----
with torch.no_grad():
inputs = processor(imgs, return_tensors="pt").to(device)
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
# ---- student forward ----
logits, hash_codes, bits = compressor(teacher_tokens)
# ---- generate positive mask ----
labels = batch["label"]
# positive_mask[i,j] = True if labels[i] == labels[j]
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
# ---- loss ----
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
)
# ---- backward ----
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# ---- logging ----
progress.update(
task_id,
description=f"Epoch [{epoch + 1}/{epoch_size}] "
f"loss={components['total']:.4f} "
f"cont={components['contrastive']:.2f} "
f"distill={components['distill']:.3f}",
)
# ---- periodic save ----
if global_step % save_every == 0:
save_checkpoint(
compressor, optimizer, epoch, global_step, checkpoint_path
)
except KeyboardInterrupt:
print("\n⚠️ Training interrupted, saving checkpoint...")
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)

View File

@@ -21,7 +21,6 @@ dataset:
seed: 42
benchmark:
enabled: true
dataset:
source_type: "huggingface"
path: "uoft-cs/cifar10"

View File

@@ -124,7 +124,6 @@ class BenchmarkConfig(BaseModel):
model_config = ConfigDict(extra="ignore")
enabled: bool = Field(default=False, description="Enable benchmark evaluation")
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from tqdm.auto import tqdm
from rich.progress import track
class ImageSynthesizer:
@@ -287,7 +287,7 @@ class ImageSynthesizer:
generated_files: list[Path] = []
for i in tqdm(range(self.num_scenes), desc="Generating scenes"):
for i in track(range(self.num_scenes), description="Generating scenes"):
# Update seed for each scene
random.seed(self.seed + i)
np.random.seed(self.seed + i)

View File

@@ -6,7 +6,7 @@ from database import db_manager
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from torch import nn
from tqdm.auto import tqdm
from rich.progress import track
from transformers import (
AutoImageProcessor,
AutoModel,
@@ -93,7 +93,7 @@ class FeatureRetrieval:
self.processor, self.model, images, batch_size=batch_size
)
for i in tqdm(range(len(labels)), desc="Storing to database"):
for i in track(range(len(labels)), description="Storing to database"):
batch_label = labels[i] if label_map is None else label_map[labels[i]]
# Store to database

View File

@@ -7,7 +7,7 @@ from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from transformers import BitImageProcessorFast
from tqdm.auto import tqdm
from rich.progress import track
def _extract_features_from_output(output: Any) -> torch.Tensor:
@@ -86,7 +86,6 @@ def extract_batch_features(
model: nn.Module,
images: Union[List[Any], Any],
batch_size: int = 32,
show_progress: bool = False,
) -> torch.Tensor:
"""Extract features from a batch of images.
@@ -95,7 +94,6 @@ def extract_batch_features(
model: Feature extraction model.
images: List of images, DataLoader, or other iterable.
batch_size: Batch size for processing.
show_progress: Whether to show progress bar.
Returns:
Tensor of shape [batch_size, feature_dim].
@@ -106,8 +104,7 @@ def extract_batch_features(
# Handle DataLoader input
if isinstance(images, DataLoader):
all_features = []
iterator = tqdm(images, desc="Extracting features") if show_progress else images
for batch in iterator:
for batch in track(images, description="Extracting features"):
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
inputs = processor(images=imgs, return_tensors="pt")
inputs.to(device)
@@ -118,8 +115,7 @@ def extract_batch_features(
# Handle list of images
all_features = []
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
for i in iterator:
for i in track(range(0, len(images), batch_size), description="Extracting features"):
batch_imgs = images[i : i + batch_size]
inputs = processor(images=batch_imgs, return_tensors="pt")
inputs.to(device)