mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(configs): implement Pydantic configuration system with type safety
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
from .models import (
|
||||
ModelConfig,
|
||||
VisualizationConfig,
|
||||
OutputConfig,
|
||||
FeatureCompressorConfig,
|
||||
PoolingType,
|
||||
)
|
||||
from .loader import load_yaml, save_yaml, ConfigError
|
||||
from .config import (
|
||||
ConfigManager,
|
||||
cfg_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"ModelConfig",
|
||||
"VisualizationConfig",
|
||||
"OutputConfig",
|
||||
"FeatureCompressorConfig",
|
||||
"PoolingType",
|
||||
# Loader
|
||||
"load_yaml",
|
||||
"save_yaml",
|
||||
"ConfigError",
|
||||
# Manager
|
||||
"ConfigManager",
|
||||
"cfg_manager",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,136 @@
|
||||
from enum import Enum
|
||||
"""Configuration manager for multiple configurations."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import yaml
|
||||
from .loader import load_yaml, save_yaml
|
||||
from .models import FeatureCompressorConfig
|
||||
|
||||
|
||||
class Config(Enum):
|
||||
FEATURE_COMPRESSOR = "feature_compressor.yaml"
|
||||
class ConfigManager:
|
||||
"""Singleton configuration manager supporting multiple configs."""
|
||||
|
||||
_instance: Optional["ConfigManager"] = None
|
||||
_configs: Dict[str, FeatureCompressorConfig] = {}
|
||||
|
||||
def __new__(cls) -> "ConfigManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.config_dir = Path(__file__).parent
|
||||
|
||||
def load_config(
|
||||
self, config_name: str = "feature_compressor"
|
||||
) -> FeatureCompressorConfig:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_name: Name of config file without extension
|
||||
|
||||
Returns:
|
||||
Loaded and validated FeatureCompressorConfig instance
|
||||
"""
|
||||
config_path = self.config_dir / f"{config_name}.yaml"
|
||||
config = load_yaml(config_path, FeatureCompressorConfig)
|
||||
self._configs[config_name] = config
|
||||
return config
|
||||
|
||||
def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]:
|
||||
"""Load all YAML configuration files from config directory.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping config names to FeatureCompressorConfig instances
|
||||
"""
|
||||
config_files = list(self.config_dir.glob("*.yaml"))
|
||||
loaded_configs = {}
|
||||
|
||||
for config_path in config_files:
|
||||
config_name = config_path.stem
|
||||
if config_name.startswith("_"):
|
||||
continue # Skip private configs
|
||||
config = load_yaml(config_path, FeatureCompressorConfig)
|
||||
loaded_configs[config_name] = config
|
||||
|
||||
self._configs.update(loaded_configs)
|
||||
return loaded_configs
|
||||
|
||||
def get_config(
|
||||
self, config_name: str = "feature_compressor"
|
||||
) -> FeatureCompressorConfig:
|
||||
"""Get loaded configuration by name.
|
||||
|
||||
Args:
|
||||
config_name: Name of configuration to retrieve
|
||||
|
||||
Returns:
|
||||
FeatureCompressorConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration not loaded
|
||||
"""
|
||||
if config_name not in self._configs:
|
||||
raise ValueError(
|
||||
f"Configuration '{config_name}' not loaded. "
|
||||
f"Call load_config('{config_name}') or load_all_configs() first."
|
||||
)
|
||||
return self._configs[config_name]
|
||||
|
||||
def get_or_load_config(
|
||||
self, config_name: str = "feature_compressor"
|
||||
) -> FeatureCompressorConfig:
|
||||
"""Get loaded configuration by name or load it if not loaded.
|
||||
|
||||
Args:
|
||||
config_name: Name of configuration to retrieve
|
||||
|
||||
Returns:
|
||||
FeatureCompressorConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration not loaded
|
||||
"""
|
||||
if config_name not in self._configs:
|
||||
return self.load_config(config_name)
|
||||
return self._configs[config_name]
|
||||
|
||||
def list_configs(self) -> list[str]:
|
||||
"""List names of all currently loaded configurations.
|
||||
|
||||
Returns:
|
||||
List of configuration names
|
||||
"""
|
||||
return list(self._configs.keys())
|
||||
|
||||
def save_config(
|
||||
self,
|
||||
config_name: str = "feature_compressor",
|
||||
config: Optional[FeatureCompressorConfig] = None,
|
||||
path: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Save configuration to YAML file.
|
||||
|
||||
Args:
|
||||
config_name: Name of config file without extension
|
||||
config: Configuration to save. If None, saves currently loaded config for that name.
|
||||
path: Optional custom path to save to. If None, saves to config_dir.
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration loaded for the given name and config is None
|
||||
"""
|
||||
if config is None:
|
||||
if config_name not in self._configs:
|
||||
raise ValueError(
|
||||
f"No configuration loaded for '{config_name}'. "
|
||||
f"Cannot save without providing config parameter."
|
||||
)
|
||||
config = self._configs[config_name]
|
||||
|
||||
save_path = path if path else self.config_dir / f"{config_name}.yaml"
|
||||
save_yaml(save_path, config)
|
||||
self._configs[config_name] = config
|
||||
|
||||
|
||||
def get_config_dir() -> Path:
|
||||
return Path(__file__).parent
|
||||
|
||||
|
||||
def get_default_config(config_type: Config) -> Dict[Unknown, Unknown]:
|
||||
config_path = get_config_dir() / config_type.value
|
||||
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
# Global singleton instance
|
||||
cfg_manager = ConfigManager()
|
||||
|
||||
69
mini-nav/configs/loader.py
Normal file
69
mini-nav/configs/loader.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Generic YAML loader with Pydantic validation."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TypeVar, Type
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
"""Configuration loading and validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def load_yaml(path: Path, model_class: Type[T]) -> T:
|
||||
"""Load and validate YAML configuration file.
|
||||
|
||||
Args:
|
||||
path: Path to YAML file (str or Path object)
|
||||
model_class: Pydantic model class to validate against
|
||||
|
||||
Returns:
|
||||
Validated model instance of type T
|
||||
|
||||
Raises:
|
||||
ConfigError: On file not found, YAML parsing error, or validation failure
|
||||
"""
|
||||
# Coerce str to Path if needed
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
except FileNotFoundError as e:
|
||||
raise ConfigError(f"Configuration file not found: {path}") from e
|
||||
except yaml.YAMLError as e:
|
||||
raise ConfigError(f"YAML parsing error: {e}") from e
|
||||
|
||||
try:
|
||||
return model_class.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise ConfigError(f"Configuration validation failed: {e}") from e
|
||||
|
||||
|
||||
def save_yaml(path: Path, model: BaseModel) -> None:
|
||||
"""Save Pydantic model to YAML file.
|
||||
|
||||
Args:
|
||||
path: Path to YAML file (str or Path object)
|
||||
model: Pydantic model instance to save
|
||||
|
||||
Raises:
|
||||
ConfigError: On file write error
|
||||
"""
|
||||
# Coerce str to Path if needed
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
try:
|
||||
data = model.model_dump(exclude_unset=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
|
||||
except Exception as e:
|
||||
raise ConfigError(f"Failed to save configuration to {path}: {e}") from e
|
||||
77
mini-nav/configs/models.py
Normal file
77
mini-nav/configs/models.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Pydantic data models for feature compressor configuration."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class PoolingType(str, Enum):
|
||||
"""Enum for pooling types."""
|
||||
|
||||
ATTENTION = "attention"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Configuration for the vision model and compression."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
name: str = "facebook/dinov2-large"
|
||||
compression_dim: int = Field(
|
||||
default=256, gt=0, description="Output feature dimension"
|
||||
)
|
||||
pooling_type: PoolingType = PoolingType.ATTENTION
|
||||
top_k_ratio: float = Field(
|
||||
default=0.5, ge=0, le=1, description="Ratio of tokens to keep"
|
||||
)
|
||||
hidden_ratio: float = Field(
|
||||
default=2.0, gt=0, description="MLP hidden dim as multiple of compression_dim"
|
||||
)
|
||||
dropout_rate: float = Field(
|
||||
default=0.1, ge=0, le=1, description="Dropout probability"
|
||||
)
|
||||
use_residual: bool = True
|
||||
device: str = "auto"
|
||||
|
||||
|
||||
class VisualizationConfig(BaseModel):
|
||||
"""Configuration for visualization settings."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
plot_theme: str = "plotly_white"
|
||||
color_scale: str = "viridis"
|
||||
point_size: int = Field(default=8, gt=0)
|
||||
fig_width: int = Field(default=900, gt=0)
|
||||
fig_height: int = Field(default=600, gt=0)
|
||||
|
||||
|
||||
class OutputConfig(BaseModel):
|
||||
"""Configuration for output settings."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
directory: Path = Path(__file__).parent.parent.parent / "outputs"
|
||||
html_self_contained: bool = True
|
||||
png_scale: int = Field(default=2, gt=0)
|
||||
|
||||
@field_validator("directory", mode="after")
|
||||
def convert_to_absolute(cls, v: Path) -> Path:
|
||||
"""
|
||||
Converts the path to an absolute path relative to the current working directory.
|
||||
This works even if the path doesn't exist on disk.
|
||||
"""
|
||||
if v.is_absolute():
|
||||
return v
|
||||
return Path(__file__).parent.parent.parent / v
|
||||
|
||||
|
||||
class FeatureCompressorConfig(BaseModel):
|
||||
"""Root configuration for the feature compressor."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
model: ModelConfig
|
||||
visualization: VisualizationConfig
|
||||
output: OutputConfig
|
||||
Reference in New Issue
Block a user