mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(ui): replace tqdm with rich for enhanced console output
This commit is contained in:
@@ -7,6 +7,10 @@ import lancedb
|
|||||||
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
|
from benchmarks.datasets import HuggingFaceDataset, LocalDataset
|
||||||
from benchmarks.tasks import get_task
|
from benchmarks.tasks import get_task
|
||||||
from configs.models import BenchmarkConfig, DatasetSourceConfig
|
from configs.models import BenchmarkConfig, DatasetSourceConfig
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(config: DatasetSourceConfig) -> Any:
|
def create_dataset(config: DatasetSourceConfig) -> Any:
|
||||||
@@ -92,7 +96,9 @@ def _ensure_table(
|
|||||||
if table_name in existing_tables:
|
if table_name in existing_tables:
|
||||||
table = db.open_table(table_name)
|
table = db.open_table(table_name)
|
||||||
if table.schema != schema:
|
if table.schema != schema:
|
||||||
print(f"Table '{table_name}' schema mismatch, rebuilding.")
|
console.print(
|
||||||
|
f"[yellow]Table '{table_name}' schema mismatch, rebuilding.[/yellow]"
|
||||||
|
)
|
||||||
db.drop_table(table_name)
|
db.drop_table(table_name)
|
||||||
table = db.create_table(table_name, schema=schema)
|
table = db.create_table(table_name, schema=schema)
|
||||||
else:
|
else:
|
||||||
@@ -101,6 +107,29 @@ def _ensure_table(
|
|||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def _print_benchmark_info(
|
||||||
|
config: BenchmarkConfig, vector_dim: int, table_name: str, table_count: int
|
||||||
|
) -> None:
|
||||||
|
"""Print benchmark configuration info using Rich table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Benchmark configuration.
|
||||||
|
vector_dim: Feature vector dimension.
|
||||||
|
table_name: Database table name.
|
||||||
|
table_count: Number of entries in the table.
|
||||||
|
"""
|
||||||
|
table = Table(title="Benchmark Configuration", show_header=False)
|
||||||
|
table.add_column("Key", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("Value", style="magenta")
|
||||||
|
|
||||||
|
table.add_row("Dataset", f"{config.dataset.source_type} - {config.dataset.path}")
|
||||||
|
table.add_row("Model Output Dimension", str(vector_dim))
|
||||||
|
table.add_row("Table Name", table_name)
|
||||||
|
table.add_row("Table Entries", str(table_count))
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(
|
def run_benchmark(
|
||||||
model: Any,
|
model: Any,
|
||||||
processor: Any,
|
processor: Any,
|
||||||
@@ -127,13 +156,10 @@ def run_benchmark(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If benchmark is not enabled in config.
|
ValueError: If benchmark is not enabled in config.
|
||||||
"""
|
"""
|
||||||
if not config.enabled:
|
|
||||||
raise ValueError(
|
|
||||||
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dataset
|
# Create dataset
|
||||||
print(f"Loading dataset: {config.dataset.source_type} - {config.dataset.path}")
|
console.print(
|
||||||
|
f"[cyan]Loading dataset:[/cyan] {config.dataset.source_type} - {config.dataset.path}"
|
||||||
|
)
|
||||||
dataset = create_dataset(config.dataset)
|
dataset = create_dataset(config.dataset)
|
||||||
|
|
||||||
# Get train and test splits
|
# Get train and test splits
|
||||||
@@ -152,7 +178,7 @@ def run_benchmark(
|
|||||||
from utils.feature_extractor import infer_vector_dim
|
from utils.feature_extractor import infer_vector_dim
|
||||||
|
|
||||||
vector_dim = infer_vector_dim(processor, model, sample_image)
|
vector_dim = infer_vector_dim(processor, model, sample_image)
|
||||||
print(f"Model output dimension: {vector_dim}")
|
console.print(f"[cyan]Model output dimension:[/cyan] {vector_dim}")
|
||||||
|
|
||||||
# Ensure table exists with correct schema
|
# Ensure table exists with correct schema
|
||||||
table = _ensure_table(config, model_name, vector_dim)
|
table = _ensure_table(config, model_name, vector_dim)
|
||||||
@@ -161,26 +187,20 @@ def run_benchmark(
|
|||||||
# Check if database is already built
|
# Check if database is already built
|
||||||
table_count = table.count_rows()
|
table_count = table.count_rows()
|
||||||
if table_count > 0:
|
if table_count > 0:
|
||||||
print(
|
console.print(
|
||||||
f"Table '{table_name}' already has {table_count} entries, skipping database build."
|
f"[yellow]Table '{table_name}' already has {table_count} entries, skipping database build.[/yellow]"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Create and run benchmark task
|
# Create and run benchmark task
|
||||||
task = get_task(config.task.type, top_k=config.task.top_k)
|
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||||
print(f"Building database with {len(train_dataset)} training samples...")
|
console.print(
|
||||||
|
f"[cyan]Building database[/cyan] with {len(train_dataset)} training samples..."
|
||||||
|
)
|
||||||
task.build_database(model, processor, train_dataset, table, config.batch_size)
|
task.build_database(model, processor, train_dataset, table, config.batch_size)
|
||||||
|
|
||||||
# Run evaluation
|
# Run evaluation (results with Rich table will be printed by the task)
|
||||||
task = get_task(config.task.type, top_k=config.task.top_k)
|
task = get_task(config.task.type, top_k=config.task.top_k)
|
||||||
print(f"Evaluating on {len(test_dataset)} test samples...")
|
console.print(f"[cyan]Evaluating[/cyan] on {len(test_dataset)} test samples...")
|
||||||
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
|
results = task.evaluate(model, processor, test_dataset, table, config.batch_size)
|
||||||
|
|
||||||
# Print results
|
|
||||||
print("\n=== Benchmark Results ===")
|
|
||||||
print(f"Dataset: {config.dataset.path}")
|
|
||||||
print(f"Task: {config.task.name}")
|
|
||||||
print(f"Top-K: {results['top_k']}")
|
|
||||||
print(f"Accuracy: {results['accuracy']:.4f}")
|
|
||||||
print(f"Correct: {results['correct']}/{results['total']}")
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ import lancedb
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from benchmarks.base import BaseBenchmarkTask
|
from benchmarks.base import BaseBenchmarkTask
|
||||||
from benchmarks.tasks.registry import RegisterTask
|
from benchmarks.tasks.registry import RegisterTask
|
||||||
|
from rich.progress import track
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import BitImageProcessorFast
|
from transformers import BitImageProcessorFast
|
||||||
|
|
||||||
from utils.feature_extractor import extract_batch_features, infer_vector_dim
|
from utils.feature_extractor import extract_batch_features, infer_vector_dim
|
||||||
|
|
||||||
|
|
||||||
@@ -47,11 +46,11 @@ def _establish_eval_database(
|
|||||||
dataloader: DataLoader for the training dataset.
|
dataloader: DataLoader for the training dataset.
|
||||||
"""
|
"""
|
||||||
# Extract all features using the utility function
|
# Extract all features using the utility function
|
||||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
all_features = extract_batch_features(processor, model, dataloader)
|
||||||
|
|
||||||
# Store features to database
|
# Store features to database
|
||||||
global_idx = 0
|
global_idx = 0
|
||||||
for batch in tqdm(dataloader, desc="Storing eval database"):
|
for batch in track(dataloader, description="Storing eval database"):
|
||||||
labels = batch["label"]
|
labels = batch["label"]
|
||||||
labels_list = labels.tolist()
|
labels_list = labels.tolist()
|
||||||
batch_size = len(labels_list)
|
batch_size = len(labels_list)
|
||||||
@@ -89,13 +88,13 @@ def _evaluate_recall(
|
|||||||
A tuple of (correct_count, total_count).
|
A tuple of (correct_count, total_count).
|
||||||
"""
|
"""
|
||||||
# Extract all features using the utility function
|
# Extract all features using the utility function
|
||||||
all_features = extract_batch_features(processor, model, dataloader, show_progress=True)
|
all_features = extract_batch_features(processor, model, dataloader)
|
||||||
|
|
||||||
correct = 0
|
correct = 0
|
||||||
total = 0
|
total = 0
|
||||||
feature_idx = 0
|
feature_idx = 0
|
||||||
|
|
||||||
for batch in tqdm(dataloader, desc=f"Evaluating Recall@{top_k}"):
|
for batch in track(dataloader, description=f"Evaluating Recall@{top_k}"):
|
||||||
labels = batch["label"]
|
labels = batch["label"]
|
||||||
labels_list = labels.tolist()
|
labels_list = labels.tolist()
|
||||||
|
|
||||||
|
|||||||
@@ -21,13 +21,6 @@ def benchmark(
|
|||||||
config = cfg_manager.get()
|
config = cfg_manager.get()
|
||||||
benchmark_cfg = config.benchmark
|
benchmark_cfg = config.benchmark
|
||||||
|
|
||||||
if not benchmark_cfg.enabled:
|
|
||||||
typer.echo(
|
|
||||||
"Benchmark is not enabled. Set benchmark.enabled=true in config.yaml",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
||||||
model_cfg = config.model
|
model_cfg = config.model
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from compressors import HashCompressor, HashLoss
|
|||||||
from configs import cfg_manager
|
from configs import cfg_manager
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@@ -104,59 +104,71 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(start_epoch, epoch_size):
|
progress = Progress(
|
||||||
train_bar = tqdm(dataloader, desc=f"Epoch [{epoch + 1}/{epoch_size}]")
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
for batch in train_bar:
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||||
global_step += 1
|
TimeRemainingColumn(),
|
||||||
|
)
|
||||||
# ---- training step ----
|
with progress:
|
||||||
imgs = batch["img"]
|
for epoch in range(start_epoch, epoch_size):
|
||||||
|
task_id = progress.add_task(
|
||||||
# ---- teacher forward ----
|
f"Epoch [{epoch + 1}/{epoch_size}]", total=len(dataloader)
|
||||||
with torch.no_grad():
|
|
||||||
inputs = processor(imgs, return_tensors="pt").to(device)
|
|
||||||
|
|
||||||
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
|
|
||||||
|
|
||||||
teacher_embed = teacher_tokens.mean(dim=1)
|
|
||||||
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
|
||||||
|
|
||||||
# ---- student forward ----
|
|
||||||
logits, hash_codes, bits = compressor(teacher_tokens)
|
|
||||||
|
|
||||||
# ---- generate positive mask ----
|
|
||||||
labels = batch["label"]
|
|
||||||
# positive_mask[i,j] = True if labels[i] == labels[j]
|
|
||||||
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
|
|
||||||
|
|
||||||
# ---- loss ----
|
|
||||||
total_loss, components = loss_fn(
|
|
||||||
logits=logits,
|
|
||||||
hash_codes=hash_codes,
|
|
||||||
teacher_embed=teacher_embed,
|
|
||||||
positive_mask=positive_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---- backward ----
|
for batch in dataloader:
|
||||||
optimizer.zero_grad()
|
progress.update(task_id, advance=1)
|
||||||
total_loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# ---- logging ----
|
global_step += 1
|
||||||
train_bar.set_postfix(
|
|
||||||
loss=f"{components['total']:.4f}",
|
|
||||||
cont=f"{components['contrastive']:.2f}",
|
|
||||||
distill=f"{components['distill']:.3f}",
|
|
||||||
quant=f"{components['quantization']:.2f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---- periodic save ----
|
# ---- training step ----
|
||||||
if global_step % save_every == 0:
|
imgs = batch["img"]
|
||||||
save_checkpoint(
|
|
||||||
compressor, optimizer, epoch, global_step, checkpoint_path
|
# ---- teacher forward ----
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = processor(imgs, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
teacher_tokens = dino(**inputs).last_hidden_state # [B,N,1024]
|
||||||
|
|
||||||
|
teacher_embed = teacher_tokens.mean(dim=1)
|
||||||
|
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||||
|
|
||||||
|
# ---- student forward ----
|
||||||
|
logits, hash_codes, bits = compressor(teacher_tokens)
|
||||||
|
|
||||||
|
# ---- generate positive mask ----
|
||||||
|
labels = batch["label"]
|
||||||
|
# positive_mask[i,j] = True if labels[i] == labels[j]
|
||||||
|
positive_mask = labels.unsqueeze(0) == labels.unsqueeze(1) # [B, B]
|
||||||
|
|
||||||
|
# ---- loss ----
|
||||||
|
total_loss, components = loss_fn(
|
||||||
|
logits=logits,
|
||||||
|
hash_codes=hash_codes,
|
||||||
|
teacher_embed=teacher_embed,
|
||||||
|
positive_mask=positive_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- backward ----
|
||||||
|
optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# ---- logging ----
|
||||||
|
progress.update(
|
||||||
|
task_id,
|
||||||
|
description=f"Epoch [{epoch + 1}/{epoch_size}] "
|
||||||
|
f"loss={components['total']:.4f} "
|
||||||
|
f"cont={components['contrastive']:.2f} "
|
||||||
|
f"distill={components['distill']:.3f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- periodic save ----
|
||||||
|
if global_step % save_every == 0:
|
||||||
|
save_checkpoint(
|
||||||
|
compressor, optimizer, epoch, global_step, checkpoint_path
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n⚠️ Training interrupted, saving checkpoint...")
|
print("\n⚠️ Training interrupted, saving checkpoint...")
|
||||||
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
save_checkpoint(compressor, optimizer, epoch, global_step, checkpoint_path)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ dataset:
|
|||||||
seed: 42
|
seed: 42
|
||||||
|
|
||||||
benchmark:
|
benchmark:
|
||||||
enabled: true
|
|
||||||
dataset:
|
dataset:
|
||||||
source_type: "huggingface"
|
source_type: "huggingface"
|
||||||
path: "uoft-cs/cifar10"
|
path: "uoft-cs/cifar10"
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class BenchmarkConfig(BaseModel):
|
|||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
enabled: bool = Field(default=False, description="Enable benchmark evaluation")
|
|
||||||
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
|
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
|
||||||
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
|
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
|
||||||
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")
|
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Resampling
|
from PIL.Image import Resampling
|
||||||
from tqdm.auto import tqdm
|
from rich.progress import track
|
||||||
|
|
||||||
|
|
||||||
class ImageSynthesizer:
|
class ImageSynthesizer:
|
||||||
@@ -287,7 +287,7 @@ class ImageSynthesizer:
|
|||||||
|
|
||||||
generated_files: list[Path] = []
|
generated_files: list[Path] = []
|
||||||
|
|
||||||
for i in tqdm(range(self.num_scenes), desc="Generating scenes"):
|
for i in track(range(self.num_scenes), description="Generating scenes"):
|
||||||
# Update seed for each scene
|
# Update seed for each scene
|
||||||
random.seed(self.seed + i)
|
random.seed(self.seed + i)
|
||||||
np.random.seed(self.seed + i)
|
np.random.seed(self.seed + i)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from database import db_manager
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngImageFile
|
from PIL.PngImagePlugin import PngImageFile
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from tqdm.auto import tqdm
|
from rich.progress import track
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
@@ -93,7 +93,7 @@ class FeatureRetrieval:
|
|||||||
self.processor, self.model, images, batch_size=batch_size
|
self.processor, self.model, images, batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in tqdm(range(len(labels)), desc="Storing to database"):
|
for i in track(range(len(labels)), description="Storing to database"):
|
||||||
batch_label = labels[i] if label_map is None else label_map[labels[i]]
|
batch_label = labels[i] if label_map is None else label_map[labels[i]]
|
||||||
|
|
||||||
# Store to database
|
# Store to database
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from PIL import Image
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import BitImageProcessorFast
|
from transformers import BitImageProcessorFast
|
||||||
from tqdm.auto import tqdm
|
from rich.progress import track
|
||||||
|
|
||||||
|
|
||||||
def _extract_features_from_output(output: Any) -> torch.Tensor:
|
def _extract_features_from_output(output: Any) -> torch.Tensor:
|
||||||
@@ -86,7 +86,6 @@ def extract_batch_features(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
images: Union[List[Any], Any],
|
images: Union[List[Any], Any],
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
show_progress: bool = False,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Extract features from a batch of images.
|
"""Extract features from a batch of images.
|
||||||
|
|
||||||
@@ -95,7 +94,6 @@ def extract_batch_features(
|
|||||||
model: Feature extraction model.
|
model: Feature extraction model.
|
||||||
images: List of images, DataLoader, or other iterable.
|
images: List of images, DataLoader, or other iterable.
|
||||||
batch_size: Batch size for processing.
|
batch_size: Batch size for processing.
|
||||||
show_progress: Whether to show progress bar.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of shape [batch_size, feature_dim].
|
Tensor of shape [batch_size, feature_dim].
|
||||||
@@ -106,8 +104,7 @@ def extract_batch_features(
|
|||||||
# Handle DataLoader input
|
# Handle DataLoader input
|
||||||
if isinstance(images, DataLoader):
|
if isinstance(images, DataLoader):
|
||||||
all_features = []
|
all_features = []
|
||||||
iterator = tqdm(images, desc="Extracting features") if show_progress else images
|
for batch in track(images, description="Extracting features"):
|
||||||
for batch in iterator:
|
|
||||||
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
|
imgs = batch["img"] if isinstance(batch, dict) else batch[0]
|
||||||
inputs = processor(images=imgs, return_tensors="pt")
|
inputs = processor(images=imgs, return_tensors="pt")
|
||||||
inputs.to(device)
|
inputs.to(device)
|
||||||
@@ -118,8 +115,7 @@ def extract_batch_features(
|
|||||||
|
|
||||||
# Handle list of images
|
# Handle list of images
|
||||||
all_features = []
|
all_features = []
|
||||||
iterator = tqdm(range(0, len(images), batch_size), desc="Extracting features") if show_progress else range(0, len(images), batch_size)
|
for i in track(range(0, len(images), batch_size), description="Extracting features"):
|
||||||
for i in iterator:
|
|
||||||
batch_imgs = images[i : i + batch_size]
|
batch_imgs = images[i : i + batch_size]
|
||||||
inputs = processor(images=batch_imgs, return_tensors="pt")
|
inputs = processor(images=batch_imgs, return_tensors="pt")
|
||||||
inputs.to(device)
|
inputs.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user