refactor(config): simplify config manager to single unified config

This commit is contained in:
2026-02-05 15:47:05 +08:00
parent 3d90e75441
commit 7ce97c1965
8 changed files with 50 additions and 110 deletions

View File

@@ -1,3 +1,3 @@
from .database import DatabaseManager, db_manager, db_schema
from database import DatabaseManager, db_manager, db_schema
__all__ = ["DatabaseManager", "db_manager", "db_schema"]

View File

@@ -8,7 +8,6 @@ from .models import (
from .loader import load_yaml, save_yaml, ConfigError
from .config import (
ConfigManager,
ConfigType,
cfg_manager,
)
@@ -25,6 +24,5 @@ __all__ = [
"ConfigError",
# Manager
"ConfigManager",
"ConfigType",
"cfg_manager",
]

View File

@@ -1,134 +1,80 @@
"""Configuration manager for multiple configurations."""
"""Configuration manager for unified config."""
from enum import Enum
from pathlib import Path
from typing import Dict, Optional
from typing import Optional
from .loader import load_yaml, save_yaml
from .models import FeatureCompressorConfig
class ConfigType(str, Enum):
FeatureCompressor = "feature_compressor"
class ConfigManager:
"""Singleton configuration manager supporting multiple configs."""
"""Singleton configuration manager for unified config."""
_instance: Optional["ConfigManager"] = None
_configs: Dict[str, FeatureCompressorConfig] = {}
_config: Optional[FeatureCompressorConfig] = None
def __new__(cls) -> "ConfigManager":
"""Singleton pattern - ensure only one instance exists."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize config manager with config directory and path."""
self.config_dir = Path(__file__).parent
self.config_path = self.config_dir / "config.yaml"
def load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_name: Name of config file without extension
def load(self) -> FeatureCompressorConfig:
"""Load configuration from config.yaml file.
Returns:
Loaded and validated FeatureCompressorConfig instance
"""
config_path = self.config_dir / f"{config_name}.yaml"
config = load_yaml(config_path, FeatureCompressorConfig)
self._configs[config_name] = config
config = load_yaml(self.config_path, FeatureCompressorConfig)
self._config = config
return config
def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]:
"""Load all YAML configuration files from config directory.
Returns:
Dictionary mapping config names to FeatureCompressorConfig instances
"""
config_files = list(self.config_dir.glob("*.yaml"))
loaded_configs = {}
for config_path in config_files:
config_name = config_path.stem
if config_name.startswith("_"):
continue # Skip private configs
config = load_yaml(config_path, FeatureCompressorConfig)
loaded_configs[config_name] = config
self._configs.update(loaded_configs)
return loaded_configs
def get_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Get loaded configuration by name.
Args:
config_name: Name of configuration to retrieve
def get(self) -> FeatureCompressorConfig:
"""Get loaded configuration, auto-loading if not already loaded.
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
Note:
Automatically loads config if not already loaded
"""
if config_name not in self._configs:
raise ValueError(
f"Configuration '{config_name}' not loaded. "
f"Call load_config('{config_name}') or load_all_configs() first."
)
return self._configs[config_name]
# Auto-load if config not yet loaded
if self._config is None:
return self.load()
return self._config
def get_or_load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Get loaded configuration by name or load it if not loaded.
Args:
config_name: Name of configuration to retrieve
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
"""
if config_name not in self._configs:
return self.load_config(config_name)
return self._configs[config_name]
def list_configs(self) -> list[str]:
"""List names of all currently loaded configurations.
Returns:
List of configuration names
"""
return list(self._configs.keys())
def save_config(
def save(
self,
config_name: ConfigType,
config: Optional[FeatureCompressorConfig] = None,
path: Optional[Path] = None,
) -> None:
"""Save configuration to YAML file.
Args:
config_name: Name of config file without extension
config: Configuration to save. If None, saves currently loaded config for that name.
path: Optional custom path to save to. If None, saves to config_dir.
config: Configuration to save. If None, saves currently loaded config.
path: Optional custom path. If None, saves to default config.yaml.
Raises:
ValueError: If no configuration loaded for the given name and config is None
ValueError: If no configuration loaded and config is None
"""
# Use provided config or fall back to loaded config
if config is None:
if config_name not in self._configs:
if self._config is None:
raise ValueError(
f"No configuration loaded for '{config_name}'. "
f"Cannot save without providing config parameter."
"No configuration loaded. "
"Cannot save without providing config parameter."
)
config = self._configs[config_name]
config = self._config
save_path = path if path else self.config_dir / f"{config_name}.yaml"
# Save to custom path or default config.yaml
save_path = path if path else self.config_path
save_yaml(save_path, config)
self._configs[config_name] = config
self._config = config
# Global singleton instance

View File

@@ -1,8 +1,7 @@
from pathlib import Path
from typing import Optional
import lancedb
import pyarrow as pa
from configs import ConfigType, cfg_manager
from configs import cfg_manager
db_schema = pa.schema(
[
@@ -27,7 +26,7 @@ class DatabaseManager:
def __init__(self):
# 获取数据库位置
config = cfg_manager.get_or_load_config(ConfigType.FeatureCompressor)
config = cfg_manager.get()
db_path = config.output.directory / "database"
# 初始化数据库与表格

View File

@@ -62,7 +62,7 @@ class DINOv2FeatureExtractor:
FeatureCompressorConfig instance
"""
if config_path is None:
return cfg_manager.get_or_load_config("feature_compressor")
return cfg_manager.get()
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)

