mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(benchmarks): modularize benchmark system with config-driven execution
This commit is contained in:
@@ -2,10 +2,10 @@ model:
|
||||
name: "facebook/dinov2-large"
|
||||
compression_dim: 512
|
||||
device: "auto" # auto-detect GPU
|
||||
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
|
||||
sam_min_mask_area: 100 # Minimum mask area threshold
|
||||
sam_max_masks: 10 # Maximum number of masks to keep
|
||||
compressor_path: null # Path to trained HashCompressor weights (optional)
|
||||
sam_model: "facebook/sam2.1-hiera-large" # SAM model name
|
||||
sam_min_mask_area: 100 # Minimum mask area threshold
|
||||
sam_max_masks: 10 # Maximum number of masks to keep
|
||||
compressor_path: null # Path to trained HashCompressor weights (optional)
|
||||
|
||||
output:
|
||||
directory: "./outputs"
|
||||
@@ -19,3 +19,17 @@ dataset:
|
||||
rotation_range: [-30, 30]
|
||||
overlap_threshold: 0.3
|
||||
seed: 42
|
||||
|
||||
benchmark:
|
||||
enabled: true
|
||||
dataset:
|
||||
source_type: "huggingface"
|
||||
path: "uoft-cs/cifar10"
|
||||
img_column: "img"
|
||||
label_column: "label"
|
||||
task:
|
||||
name: "recall_at_k"
|
||||
type: "retrieval"
|
||||
top_k: 10
|
||||
batch_size: 64
|
||||
model_table_prefix: "benchmark"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Pydantic data models for feature compressor configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
@@ -98,6 +98,41 @@ class DatasetConfig(BaseModel):
|
||||
return v
|
||||
|
||||
|
||||
class DatasetSourceConfig(BaseModel):
|
||||
"""Configuration for benchmark dataset source."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
source_type: Literal["huggingface", "local"] = "huggingface"
|
||||
path: str = Field(default="", description="HuggingFace dataset ID or local path")
|
||||
img_column: str = Field(default="img", description="Image column name")
|
||||
label_column: str = Field(default="label", description="Label column name")
|
||||
|
||||
|
||||
class BenchmarkTaskConfig(BaseModel):
|
||||
"""Configuration for benchmark task."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
name: str = Field(default="recall_at_k", description="Task name")
|
||||
type: str = Field(default="retrieval", description="Task type")
|
||||
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""Configuration for benchmark evaluation."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable benchmark evaluation")
|
||||
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
|
||||
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
|
||||
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")
|
||||
model_table_prefix: str = Field(
|
||||
default="benchmark", description="Prefix for LanceDB table names"
|
||||
)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""Root configuration for the feature compressor."""
|
||||
|
||||
@@ -106,3 +141,4 @@ class Config(BaseModel):
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
output: OutputConfig = Field(default_factory=OutputConfig)
|
||||
dataset: DatasetConfig = Field(default_factory=DatasetConfig)
|
||||
benchmark: BenchmarkConfig = Field(default_factory=BenchmarkConfig)
|
||||
|
||||
Reference in New Issue
Block a user