refactor(config): simplify config manager to single unified config

This commit is contained in:
2026-02-05 15:47:05 +08:00
parent 3d90e75441
commit 7ce97c1965
8 changed files with 50 additions and 110 deletions

View File

@@ -1,3 +1,3 @@
from .database import DatabaseManager, db_manager, db_schema from database import DatabaseManager, db_manager, db_schema
__all__ = ["DatabaseManager", "db_manager", "db_schema"] __all__ = ["DatabaseManager", "db_manager", "db_schema"]

View File

@@ -8,7 +8,6 @@ from .models import (
from .loader import load_yaml, save_yaml, ConfigError from .loader import load_yaml, save_yaml, ConfigError
from .config import ( from .config import (
ConfigManager, ConfigManager,
ConfigType,
cfg_manager, cfg_manager,
) )
@@ -25,6 +24,5 @@ __all__ = [
"ConfigError", "ConfigError",
# Manager # Manager
"ConfigManager", "ConfigManager",
"ConfigType",
"cfg_manager", "cfg_manager",
] ]

View File

@@ -1,134 +1,80 @@
"""Configuration manager for multiple configurations.""" """Configuration manager for unified config."""
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Optional
from .loader import load_yaml, save_yaml from .loader import load_yaml, save_yaml
from .models import FeatureCompressorConfig from .models import FeatureCompressorConfig
class ConfigType(str, Enum):
FeatureCompressor = "feature_compressor"
class ConfigManager: class ConfigManager:
"""Singleton configuration manager supporting multiple configs.""" """Singleton configuration manager for unified config."""
_instance: Optional["ConfigManager"] = None _instance: Optional["ConfigManager"] = None
_configs: Dict[str, FeatureCompressorConfig] = {} _config: Optional[FeatureCompressorConfig] = None
def __new__(cls) -> "ConfigManager": def __new__(cls) -> "ConfigManager":
"""Singleton pattern - ensure only one instance exists."""
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self): def __init__(self):
"""Initialize config manager with config directory and path."""
self.config_dir = Path(__file__).parent self.config_dir = Path(__file__).parent
self.config_path = self.config_dir / "config.yaml"
def load_config(self, config_name: ConfigType) -> FeatureCompressorConfig: def load(self) -> FeatureCompressorConfig:
"""Load configuration from YAML file. """Load configuration from config.yaml file.
Args:
config_name: Name of config file without extension
Returns: Returns:
Loaded and validated FeatureCompressorConfig instance Loaded and validated FeatureCompressorConfig instance
""" """
config_path = self.config_dir / f"{config_name}.yaml" config = load_yaml(self.config_path, FeatureCompressorConfig)
config = load_yaml(config_path, FeatureCompressorConfig) self._config = config
self._configs[config_name] = config
return config return config
def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]: def get(self) -> FeatureCompressorConfig:
"""Load all YAML configuration files from config directory. """Get loaded configuration, auto-loading if not already loaded.
Returns:
Dictionary mapping config names to FeatureCompressorConfig instances
"""
config_files = list(self.config_dir.glob("*.yaml"))
loaded_configs = {}
for config_path in config_files:
config_name = config_path.stem
if config_name.startswith("_"):
continue # Skip private configs
config = load_yaml(config_path, FeatureCompressorConfig)
loaded_configs[config_name] = config
self._configs.update(loaded_configs)
return loaded_configs
def get_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Get loaded configuration by name.
Args:
config_name: Name of configuration to retrieve
Returns: Returns:
FeatureCompressorConfig instance FeatureCompressorConfig instance
Raises: Note:
ValueError: If configuration not loaded Automatically loads config if not already loaded
""" """
if config_name not in self._configs: # Auto-load if config not yet loaded
raise ValueError( if self._config is None:
f"Configuration '{config_name}' not loaded. " return self.load()
f"Call load_config('{config_name}') or load_all_configs() first." return self._config
)
return self._configs[config_name]
def get_or_load_config(self, config_name: ConfigType) -> FeatureCompressorConfig: def save(
"""Get loaded configuration by name or load it if not loaded.
Args:
config_name: Name of configuration to retrieve
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
"""
if config_name not in self._configs:
return self.load_config(config_name)
return self._configs[config_name]
def list_configs(self) -> list[str]:
"""List names of all currently loaded configurations.
Returns:
List of configuration names
"""
return list(self._configs.keys())
def save_config(
self, self,
config_name: ConfigType,
config: Optional[FeatureCompressorConfig] = None, config: Optional[FeatureCompressorConfig] = None,
path: Optional[Path] = None, path: Optional[Path] = None,
) -> None: ) -> None:
"""Save configuration to YAML file. """Save configuration to YAML file.
Args: Args:
config_name: Name of config file without extension config: Configuration to save. If None, saves currently loaded config.
config: Configuration to save. If None, saves currently loaded config for that name. path: Optional custom path. If None, saves to default config.yaml.
path: Optional custom path to save to. If None, saves to config_dir.
Raises: Raises:
ValueError: If no configuration loaded for the given name and config is None ValueError: If no configuration loaded and config is None
""" """
# Use provided config or fall back to loaded config
if config is None: if config is None:
if config_name not in self._configs: if self._config is None:
raise ValueError( raise ValueError(
f"No configuration loaded for '{config_name}'. " "No configuration loaded. "
f"Cannot save without providing config parameter." "Cannot save without providing config parameter."
) )
config = self._configs[config_name] config = self._config
save_path = path if path else self.config_dir / f"{config_name}.yaml" # Save to custom path or default config.yaml
save_path = path if path else self.config_path
save_yaml(save_path, config) save_yaml(save_path, config)
self._configs[config_name] = config self._config = config
# Global singleton instance # Global singleton instance

View File

@@ -1,8 +1,7 @@
from pathlib import Path
from typing import Optional from typing import Optional
import lancedb import lancedb
import pyarrow as pa import pyarrow as pa
from configs import ConfigType, cfg_manager from configs import cfg_manager
db_schema = pa.schema( db_schema = pa.schema(
[ [
@@ -27,7 +26,7 @@ class DatabaseManager:
def __init__(self): def __init__(self):
# 获取数据库位置 # 获取数据库位置
config = cfg_manager.get_or_load_config(ConfigType.FeatureCompressor) config = cfg_manager.get()
db_path = config.output.directory / "database" db_path = config.output.directory / "database"
# 初始化数据库与表格 # 初始化数据库与表格

View File

@@ -62,7 +62,7 @@ class DINOv2FeatureExtractor:
FeatureCompressorConfig instance FeatureCompressorConfig instance
""" """
if config_path is None: if config_path is None:
return cfg_manager.get_or_load_config("feature_compressor") return cfg_manager.get()
else: else:
return load_yaml(Path(config_path), FeatureCompressorConfig) return load_yaml(Path(config_path), FeatureCompressorConfig)

View File

@@ -43,7 +43,7 @@ class FeatureVisualizer:
Configuration Pydantic model Configuration Pydantic model
""" """
if config_path is None: if config_path is None:
return cfg_manager.get_or_load_config("feature_compressor") return cfg_manager.get()
else: else:
return load_yaml(Path(config_path), FeatureCompressorConfig) return load_yaml(Path(config_path), FeatureCompressorConfig)

View File

@@ -141,10 +141,8 @@ class TestYamlLoader:
"""Test suite for YAML loading and saving.""" """Test suite for YAML loading and saving."""
def test_load_existing_yaml(self): def test_load_existing_yaml(self):
"""Load feature_compressor.yaml and verify values.""" """Load config.yaml and verify values."""
config_path = ( config_path = Path(__file__).parent.parent / "configs" / "config.yaml"
Path(__file__).parent.parent / "configs" / "feature_compressor.yaml"
)
config = load_yaml(config_path, FeatureCompressorConfig) config = load_yaml(config_path, FeatureCompressorConfig)
# Verify model config # Verify model config
@@ -185,7 +183,7 @@ class TestYamlLoader:
def test_save_yaml_roundtrip(self): def test_save_yaml_roundtrip(self):
"""Create config, save to temp, verify file exists with content.""" """Create config, save to temp, verify file exists with content."""
original = cfg_manager.load_config("feature_compressor") original = cfg_manager.load()
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
temp_path = Path(f.name) temp_path = Path(f.name)
@@ -220,23 +218,22 @@ class TestConfigManager:
assert manager1 is manager2 assert manager1 is manager2
def test_load_config(self): def test_load_config(self):
"""Test loading feature_compressor config.""" """Test loading default config."""
config = cfg_manager.load_config("feature_compressor") config = cfg_manager.load()
assert config is not None assert config is not None
assert config.model.compression_dim == 256 assert config.model.compression_dim == 256
assert config.visualization.point_size == 8 assert config.visualization.point_size == 8
def test_get_config_not_loaded(self): def test_get_without_load(self):
"""Test that get_config() raises error for unloaded config.""" """Test that get() auto-loads config if not loaded."""
with pytest.raises(ValueError, match="not loaded"): # Reset the singleton's cached config
cfg_manager.get_config("nonexistent_config") cfg_manager._config = None
def test_list_configs(self): # get() should auto-load
"""Test listing all loaded configurations.""" config = cfg_manager.get()
cfg_manager.load_config("feature_compressor") assert config is not None
configs = cfg_manager.list_configs() assert config.model.compression_dim == 256
assert "feature_compressor" in configs
def test_save_config(self): def test_save_config(self):
"""Test saving configuration to file.""" """Test saving configuration to file."""
@@ -250,7 +247,7 @@ class TestConfigManager:
temp_path = Path(f.name) temp_path = Path(f.name)
try: try:
cfg_manager.save_config("test_config", config, path=temp_path) cfg_manager.save(config, path=temp_path)
loaded_config = load_yaml(temp_path, FeatureCompressorConfig) loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
assert loaded_config.model.compression_dim == 512 assert loaded_config.model.compression_dim == 512