refactor(ui): replace tqdm with rich for enhanced console output

This commit is contained in:
2026-03-06 16:20:38 +08:00
parent 4a6918ce56
commit e832f9d656
9 changed files with 113 additions and 95 deletions

View File

@@ -7,6 +7,10 @@ import lancedb
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
from benchmarks.tasks import get_task
from configs.models import BenchmarkConfig, DatasetSourceConfig
from rich.console import Console
from rich.table import Table
console = Console()
def create_dataset(config: DatasetSourceConfig) -> Any:
@@ -92,7 +96,9 @@ def _ensure_table(
if table_name in existing_tables:
table = db.open_table(table_name)
if table.schema != schema:
print(f"Table '{table_name}' schema mismatch, rebuilding.")
console.print(
f"[yellow]Table '{table_name}' schema mismatch, rebuilding.[/yellow]"
)
db.drop_table(table_name)
table = db.create_table(table_name, schema=schema)
else:
@@ -101,6 +107,29 @@ def _ensure_table(
return table
def _print_benchmark_info(
config: BenchmarkConfig, vector_dim: int, table_name: str, table_count: int
) -> None:
"""Print benchmark configuration info using Rich table.
Args:
config: Benchmark configuration.
vector_dim: Feature vector dimension.
table_name: Database table name.
table_count: Number of entries in the table.
"""
table = Table(title="Benchmark Configuration", show_header=False)
table.add_column("Key", style="cyan", no_wrap=True)
table.add_column("Value", style="magenta")
table.add_row("Dataset", f"{config.dataset.source_type} - {config.dataset.path}")
table.add_row("Model Output Dimension", str(vector_dim))
table.add_row("Table Name", table_name)
table.add_row("Table Entries", str(table_count))
console.print(table)
def run_benchmark(
model: Any,
processor: Any,
@@ -127,13 +156,10 @@ def run_benchmark(
Raises:
ValueError: If benchmark is not enabled in config.
"""
if not config.enabled:
raise ValueError(
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
)
# Create dataset
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
console.print(
f"[cyan]Loading dataset:[/cyan] {config.dataset.source_type} - {config.dataset.path}"
)
dataset = create_dataset(config.dataset)
# Get train and test splits
@@ -152,7 +178,7 @@ def run_benchmark(
from utils.feature_extractor import infer_vector_dim
vector_dim = infer_vector_dim(processor, model, sample_image)
print(f"Model output dimension: {vector_dim}")
console.print(f"[cyan]Model output dimension:[/cyan] {vector_dim}")
# Ensure table exists with correct schema
table = _ensure_table(config, model_name, vector_dim)
@@ -161,26 +187,20 @@ def run_benchmark(
# Check if database is already built
table_count = table.count_rows()
if table_count > 0:
print(
f"Table '{table_name}' already has {table_count} entries, skipping database build."
console.print(
f"[yellow]Table '{table_name}' already has {table_count} entries, skipping database build.[/yellow]"
)
else:
# Create and run benchmark task
task = get_task(config.task.type, top_k=config.task.top_k)
print(f"Building database with {len(train_dataset)} training samples...")
console.print(
f"[cyan]Building database[/cyan] with {len(train_dataset)} training samples..."
)
task.build_database(model, processor, train_dataset, table, config.batch_size)
# Run evaluation
# Run evaluation (results with Rich table will be printed by the task)
task = get_task(config.task.type, top_k=config.task.top_k)
print(f"Evaluating on {len(test_dataset)} test samples...")
console.print(f"[cyan]Evaluating[/cyan] on {len(test_dataset)} test samples...")
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
# Print results
print("\n=== Benchmark Results ===")
print(f"Dataset: {config.dataset.path}")
print(f"Task: {config.task.name}")
print(f"Top-K: {results['top_k']}")
print(f"Accuracy: {results['accuracy']:.4f}")
print(f"Correct: {results['correct']}/{results['total']}")
return results

View File

@@ -6,11 +6,10 @@ import lancedb
import pyarrow as pa
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from rich.progress import track
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils.feature_extractor import extract_batch_features, infer_vector_dim
@@ -47,11 +46,11 @@ def _establish_eval_database(
dataloader: DataLoader for the training dataset.
"""
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
all_features = extract_batch_features(processor, model, dataloader)
# Store features to database
global_idx = 0
for batch in tqdm(dataloader, desc="Storing eval database"):
for batch in track(dataloader, description="Storing eval database"):
labels = batch["label"]
labels_list = labels.tolist()
batch_size = len(labels_list)
@@ -89,13 +88,13 @@ def _evaluate_recall(
A tuple of (correct_count, total_count).
"""
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
all_features = extract_batch_features(processor, model, dataloader)
correct = 0
total = 0
feature_idx = 0
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
for batch in track(dataloader, description=f"Evaluating Recall@{top_k}"):
labels = batch["label"]
labels_list = labels.tolist()