refactor(ui): replace tqdm with rich for enhanced console output

This commit is contained in:
2026-03-06 16:20:38 +08:00
parent 4a6918ce56
commit e832f9d656
9 changed files with 113 additions and 95 deletions

View File

@@ -7,7 +7,7 @@ from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from transformers import BitImageProcessorFast
from tqdm.auto import tqdm
from rich.progress import track
def _extract_features_from_output(output: Any) -> torch.Tensor:
@@ -86,7 +86,6 @@ def extract_batch_features(
model: nn.Module,
images: Union[List[Any], Any],
batch_size: int = 32,
show_progress: bool = False,
) -> torch.Tensor:
"""Extract features from a batch of images.
@@ -95,7 +94,6 @@ def extract_batch_features(
model: Feature extraction model.
images: List of images, DataLoader, or other iterable.
batch_size: Batch size for processing.
show_progress: Whether to show progress bar.
Returns:
Tensor of shape [batch_size, feature_dim].
@@ -106,8 +104,7 @@ def extract_batch_features(
# Handle DataLoader input
if isinstance(images, DataLoader):
all_features = []
iterator = tqdm(images, desc="Extracting features") if show_progress else images
for batch in iterator:
for batch in track(images, description="Extracting features"):
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
inputs = processor(images=imgs, return_tensors="pt")
inputs.to(device)
@@ -118,8 +115,7 @@ def extract_batch_features(
# Handle list of images
all_features = []
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
for i in iterator:
for i in track(range(0, len(images), batch_size), description="Extracting features"):
batch_imgs = images[i : i + batch_size]
inputs = processor(images=batch_imgs, return_tensors="pt")
inputs.to(device)