From e832f9d65643e5b00542f6b8a402932b83be7962 Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Fri, 6 Mar 2026 16:20:38 +0800 Subject: [PATCH] refactor(ui): replace tqdm with rich for enhanced console output --- mini-nav/benchmarks/runner.py | 62 +++++++++----- mini-nav/benchmarks/tasks/retrieval.py | 11 ++- mini-nav/commands/benchmark.py | 7 -- mini-nav/compressors/train.py | 108 ++++++++++++++----------- mini-nav/configs/config.yaml | 1 - mini-nav/configs/models.py | 1 - mini-nav/data_loading/synthesizer.py | 4 +- mini-nav/feature_retrieval.py | 4 +- mini-nav/utils/feature_extractor.py | 10 +-- 9 files changed, 113 insertions(+), 95 deletions(-) diff --git a/mini-nav/benchmarks/runner.py b/mini-nav/benchmarks/runner.py index 4f8cdab..f6505cb 100644 --- a/mini-nav/benchmarks/runner.py +++ b/mini-nav/benchmarks/runner.py @@ -7,6 +7,10 @@ import lancedb from benchmarks.datasets import HuggingFaceDataset, LocalDataset from benchmarks.tasks import get_task from configs.models import BenchmarkConfig, DatasetSourceConfig +from rich.console import Console +from rich.table import Table + +console = Console() def create_dataset(config: DatasetSourceConfig) -> Any: @@ -92,7 +96,9 @@ def _ensure_table( if table_name in existing_tables: table = db.open_table(table_name) if table.schema != schema: - print(f"Table '{table_name}' schema mismatch, rebuilding.") + console.print( + f"[yellow]Table '{table_name}' schema mismatch, rebuilding.[/yellow]" + ) db.drop_table(table_name) table = db.create_table(table_name, schema=schema) else: @@ -101,6 +107,29 @@ def _ensure_table( return table +def _print_benchmark_info( + config: BenchmarkConfig, vector_dim: int, table_name: str, table_count: int +) -> None: + """Print benchmark configuration info using Rich table. + + Args: + config: Benchmark configuration. + vector_dim: Feature vector dimension. + table_name: Database table name. + table_count: Number of entries in the table. + """ + table = Table(title="Benchmark Configuration", show_header=False) + table.add_column("Key", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + table.add_row("Dataset", f"{config.dataset.source_type} - {config.dataset.path}") + table.add_row("Model Output Dimension", str(vector_dim)) + table.add_row("Table Name", table_name) + table.add_row("Table Entries", str(table_count)) + + console.print(table) + + def run_benchmark( model: Any, processor: Any, @@ -127,13 +156,10 @@ def run_benchmark( Raises: ValueError: If benchmark is not enabled in config. """ - if not config.enabled: - raise ValueError( - "Benchmark is not enabled. Set benchmark.enabled=true in config.yaml" - ) - # Create dataset - print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}") + console.print( + f"[cyan]Loading dataset:[/cyan] {config.dataset.source_type} - {config.dataset.path}" + ) dataset = create_dataset(config.dataset) # Get train and test splits @@ -152,7 +178,7 @@ def run_benchmark( from utils.feature_extractor import infer_vector_dim vector_dim = infer_vector_dim(processor, model, sample_image) - print(f"Model output dimension: {vector_dim}") + console.print(f"[cyan]Model output dimension:[/cyan] {vector_dim}") # Ensure table exists with correct schema table = _ensure_table(config, model_name, vector_dim) @@ -161,26 +187,20 @@ def run_benchmark( # Check if database is already built table_count = table.count_rows() if table_count > 0: - print( - f"Table '{table_name}' already has {table_count} entries, skipping database build." + console.print( + f"[yellow]Table '{table_name}' already has {table_count} entries, skipping database build.[/yellow]" ) else: # Create and run benchmark task task = get_task(config.task.type, top_k=config.task.top_k) - print(f"Building database with {len(train_dataset)} training samples...") + console.print( + f"[cyan]Building database[/cyan] with {len(train_dataset)} training samples..." + ) task.build_database(model, processor, train_dataset, table, config.batch_size) - # Run evaluation + # Run evaluation (results with Rich table will be printed by the task) task = get_task(config.task.type, top_k=config.task.top_k) - print(f"Evaluating on {len(test_dataset)} test samples...") + console.print(f"[cyan]Evaluating[/cyan] on {len(test_dataset)} test samples...") results = task.evaluate(model, processor, test_dataset, table, config.batch_size) - # Print results - print("\n=== Benchmark Results ===") - print(f"Dataset: {config.dataset.path}") - print(f"Task: {config.task.name}") - print(f"Top-K: {results['top_k']}") - print(f"Accuracy: {results['accuracy']:.4f}") - print(f"Correct: {results['correct']}/{results['total']}") - return results diff --git a/mini-nav/benchmarks/tasks/retrieval.py b/mini-nav/benchmarks/tasks/retrieval.py index c230dbe..c81c48a 100644 --- a/mini-nav/benchmarks/tasks/retrieval.py +++ b/mini-nav/benchmarks/tasks/retrieval.py @@ -6,11 +6,10 @@ import lancedb import pyarrow as pa from benchmarks.base import BaseBenchmarkTask from benchmarks.tasks.registry import RegisterTask +from rich.progress import track from torch import nn from torch.utils.data import DataLoader -from tqdm.auto import tqdm from transformers import BitImageProcessorFast - from utils.feature_extractor import extract_batch_features, infer_vector_dim @@ -47,11 +46,11 @@ def _establish_eval_database( dataloader: DataLoader for the training dataset. """ # Extract all features using the utility function - all_features = extract_batch_features(processor, model, dataloader, show_progress=True) + all_features = extract_batch_features(processor, model, dataloader) # Store features to database global_idx = 0 - for batch in tqdm(dataloader, desc="Storing eval database"): + for batch in track(dataloader, description="Storing eval database"): labels = batch["label"] labels_list = labels.tolist() batch_size = len(labels_list) @@ -89,13 +88,13 @@ def _evaluate_recall( A tuple of (correct_count, total_count). """ # Extract all features using the utility function - all_features = extract_batch_features(processor, model, dataloader, show_progress=True) + all_features = extract_batch_features(processor, model, dataloader) correct = 0 total = 0 feature_idx = 0 - for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"): + for batch in track(dataloader, description=f"Evaluating Recall@{top_k}"): labels = batch["label"] labels_list = labels.tolist() diff --git a/mini-nav/commands/benchmark.py b/mini-nav/commands/benchmark.py index 758c6b7..755a0f7 100644 --- a/mini-nav/commands/benchmark.py +++ b/mini-nav/commands/benchmark.py @@ -21,13 +21,6 @@ def benchmark( config = cfg_manager.get() benchmark_cfg = config.benchmark - if not benchmark_cfg.enabled: - typer.echo( - "Benchmark is not enabled. Set benchmark.enabled=true in config.yaml", - err=True, - ) - raise typer.Exit(code=1) - device = get_device() model_cfg = config.model diff --git a/mini-nav/compressors/train.py b/mini-nav/compressors/train.py index dde89f1..905e6c4 100644 --- a/mini-nav/compressors/train.py +++ b/mini-nav/compressors/train.py @@ -8,7 +8,7 @@ from compressors import HashCompressor, HashLoss from configs import cfg_manager from torch import nn from torch.utils.data import DataLoader -from tqdm.auto import tqdm +from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn from transformers import AutoImageProcessor, AutoModel from datasets import load_dataset @@ -104,59 +104,71 @@ def train( ) try: - for epoch in range(start_epoch, epoch_size): - train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]") - - for batch in train_bar: - global_step += 1 - - # ---- training step ---- - imgs = batch["img"] - - # ---- teacher forward ---- - with torch.no_grad(): - inputs = processor(imgs, return_tensors="pt").to(device) - - teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024] - - teacher_embed = teacher_tokens.mean(dim=1) - teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] - - # ---- student forward ---- - logits, hash_codes, bits = compressor(teacher_tokens) - - # ---- generate positive mask ---- - labels = batch["label"] - # positive_mask[i,j] = True if labels[i] == labels[j] - positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B] - - # ---- loss ---- - total_loss, components = loss_fn( - logits=logits, - hash_codes=hash_codes, - teacher_embed=teacher_embed, - positive_mask=positive_mask, + progress = Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeRemainingColumn(), + ) + with progress: + for epoch in range(start_epoch, epoch_size): + task_id = progress.add_task( + f"Epoch [{epoch + 1}/{epoch_size}]", total=len(dataloader) ) - # ---- backward ---- - optimizer.zero_grad() - total_loss.backward() - optimizer.step() + for batch in dataloader: + progress.update(task_id, advance=1) - # ---- logging ---- - train_bar.set_postfix( - loss=f"{components['total']:.4f}", - cont=f"{components['contrastive']:.2f}", - distill=f"{components['distill']:.3f}", - quant=f"{components['quantization']:.2f}", - ) + global_step += 1 - # ---- periodic save ---- - if global_step % save_every == 0: - save_checkpoint( - compressor, optimizer, epoch, global_step, checkpoint_path + # ---- training step ---- + imgs = batch["img"] + + # ---- teacher forward ---- + with torch.no_grad(): + inputs = processor(imgs, return_tensors="pt").to(device) + + teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024] + + teacher_embed = teacher_tokens.mean(dim=1) + teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] + + # ---- student forward ---- + logits, hash_codes, bits = compressor(teacher_tokens) + + # ---- generate positive mask ---- + labels = batch["label"] + # positive_mask[i,j] = True if labels[i] == labels[j] + positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B] + + # ---- loss ---- + total_loss, components = loss_fn( + logits=logits, + hash_codes=hash_codes, + teacher_embed=teacher_embed, + positive_mask=positive_mask, ) + # ---- backward ---- + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + # ---- logging ---- + progress.update( + task_id, + description=f"Epoch [{epoch + 1}/{epoch_size}] " + f"loss={components['total']:.4f} " + f"cont={components['contrastive']:.2f} " + f"distill={components['distill']:.3f}", + ) + + # ---- periodic save ---- + if global_step % save_every == 0: + save_checkpoint( + compressor, optimizer, epoch, global_step, checkpoint_path + ) + except KeyboardInterrupt: print("\n⚠️ Training interrupted, saving checkpoint...") save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path) diff --git a/mini-nav/configs/config.yaml b/mini-nav/configs/config.yaml index 70d7700..f3535c3 100644 --- a/mini-nav/configs/config.yaml +++ b/mini-nav/configs/config.yaml @@ -21,7 +21,6 @@ dataset: seed: 42 benchmark: - enabled: true dataset: source_type: "huggingface" path: "uoft-cs/cifar10" diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py index 679d281..b3e244c 100644 --- a/mini-nav/configs/models.py +++ b/mini-nav/configs/models.py @@ -124,7 +124,6 @@ class BenchmarkConfig(BaseModel): model_config = ConfigDict(extra="ignore") - enabled: bool = Field(default=False, description="Enable benchmark evaluation") dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig) task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig) batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader") diff --git a/mini-nav/data_loading/synthesizer.py b/mini-nav/data_loading/synthesizer.py index 934caa2..5401acc 100644 --- a/mini-nav/data_loading/synthesizer.py +++ b/mini-nav/data_loading/synthesizer.py @@ -7,7 +7,7 @@ from pathlib import Path import numpy as np from PIL import Image from PIL.Image import Resampling -from tqdm.auto import tqdm +from rich.progress import track class ImageSynthesizer: @@ -287,7 +287,7 @@ class ImageSynthesizer: generated_files: list[Path] = [] - for i in tqdm(range(self.num_scenes), desc="Generating scenes"): + for i in track(range(self.num_scenes), description="Generating scenes"): # Update seed for each scene random.seed(self.seed + i) np.random.seed(self.seed + i) diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index 94bef54..1c08262 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -6,7 +6,7 @@ from database import db_manager from PIL import Image from PIL.PngImagePlugin import PngImageFile from torch import nn -from tqdm.auto import tqdm +from rich.progress import track from transformers import ( AutoImageProcessor, AutoModel, @@ -93,7 +93,7 @@ class FeatureRetrieval: self.processor, self.model, images, batch_size=batch_size ) - for i in tqdm(range(len(labels)), desc="Storing to database"): + for i in track(range(len(labels)), description="Storing to database"): batch_label = labels[i] if label_map is None else label_map[labels[i]] # Store to database diff --git a/mini-nav/utils/feature_extractor.py b/mini-nav/utils/feature_extractor.py index 01acb29..bbed324 100644 --- a/mini-nav/utils/feature_extractor.py +++ b/mini-nav/utils/feature_extractor.py @@ -7,7 +7,7 @@ from PIL import Image from torch import nn from torch.utils.data import DataLoader from transformers import BitImageProcessorFast -from tqdm.auto import tqdm +from rich.progress import track def _extract_features_from_output(output: Any) -> torch.Tensor: @@ -86,7 +86,6 @@ def extract_batch_features( model: nn.Module, images: Union[List[Any], Any], batch_size: int = 32, - show_progress: bool = False, ) -> torch.Tensor: """Extract features from a batch of images. @@ -95,7 +94,6 @@ def extract_batch_features( model: Feature extraction model. images: List of images, DataLoader, or other iterable. batch_size: Batch size for processing. - show_progress: Whether to show progress bar. Returns: Tensor of shape [batch_size, feature_dim]. @@ -106,8 +104,7 @@ def extract_batch_features( # Handle DataLoader input if isinstance(images, DataLoader): all_features = [] - iterator = tqdm(images, desc="Extracting features") if show_progress else images - for batch in iterator: + for batch in track(images, description="Extracting features"): imgs = batch["img"] if isinstance(batch, dict) else batch[0] inputs = processor(images=imgs, return_tensors="pt") inputs.to(device) @@ -118,8 +115,7 @@ def extract_batch_features( # Handle list of images all_features = [] - iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size) - for i in iterator: + for i in track(range(0, len(images), batch_size), description="Extracting features"): batch_imgs = images[i : i + batch_size] inputs = processor(images=batch_imgs, return_tensors="pt") inputs.to(device)