refactor(benchmarks): modularize benchmark system with config-driven execution

This commit is contained in:
2026-03-02 16:00:36 +08:00
parent a7b01cb49e
commit a16b376dd7
14 changed files with 779 additions and 180 deletions

View File

@@ -16,9 +16,51 @@ if __name__ == "__main__":
epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt"
)
elif args.action == "benchmark":
from benchmarks import evaluate
from typing import cast
evaluate("Dinov2", "CIFAR-10", "Recall@10")
import torch
from benchmarks import run_benchmark
from compressors import DinoCompressor
from configs import cfg_manager
from transformers import AutoImageProcessor, BitImageProcessorFast
from utils import get_device
config = cfg_manager.get()
benchmark_cfg = config.benchmark
if not benchmark_cfg.enabled:
print("Benchmark is not enabled. Set benchmark.enabled=true in config.yaml")
exit(1)
device = get_device()
# Load model and processor based on config
model_cfg = config.model
processor = cast(
BitImageProcessorFast,
AutoImageProcessor.from_pretrained(model_cfg.name, device_map=device),
)
# Load compressor weights if specified in model config
model = DinoCompressor().to(device)
if model_cfg.compressor_path is not None:
from compressors import HashCompressor
compressor = HashCompressor(
input_dim=model_cfg.compression_dim,
output_dim=model_cfg.compression_dim,
)
compressor.load_state_dict(torch.load(model_cfg.compressor_path))
# Wrap with compressor if path is specified
model.compressor = compressor
# Run benchmark
run_benchmark(
model=model,
processor=processor,
config=benchmark_cfg,
model_name="dinov2",
)
elif args.action == "visualize":
from visualizer import app