refactor(ui): replace tqdm with rich for enhanced console output

This commit is contained in:
2026-03-06 16:20:38 +08:00
parent 4a6918ce56
commit e832f9d656
9 changed files with 113 additions and 95 deletions

View File

@@ -6,11 +6,10 @@ import lancedb
import pyarrow as pa
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from rich.progress import track
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils.feature_extractor import extract_batch_features, infer_vector_dim
@@ -47,11 +46,11 @@ def _establish_eval_database(
dataloader: DataLoader for the training dataset.
"""
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
all_features = extract_batch_features(processor, model, dataloader)
# Store features to database
global_idx = 0
for batch in tqdm(dataloader, desc="Storing eval database"):
for batch in track(dataloader, description="Storing eval database"):
labels = batch["label"]
labels_list = labels.tolist()
batch_size = len(labels_list)
@@ -89,13 +88,13 @@ def _evaluate_recall(
A tuple of (correct_count, total_count).
"""
# Extract all features using the utility function
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
all_features = extract_batch_features(processor, model, dataloader)
correct = 0
total = 0
feature_idx = 0
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
for batch in track(dataloader, description=f"Evaluating Recall@{top_k}"):
labels = batch["label"]
labels_list = labels.tolist()