mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(config): simplify config manager to single unified config
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
from .database import DatabaseManager, db_manager, db_schema
|
from database import DatabaseManager, db_manager, db_schema
|
||||||
|
|
||||||
__all__ = ["DatabaseManager", "db_manager", "db_schema"]
|
__all__ = ["DatabaseManager", "db_manager", "db_schema"]
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from .models import (
|
|||||||
from .loader import load_yaml, save_yaml, ConfigError
|
from .loader import load_yaml, save_yaml, ConfigError
|
||||||
from .config import (
|
from .config import (
|
||||||
ConfigManager,
|
ConfigManager,
|
||||||
ConfigType,
|
|
||||||
cfg_manager,
|
cfg_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,6 +24,5 @@ __all__ = [
|
|||||||
"ConfigError",
|
"ConfigError",
|
||||||
# Manager
|
# Manager
|
||||||
"ConfigManager",
|
"ConfigManager",
|
||||||
"ConfigType",
|
|
||||||
"cfg_manager",
|
"cfg_manager",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,134 +1,80 @@
|
|||||||
"""Configuration manager for multiple configurations."""
|
"""Configuration manager for unified config."""
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .loader import load_yaml, save_yaml
|
from .loader import load_yaml, save_yaml
|
||||||
from .models import FeatureCompressorConfig
|
from .models import FeatureCompressorConfig
|
||||||
|
|
||||||
|
|
||||||
class ConfigType(str, Enum):
|
|
||||||
FeatureCompressor = "feature_compressor"
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
"""Singleton configuration manager supporting multiple configs."""
|
"""Singleton configuration manager for unified config."""
|
||||||
|
|
||||||
_instance: Optional["ConfigManager"] = None
|
_instance: Optional["ConfigManager"] = None
|
||||||
_configs: Dict[str, FeatureCompressorConfig] = {}
|
_config: Optional[FeatureCompressorConfig] = None
|
||||||
|
|
||||||
def __new__(cls) -> "ConfigManager":
|
def __new__(cls) -> "ConfigManager":
|
||||||
|
"""Singleton pattern - ensure only one instance exists."""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
"""Initialize config manager with config directory and path."""
|
||||||
self.config_dir = Path(__file__).parent
|
self.config_dir = Path(__file__).parent
|
||||||
|
self.config_path = self.config_dir / "config.yaml"
|
||||||
|
|
||||||
def load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
|
def load(self) -> FeatureCompressorConfig:
|
||||||
"""Load configuration from YAML file.
|
"""Load configuration from config.yaml file.
|
||||||
|
|
||||||
Args:
|
|
||||||
config_name: Name of config file without extension
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded and validated FeatureCompressorConfig instance
|
Loaded and validated FeatureCompressorConfig instance
|
||||||
"""
|
"""
|
||||||
config_path = self.config_dir / f"{config_name}.yaml"
|
config = load_yaml(self.config_path, FeatureCompressorConfig)
|
||||||
config = load_yaml(config_path, FeatureCompressorConfig)
|
self._config = config
|
||||||
self._configs[config_name] = config
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]:
|
def get(self) -> FeatureCompressorConfig:
|
||||||
"""Load all YAML configuration files from config directory.
|
"""Get loaded configuration, auto-loading if not already loaded.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping config names to FeatureCompressorConfig instances
|
|
||||||
"""
|
|
||||||
config_files = list(self.config_dir.glob("*.yaml"))
|
|
||||||
loaded_configs = {}
|
|
||||||
|
|
||||||
for config_path in config_files:
|
|
||||||
config_name = config_path.stem
|
|
||||||
if config_name.startswith("_"):
|
|
||||||
continue # Skip private configs
|
|
||||||
config = load_yaml(config_path, FeatureCompressorConfig)
|
|
||||||
loaded_configs[config_name] = config
|
|
||||||
|
|
||||||
self._configs.update(loaded_configs)
|
|
||||||
return loaded_configs
|
|
||||||
|
|
||||||
def get_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
|
|
||||||
"""Get loaded configuration by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_name: Name of configuration to retrieve
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FeatureCompressorConfig instance
|
FeatureCompressorConfig instance
|
||||||
|
|
||||||
Raises:
|
Note:
|
||||||
ValueError: If configuration not loaded
|
Automatically loads config if not already loaded
|
||||||
"""
|
"""
|
||||||
if config_name not in self._configs:
|
# Auto-load if config not yet loaded
|
||||||
raise ValueError(
|
if self._config is None:
|
||||||
f"Configuration '{config_name}' not loaded. "
|
return self.load()
|
||||||
f"Call load_config('{config_name}') or load_all_configs() first."
|
return self._config
|
||||||
)
|
|
||||||
return self._configs[config_name]
|
|
||||||
|
|
||||||
def get_or_load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
|
def save(
|
||||||
"""Get loaded configuration by name or load it if not loaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_name: Name of configuration to retrieve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
FeatureCompressorConfig instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If configuration not loaded
|
|
||||||
"""
|
|
||||||
if config_name not in self._configs:
|
|
||||||
return self.load_config(config_name)
|
|
||||||
return self._configs[config_name]
|
|
||||||
|
|
||||||
def list_configs(self) -> list[str]:
|
|
||||||
"""List names of all currently loaded configurations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of configuration names
|
|
||||||
"""
|
|
||||||
return list(self._configs.keys())
|
|
||||||
|
|
||||||
def save_config(
|
|
||||||
self,
|
self,
|
||||||
config_name: ConfigType,
|
|
||||||
config: Optional[FeatureCompressorConfig] = None,
|
config: Optional[FeatureCompressorConfig] = None,
|
||||||
path: Optional[Path] = None,
|
path: Optional[Path] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save configuration to YAML file.
|
"""Save configuration to YAML file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_name: Name of config file without extension
|
config: Configuration to save. If None, saves currently loaded config.
|
||||||
config: Configuration to save. If None, saves currently loaded config for that name.
|
path: Optional custom path. If None, saves to default config.yaml.
|
||||||
path: Optional custom path to save to. If None, saves to config_dir.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no configuration loaded for the given name and config is None
|
ValueError: If no configuration loaded and config is None
|
||||||
"""
|
"""
|
||||||
|
# Use provided config or fall back to loaded config
|
||||||
if config is None:
|
if config is None:
|
||||||
if config_name not in self._configs:
|
if self._config is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No configuration loaded for '{config_name}'. "
|
"No configuration loaded. "
|
||||||
f"Cannot save without providing config parameter."
|
"Cannot save without providing config parameter."
|
||||||
)
|
)
|
||||||
config = self._configs[config_name]
|
config = self._config
|
||||||
|
|
||||||
save_path = path if path else self.config_dir / f"{config_name}.yaml"
|
# Save to custom path or default config.yaml
|
||||||
|
save_path = path if path else self.config_path
|
||||||
save_yaml(save_path, config)
|
save_yaml(save_path, config)
|
||||||
self._configs[config_name] = config
|
self._config = config
|
||||||
|
|
||||||
|
|
||||||
# Global singleton instance
|
# Global singleton instance
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import lancedb
|
import lancedb
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from configs import ConfigType, cfg_manager
|
from configs import cfg_manager
|
||||||
|
|
||||||
db_schema = pa.schema(
|
db_schema = pa.schema(
|
||||||
[
|
[
|
||||||
@@ -27,7 +26,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 获取数据库位置
|
# 获取数据库位置
|
||||||
config = cfg_manager.get_or_load_config(ConfigType.FeatureCompressor)
|
config = cfg_manager.get()
|
||||||
db_path = config.output.directory / "database"
|
db_path = config.output.directory / "database"
|
||||||
|
|
||||||
# 初始化数据库与表格
|
# 初始化数据库与表格
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class DINOv2FeatureExtractor:
|
|||||||
FeatureCompressorConfig instance
|
FeatureCompressorConfig instance
|
||||||
"""
|
"""
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
return cfg_manager.get_or_load_config("feature_compressor")
|
return cfg_manager.get()
|
||||||
else:
|
else:
|
||||||
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class FeatureVisualizer:
|
|||||||
Configuration Pydantic model
|
Configuration Pydantic model
|
||||||
"""
|
"""
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
return cfg_manager.get_or_load_config("feature_compressor")
|
return cfg_manager.get()
|
||||||
else:
|
else:
|
||||||
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
||||||
|
|
||||||
|
|||||||
@@ -141,10 +141,8 @@ class TestYamlLoader:
|
|||||||
"""Test suite for YAML loading and saving."""
|
"""Test suite for YAML loading and saving."""
|
||||||
|
|
||||||
def test_load_existing_yaml(self):
|
def test_load_existing_yaml(self):
|
||||||
"""Load feature_compressor.yaml and verify values."""
|
"""Load config.yaml and verify values."""
|
||||||
config_path = (
|
config_path = Path(__file__).parent.parent / "configs" / "config.yaml"
|
||||||
Path(__file__).parent.parent / "configs" / "feature_compressor.yaml"
|
|
||||||
)
|
|
||||||
config = load_yaml(config_path, FeatureCompressorConfig)
|
config = load_yaml(config_path, FeatureCompressorConfig)
|
||||||
|
|
||||||
# Verify model config
|
# Verify model config
|
||||||
@@ -185,7 +183,7 @@ class TestYamlLoader:
|
|||||||
|
|
||||||
def test_save_yaml_roundtrip(self):
|
def test_save_yaml_roundtrip(self):
|
||||||
"""Create config, save to temp, verify file exists with content."""
|
"""Create config, save to temp, verify file exists with content."""
|
||||||
original = cfg_manager.load_config("feature_compressor")
|
original = cfg_manager.load()
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
temp_path = Path(f.name)
|
temp_path = Path(f.name)
|
||||||
@@ -220,23 +218,22 @@ class TestConfigManager:
|
|||||||
assert manager1 is manager2
|
assert manager1 is manager2
|
||||||
|
|
||||||
def test_load_config(self):
|
def test_load_config(self):
|
||||||
"""Test loading feature_compressor config."""
|
"""Test loading default config."""
|
||||||
config = cfg_manager.load_config("feature_compressor")
|
config = cfg_manager.load()
|
||||||
|
|
||||||
assert config is not None
|
assert config is not None
|
||||||
assert config.model.compression_dim == 256
|
assert config.model.compression_dim == 256
|
||||||
assert config.visualization.point_size == 8
|
assert config.visualization.point_size == 8
|
||||||
|
|
||||||
def test_get_config_not_loaded(self):
|
def test_get_without_load(self):
|
||||||
"""Test that get_config() raises error for unloaded config."""
|
"""Test that get() auto-loads config if not loaded."""
|
||||||
with pytest.raises(ValueError, match="not loaded"):
|
# Reset the singleton's cached config
|
||||||
cfg_manager.get_config("nonexistent_config")
|
cfg_manager._config = None
|
||||||
|
|
||||||
def test_list_configs(self):
|
# get() should auto-load
|
||||||
"""Test listing all loaded configurations."""
|
config = cfg_manager.get()
|
||||||
cfg_manager.load_config("feature_compressor")
|
assert config is not None
|
||||||
configs = cfg_manager.list_configs()
|
assert config.model.compression_dim == 256
|
||||||
assert "feature_compressor" in configs
|
|
||||||
|
|
||||||
def test_save_config(self):
|
def test_save_config(self):
|
||||||
"""Test saving configuration to file."""
|
"""Test saving configuration to file."""
|
||||||
@@ -250,7 +247,7 @@ class TestConfigManager:
|
|||||||
temp_path = Path(f.name)
|
temp_path = Path(f.name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cfg_manager.save_config("test_config", config, path=temp_path)
|
cfg_manager.save(config, path=temp_path)
|
||||||
loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
|
loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
|
||||||
|
|
||||||
assert loaded_config.model.compression_dim == 512
|
assert loaded_config.model.compression_dim == 512
|
||||||
|
|||||||
Reference in New Issue
Block a user