mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
refactor(ui): replace tqdm with rich for enhanced console output
This commit is contained in:
@@ -7,7 +7,7 @@ from PIL import Image
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BitImageProcessorFast
|
||||
from tqdm.auto import tqdm
|
||||
from rich.progress import track
|
||||
|
||||
|
||||
def _extract_features_from_output(output: Any) -> torch.Tensor:
|
||||
@@ -86,7 +86,6 @@ def extract_batch_features(
|
||||
model: nn.Module,
|
||||
images: Union[List[Any], Any],
|
||||
batch_size: int = 32,
|
||||
show_progress: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Extract features from a batch of images.
|
||||
|
||||
@@ -95,7 +94,6 @@ def extract_batch_features(
|
||||
model: Feature extraction model.
|
||||
images: List of images, DataLoader, or other iterable.
|
||||
batch_size: Batch size for processing.
|
||||
show_progress: Whether to show progress bar.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [batch_size, feature_dim].
|
||||
@@ -106,8 +104,7 @@ def extract_batch_features(
|
||||
# Handle DataLoader input
|
||||
if isinstance(images, DataLoader):
|
||||
all_features = []
|
||||
iterator = tqdm(images, desc="Extracting features") if show_progress else images
|
||||
for batch in iterator:
|
||||
for batch in track(images, description="Extracting features"):
|
||||
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
|
||||
inputs = processor(images=imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
@@ -118,8 +115,7 @@ def extract_batch_features(
|
||||
|
||||
# Handle list of images
|
||||
all_features = []
|
||||
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
|
||||
for i in iterator:
|
||||
for i in track(range(0, len(images), batch_size), description="Extracting features"):
|
||||
batch_imgs = images[i : i + batch_size]
|
||||
inputs = processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs.to(device)
|
||||
|
||||
Reference in New Issue
Block a user