mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(ui): replace tqdm with rich for enhanced console output
This commit is contained in:
@@ -7,6 +7,10 @@ import lancedb
|
||||
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
|
||||
from benchmarks.tasks import get_task
|
||||
from configs.models import BenchmarkConfig, DatasetSourceConfig
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def create_dataset(config: DatasetSourceConfig) -> Any:
|
||||
@@ -92,7 +96,9 @@ def _ensure_table(
|
||||
if table_name in existing_tables:
|
||||
table = db.open_table(table_name)
|
||||
if table.schema != schema:
|
||||
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
||||
console.print(
|
||||
f"[yellow]Table '{table_name}' schema mismatch, rebuilding.[/yellow]"
|
||||
)
|
||||
db.drop_table(table_name)
|
||||
table = db.create_table(table_name, schema=schema)
|
||||
else:
|
||||
@@ -101,6 +107,29 @@ def _ensure_table(
|
||||
return table
|
||||
|
||||
|
||||
def _print_benchmark_info(
|
||||
config: BenchmarkConfig, vector_dim: int, table_name: str, table_count: int
|
||||
) -> None:
|
||||
"""Print benchmark configuration info using Rich table.
|
||||
|
||||
Args:
|
||||
config: Benchmark configuration.
|
||||
vector_dim: Feature vector dimension.
|
||||
table_name: Database table name.
|
||||
table_count: Number of entries in the table.
|
||||
"""
|
||||
table = Table(title="Benchmark Configuration", show_header=False)
|
||||
table.add_column("Key", style="cyan", no_wrap=True)
|
||||
table.add_column("Value", style="magenta")
|
||||
|
||||
table.add_row("Dataset", f"{config.dataset.source_type} - {config.dataset.path}")
|
||||
table.add_row("Model Output Dimension", str(vector_dim))
|
||||
table.add_row("Table Name", table_name)
|
||||
table.add_row("Table Entries", str(table_count))
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
model: Any,
|
||||
processor: Any,
|
||||
@@ -127,13 +156,10 @@ def run_benchmark(
|
||||
Raises:
|
||||
ValueError: If benchmark is not enabled in config.
|
||||
"""
|
||||
if not config.enabled:
|
||||
raise ValueError(
|
||||
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
|
||||
)
|
||||
|
||||
# Create dataset
|
||||
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
|
||||
console.print(
|
||||
f"[cyan]Loading dataset:[/cyan] {config.dataset.source_type} - {config.dataset.path}"
|
||||
)
|
||||
dataset = create_dataset(config.dataset)
|
||||
|
||||
# Get train and test splits
|
||||
@@ -152,7 +178,7 @@ def run_benchmark(
|
||||
from utils.feature_extractor import infer_vector_dim
|
||||
|
||||
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||
print(f"Model output dimension: {vector_dim}")
|
||||
console.print(f"[cyan]Model output dimension:[/cyan] {vector_dim}")
|
||||
|
||||
# Ensure table exists with correct schema
|
||||
table = _ensure_table(config, model_name, vector_dim)
|
||||
@@ -161,26 +187,20 @@ def run_benchmark(
|
||||
# Check if database is already built
|
||||
table_count = table.count_rows()
|
||||
if table_count > 0:
|
||||
print(
|
||||
f"Table '{table_name}' already has {table_count} entries, skipping database build."
|
||||
console.print(
|
||||
f"[yellow]Table '{table_name}' already has {table_count} entries, skipping database build.[/yellow]"
|
||||
)
|
||||
else:
|
||||
# Create and run benchmark task
|
||||
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||
print(f"Building database with {len(train_dataset)} training samples...")
|
||||
console.print(
|
||||
f"[cyan]Building database[/cyan] with {len(train_dataset)} training samples..."
|
||||
)
|
||||
task.build_database(model, processor, train_dataset, table, config.batch_size)
|
||||
|
||||
# Run evaluation
|
||||
# Run evaluation (results with Rich table will be printed by the task)
|
||||
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||
print(f"Evaluating on {len(test_dataset)} test samples...")
|
||||
console.print(f"[cyan]Evaluating[/cyan] on {len(test_dataset)} test samples...")
|
||||
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
|
||||
|
||||
# Print results
|
||||
print("\n=== Benchmark Results ===")
|
||||
print(f"Dataset: {config.dataset.path}")
|
||||
print(f"Task: {config.task.name}")
|
||||
print(f"Top-K: {results['top_k']}")
|
||||
print(f"Accuracy: {results['accuracy']:.4f}")
|
||||
print(f"Correct: {results['correct']}/{results['total']}")
|
||||
|
||||
return results
|
||||
|
||||
@@ -6,11 +6,10 @@ import lancedb
|
||||
import pyarrow as pa
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
from benchmarks.tasks.registry import RegisterTask
|
||||
from rich.progress import track
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import BitImageProcessorFast
|
||||
|
||||
from utils.feature_extractor import extract_batch_features, infer_vector_dim
|
||||
|
||||
|
||||
@@ -47,11 +46,11 @@ def _establish_eval_database(
|
||||
dataloader: DataLoader for the training dataset.
|
||||
"""
|
||||
# Extract all features using the utility function
|
||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||
all_features = extract_batch_features(processor, model, dataloader)
|
||||
|
||||
# Store features to database
|
||||
global_idx = 0
|
||||
for batch in tqdm(dataloader, desc="Storing eval database"):
|
||||
for batch in track(dataloader, description="Storing eval database"):
|
||||
labels = batch["label"]
|
||||
labels_list = labels.tolist()
|
||||
batch_size = len(labels_list)
|
||||
@@ -89,13 +88,13 @@ def _evaluate_recall(
|
||||
A tuple of (correct_count, total_count).
|
||||
"""
|
||||
# Extract all features using the utility function
|
||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
||||
all_features = extract_batch_features(processor, model, dataloader)
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
feature_idx = 0
|
||||
|
||||
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
||||
for batch in track(dataloader, description=f"Evaluating Recall@{top_k}"):
|
||||
labels = batch["label"]
|
||||
labels_list = labels.tolist()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user