View File

@@ -43,7 +43,7 @@ class FeatureVisualizer:
Configuration Pydantic model
"""
if config_path is None:
return cfg_manager.get_or_load_config("feature_compressor")
return cfg_manager.get()
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)

View File

@@ -141,10 +141,8 @@ class TestYamlLoader:
"""Test suite for YAML loading and saving."""
def test_load_existing_yaml(self):
"""Load feature_compressor.yaml and verify values."""
config_path = (
Path(__file__).parent.parent / "configs" / "feature_compressor.yaml"
)
"""Load config.yaml and verify values."""
config_path = Path(__file__).parent.parent / "configs" / "config.yaml"
config = load_yaml(config_path, FeatureCompressorConfig)
# Verify model config
@@ -185,7 +183,7 @@ class TestYamlLoader:
def test_save_yaml_roundtrip(self):
"""Create config, save to temp, verify file exists with content."""
original = cfg_manager.load_config("feature_compressor")
original = cfg_manager.load()
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
temp_path = Path(f.name)
@@ -220,23 +218,22 @@ class TestConfigManager:
assert manager1 is manager2
def test_load_config(self):
"""Test loading feature_compressor config."""
config = cfg_manager.load_config("feature_compressor")
"""Test loading default config."""
config = cfg_manager.load()
assert config is not None
assert config.model.compression_dim == 256
assert config.visualization.point_size == 8
def test_get_config_not_loaded(self):
"""Test that get_config() raises error for unloaded config."""
with pytest.raises(ValueError, match="not loaded"):
cfg_manager.get_config("nonexistent_config")
def test_get_without_load(self):
"""Test that get() auto-loads config if not loaded."""
# Reset the singleton's cached config
cfg_manager._config = None
def test_list_configs(self):
"""Test listing all loaded configurations."""
cfg_manager.load_config("feature_compressor")
configs = cfg_manager.list_configs()
assert "feature_compressor" in configs
# get() should auto-load
config = cfg_manager.get()
assert config is not None
assert config.model.compression_dim == 256
def test_save_config(self):
"""Test saving configuration to file."""
@@ -250,7 +247,7 @@ class TestConfigManager:
temp_path = Path(f.name)
try:
cfg_manager.save_config("test_config", config, path=temp_path)
cfg_manager.save(config, path=temp_path)
loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
assert loaded_config.model.compression_dim == 512