diff --git a/.gitignore b/.gitignore index 11ad470..10d1859 100644 --- a/.gitignore +++ b/.gitignore @@ -209,6 +209,7 @@ __marimo__/ data/ deps/ outputs/ +.sisyphus # Devenv .devenv* diff --git a/mini-nav/__init__.py b/mini-nav/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mini-nav/configs/__init__.py b/mini-nav/configs/__init__.py index e69de29..4fee91d 100644 --- a/mini-nav/configs/__init__.py +++ b/mini-nav/configs/__init__.py @@ -0,0 +1,28 @@ +from .models import ( + ModelConfig, + VisualizationConfig, + OutputConfig, + FeatureCompressorConfig, + PoolingType, +) +from .loader import load_yaml, save_yaml, ConfigError +from .config import ( + ConfigManager, + cfg_manager, +) + +__all__ = [ + # Models + "ModelConfig", + "VisualizationConfig", + "OutputConfig", + "FeatureCompressorConfig", + "PoolingType", + # Loader + "load_yaml", + "save_yaml", + "ConfigError", + # Manager + "ConfigManager", + "cfg_manager", +] diff --git a/mini-nav/configs/config.py b/mini-nav/configs/config.py index 420133c..bb5d8b6 100644 --- a/mini-nav/configs/config.py +++ b/mini-nav/configs/config.py @@ -1,20 +1,136 @@ -from enum import Enum +"""Configuration manager for multiple configurations.""" + from pathlib import Path -from typing import Dict +from typing import Dict, Optional -import yaml +from .loader import load_yaml, save_yaml +from .models import FeatureCompressorConfig -class Config(Enum): - FEATURE_COMPRESSOR = "feature_compressor.yaml" +class ConfigManager: + """Singleton configuration manager supporting multiple configs.""" + + _instance: Optional["ConfigManager"] = None + _configs: Dict[str, FeatureCompressorConfig] = {} + + def __new__(cls) -> "ConfigManager": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + self.config_dir = Path(__file__).parent + + def load_config( + self, config_name: str = "feature_compressor" + ) -> FeatureCompressorConfig: + """Load configuration from YAML file. + + Args: + config_name: Name of config file without extension + + Returns: + Loaded and validated FeatureCompressorConfig instance + """ + config_path = self.config_dir / f"{config_name}.yaml" + config = load_yaml(config_path, FeatureCompressorConfig) + self._configs[config_name] = config + return config + + def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]: + """Load all YAML configuration files from config directory. + + Returns: + Dictionary mapping config names to FeatureCompressorConfig instances + """ + config_files = list(self.config_dir.glob("*.yaml")) + loaded_configs = {} + + for config_path in config_files: + config_name = config_path.stem + if config_name.startswith("_"): + continue # Skip private configs + config = load_yaml(config_path, FeatureCompressorConfig) + loaded_configs[config_name] = config + + self._configs.update(loaded_configs) + return loaded_configs + + def get_config( + self, config_name: str = "feature_compressor" + ) -> FeatureCompressorConfig: + """Get loaded configuration by name. + + Args: + config_name: Name of configuration to retrieve + + Returns: + FeatureCompressorConfig instance + + Raises: + ValueError: If configuration not loaded + """ + if config_name not in self._configs: + raise ValueError( + f"Configuration '{config_name}' not loaded. " + f"Call load_config('{config_name}') or load_all_configs() first." + ) + return self._configs[config_name] + + def get_or_load_config( + self, config_name: str = "feature_compressor" + ) -> FeatureCompressorConfig: + """Get loaded configuration by name or load it if not loaded. + + Args: + config_name: Name of configuration to retrieve + + Returns: + FeatureCompressorConfig instance + + Raises: + ValueError: If configuration not loaded + """ + if config_name not in self._configs: + return self.load_config(config_name) + return self._configs[config_name] + + def list_configs(self) -> list[str]: + """List names of all currently loaded configurations. + + Returns: + List of configuration names + """ + return list(self._configs.keys()) + + def save_config( + self, + config_name: str = "feature_compressor", + config: Optional[FeatureCompressorConfig] = None, + path: Optional[Path] = None, + ) -> None: + """Save configuration to YAML file. + + Args: + config_name: Name of config file without extension + config: Configuration to save. If None, saves currently loaded config for that name. + path: Optional custom path to save to. If None, saves to config_dir. + + Raises: + ValueError: If no configuration loaded for the given name and config is None + """ + if config is None: + if config_name not in self._configs: + raise ValueError( + f"No configuration loaded for '{config_name}'. " + f"Cannot save without providing config parameter." + ) + config = self._configs[config_name] + + save_path = path if path else self.config_dir / f"{config_name}.yaml" + save_yaml(save_path, config) + self._configs[config_name] = config -def get_config_dir() -> Path: - return Path(__file__).parent - - -def get_default_config(config_type: Config) -> Dict[Unknown, Unknown]: - config_path = get_config_dir() / config_type.value - - with open(config_path) as f: - return yaml.safe_load(f) +# Global singleton instance +cfg_manager = ConfigManager() diff --git a/mini-nav/configs/loader.py b/mini-nav/configs/loader.py new file mode 100644 index 0000000..6fd3dc5 --- /dev/null +++ b/mini-nav/configs/loader.py @@ -0,0 +1,69 @@ +"""Generic YAML loader with Pydantic validation.""" + +from pathlib import Path +from typing import TypeVar, Type + +import yaml +from pydantic import BaseModel, ValidationError + + +T = TypeVar("T", bound=BaseModel) + + +class ConfigError(Exception): + """Configuration loading and validation error.""" + + pass + + +def load_yaml(path: Path, model_class: Type[T]) -> T: + """Load and validate YAML configuration file. + + Args: + path: Path to YAML file (str or Path object) + model_class: Pydantic model class to validate against + + Returns: + Validated model instance of type T + + Raises: + ConfigError: On file not found, YAML parsing error, or validation failure + """ + # Coerce str to Path if needed + if isinstance(path, str): + path = Path(path) + + try: + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + except FileNotFoundError as e: + raise ConfigError(f"Configuration file not found: {path}") from e + except yaml.YAMLError as e: + raise ConfigError(f"YAML parsing error: {e}") from e + + try: + return model_class.model_validate(data) + except ValidationError as e: + raise ConfigError(f"Configuration validation failed: {e}") from e + + +def save_yaml(path: Path, model: BaseModel) -> None: + """Save Pydantic model to YAML file. + + Args: + path: Path to YAML file (str or Path object) + model: Pydantic model instance to save + + Raises: + ConfigError: On file write error + """ + # Coerce str to Path if needed + if isinstance(path, str): + path = Path(path) + + try: + data = model.model_dump(exclude_unset=True) + with open(path, "w", encoding="utf-8") as f: + yaml.dump(data, f, default_flow_style=False, allow_unicode=True) + except Exception as e: + raise ConfigError(f"Failed to save configuration to {path}: {e}") from e diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py new file mode 100644 index 0000000..35d1832 --- /dev/null +++ b/mini-nav/configs/models.py @@ -0,0 +1,77 @@ +"""Pydantic data models for feature compressor configuration.""" + +from enum import Enum +from pathlib import Path + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class PoolingType(str, Enum): + """Enum for pooling types.""" + + ATTENTION = "attention" + + +class ModelConfig(BaseModel): + """Configuration for the vision model and compression.""" + + model_config = ConfigDict(extra="ignore") + + name: str = "facebook/dinov2-large" + compression_dim: int = Field( + default=256, gt=0, description="Output feature dimension" + ) + pooling_type: PoolingType = PoolingType.ATTENTION + top_k_ratio: float = Field( + default=0.5, ge=0, le=1, description="Ratio of tokens to keep" + ) + hidden_ratio: float = Field( + default=2.0, gt=0, description="MLP hidden dim as multiple of compression_dim" + ) + dropout_rate: float = Field( + default=0.1, ge=0, le=1, description="Dropout probability" + ) + use_residual: bool = True + device: str = "auto" + + +class VisualizationConfig(BaseModel): + """Configuration for visualization settings.""" + + model_config = ConfigDict(extra="ignore") + + plot_theme: str = "plotly_white" + color_scale: str = "viridis" + point_size: int = Field(default=8, gt=0) + fig_width: int = Field(default=900, gt=0) + fig_height: int = Field(default=600, gt=0) + + +class OutputConfig(BaseModel): + """Configuration for output settings.""" + + model_config = ConfigDict(extra="ignore") + + directory: Path = Path(__file__).parent.parent.parent / "outputs" + html_self_contained: bool = True + png_scale: int = Field(default=2, gt=0) + + @field_validator("directory", mode="after") + def convert_to_absolute(cls, v: Path) -> Path: + """ + Converts the path to an absolute path relative to the current working directory. + This works even if the path doesn't exist on disk. + """ + if v.is_absolute(): + return v + return Path(__file__).parent.parent.parent / v + + +class FeatureCompressorConfig(BaseModel): + """Root configuration for the feature compressor.""" + + model_config = ConfigDict(extra="ignore") + + model: ModelConfig + visualization: VisualizationConfig + output: OutputConfig diff --git a/mini-nav/feature_compressor/core/extractor.py b/mini-nav/feature_compressor/core/extractor.py index 6021033..a1f101e 100644 --- a/mini-nav/feature_compressor/core/extractor.py +++ b/mini-nav/feature_compressor/core/extractor.py @@ -2,13 +2,12 @@ import time from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch -import yaml +from configs import FeatureCompressorConfig, cfg_manager, load_yaml from transformers import AutoImageProcessor, AutoModel -from ...configs.config import Config, get_default_config from ..utils.image_utils import load_image, preprocess_image from .compressor import PoolNetCompressor @@ -25,47 +24,47 @@ class DINOv2FeatureExtractor: """ def __init__(self, config_path: Optional[str] = None, device: str = "auto"): - self.config = self._load_config(config_path) + self.config: FeatureCompressorConfig = self._load_config(config_path) # Set device if device == "auto": - device = self.config.get("model", {}).get("device", "auto") + device = self.config.model.device if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) # Load DINOv2 model and processor - model_name = self.config.get("model", {}).get("name", "facebook/dinov2-large") + model_name = self.config.model.name self.processor = AutoImageProcessor.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model.eval() # Initialize compressor - model_config = self.config.get("model", {}) self.compressor = PoolNetCompressor( input_dim=self.model.config.hidden_size, - compression_dim=model_config.get("compression_dim", 256), - top_k_ratio=model_config.get("top_k_ratio", 0.5), - hidden_ratio=model_config.get("hidden_ratio", 2.0), - dropout_rate=model_config.get("dropout_rate", 0.1), - use_residual=model_config.get("use_residual", True), + compression_dim=self.config.model.compression_dim, + top_k_ratio=self.config.model.top_k_ratio, + hidden_ratio=self.config.model.hidden_ratio, + dropout_rate=self.config.model.dropout_rate, + use_residual=self.config.model.use_residual, device=str(self.device), ) - def _load_config(self, config_path: Optional[str] = None) -> Dict: + def _load_config( + self, config_path: Optional[str] = None + ) -> FeatureCompressorConfig: """Load configuration from YAML file. Args: config_path: Path to config file, or None for default Returns: - Configuration dictionary + FeatureCompressorConfig instance """ if config_path is None: - return get_default_config(Config.FEATURE_COMPRESSOR) - - with open(config_path) as f: - return yaml.safe_load(f) + return cfg_manager.get_or_load_config("feature_compressor") + else: + return load_yaml(Path(config_path), FeatureCompressorConfig) def _extract_dinov2_features(self, images: List) -> torch.Tensor: """Extract DINOv2 last_hidden_state features. @@ -149,14 +148,17 @@ class DINOv2FeatureExtractor: "processing_time": processing_time, "feature_norm": feature_norm, "device": str(self.device), - "model_name": self.config.get("model", {}).get("name"), + "model_name": self.config.model.name, }, } return result def process_batch( - self, image_dir: str, batch_size: int = 8, save_features: bool = True + self, + image_dir: Union[str, Path], + batch_size: int = 8, + save_features: bool = True, ) -> List[Dict[str, object]]: """Process multiple images in batches. @@ -208,7 +210,7 @@ class DINOv2FeatureExtractor: .mean() .item(), "device": str(self.device), - "model_name": self.config.get("model", {}).get("name"), + "model_name": self.config.model.name, }, } @@ -216,12 +218,7 @@ class DINOv2FeatureExtractor: # Save features if requested if save_features: - output_dir = Path( - self.config.get("output", {}).get("directory", "./outputs") - ) - # Resolve relative to project root - if not output_dir.is_absolute(): - output_dir = Path(__file__).parent.parent.parent / output_dir + output_dir = Path(self.config.output.directory) output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / f"{file_path.stem}_features.json" diff --git a/mini-nav/feature_compressor/core/visualizer.py b/mini-nav/feature_compressor/core/visualizer.py index 307696f..6852297 100644 --- a/mini-nav/feature_compressor/core/visualizer.py +++ b/mini-nav/feature_compressor/core/visualizer.py @@ -2,11 +2,11 @@ import os from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import torch -import yaml +from configs import FeatureCompressorConfig, cfg_manager, load_yaml from plotly.graph_objs import Figure from ..utils.plot_utils import ( @@ -29,28 +29,27 @@ class FeatureVisualizer: """ def __init__(self, config_path: Optional[str] = None): - self.config = self._load_config(config_path) + self.config: FeatureCompressorConfig = self._load_config(config_path) - def _load_config(self, config_path: Optional[str] = None) -> dict: + def _load_config( + self, config_path: Optional[str] = None + ) -> FeatureCompressorConfig: """Load configuration from YAML file. Args: config_path: Path to config file, or None for default Returns: - Configuration dictionary + Configuration Pydantic model """ if config_path is None: - config_path = ( - Path(__file__).parent.parent.parent - / "configs" - / "feature_compressor.yaml" - ) + return cfg_manager.get_or_load_config("feature_compressor") + else: + return load_yaml(Path(config_path), FeatureCompressorConfig) - with open(config_path) as f: - return yaml.safe_load(f) - - def plot_histogram(self, features: torch.Tensor, title: str = None) -> object: + def plot_histogram( + self, features: torch.Tensor, title: Optional[str] = None + ) -> Figure: """Plot histogram of feature values. Args: @@ -61,18 +60,21 @@ class FeatureVisualizer: Plotly Figure object """ features_np = features.cpu().numpy() - fig = create_histogram(features_np, title=title) + fig = create_histogram( + features_np, title="Feature Histogram" if title is None else title + ) - viz_config = self.config.get("visualization", {}) - fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white")) + fig = apply_theme(fig, self.config.visualization.plot_theme) fig.update_layout( - width=viz_config.get("fig_width", 900), - height=viz_config.get("fig_height", 600), + width=self.config.visualization.fig_width, + height=self.config.visualization.fig_height, ) return fig - def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure: + def plot_pca_2d( + self, features: torch.Tensor, labels: Optional[List] = None + ) -> Figure: """Plot 2D PCA projection of features. Args: @@ -83,19 +85,21 @@ class FeatureVisualizer: Plotly Figure object """ features_np = features.cpu().numpy() - viz_config = self.config.get("visualization", {}) - fig = create_pca_scatter_2d(features_np, labels=labels) - fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white")) + fig = create_pca_scatter_2d( + features_np, + labels=[i for i in range(len(features_np))] if labels is None else labels, + ) + fig = apply_theme(fig, self.config.visualization.plot_theme) fig.update_traces( marker=dict( - size=viz_config.get("point_size", 8), - colorscale=viz_config.get("color_scale", "viridis"), + size=self.config.visualization.point_size, + colorscale=self.config.visualization.color_scale, ) ) fig.update_layout( - width=viz_config.get("fig_width", 900), - height=viz_config.get("fig_height", 600), + width=self.config.visualization.fig_width, + height=self.config.visualization.fig_height, ) return fig @@ -116,16 +120,17 @@ class FeatureVisualizer: fig = create_comparison_plot(features_np_list, names) - viz_config = self.config.get("visualization", {}) - fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white")) + fig = apply_theme(fig, self.config.visualization.plot_theme) fig.update_layout( - width=viz_config.get("fig_width", 900) * len(features_list), - height=viz_config.get("fig_height", 600), + width=self.config.visualization.fig_width * len(features_list), + height=self.config.visualization.fig_height, ) return fig - def generate_report(self, results: List[dict], output_dir: str) -> List[str]: + def generate_report( + self, results: List[dict], output_dir: Union[str, Path] + ) -> List[str]: """Generate full feature analysis report. Args: @@ -158,7 +163,7 @@ class FeatureVisualizer: return generated_files - def save(self, fig: object, path: str, formats: List[str] = None) -> None: + def save(self, fig: Figure, path: str, formats: List[str]) -> None: """Save figure in multiple formats. Args: @@ -169,8 +174,6 @@ class FeatureVisualizer: if formats is None: formats = ["html"] - output_config = self.config.get("output", {}) - for fmt in formats: if fmt == "png": save_figure(fig, path, format="png") diff --git a/mini-nav/feature_compressor/utils/image_utils.py b/mini-nav/feature_compressor/utils/image_utils.py index 9b8078e..722f522 100644 --- a/mini-nav/feature_compressor/utils/image_utils.py +++ b/mini-nav/feature_compressor/utils/image_utils.py @@ -1,7 +1,7 @@ """Image loading and preprocessing utilities.""" from pathlib import Path -from typing import List, Union +from typing import List, Optional, Union import requests from PIL import Image @@ -52,7 +52,7 @@ def preprocess_image(image: Image.Image, size: int = 224) -> Image.Image: def load_images_from_directory( - dir_path: Union[str, Path], extensions: List[str] = None + dir_path: Union[str, Path], extensions: Optional[List[str]] = None ) -> List[Image.Image]: """Load all images from a directory. diff --git a/mini-nav/tests/test_config.py b/mini-nav/tests/test_config.py new file mode 100644 index 0000000..467080f --- /dev/null +++ b/mini-nav/tests/test_config.py @@ -0,0 +1,259 @@ +"""Tests for configuration system using Pydantic models.""" + +import tempfile +from pathlib import Path + +import pytest +import yaml +from configs import ( + ConfigError, + ConfigManager, + FeatureCompressorConfig, + ModelConfig, + OutputConfig, + PoolingType, + VisualizationConfig, + cfg_manager, + load_yaml, + save_yaml, +) +from pydantic import ValidationError + + +class TestConfigModels: + """Test suite for Pydantic configuration models.""" + + def test_model_config_defaults(self): + """Verify ModelConfig creates with correct defaults.""" + config = ModelConfig() + assert config.name == "facebook/dinov2-large" + assert config.compression_dim == 256 + assert config.pooling_type == PoolingType.ATTENTION + assert config.top_k_ratio == 0.5 + assert config.hidden_ratio == 2.0 + assert config.dropout_rate == 0.1 + assert config.use_residual is True + assert config.device == "auto" + + def test_model_config_validation(self): + """Test validation constraints for ModelConfig.""" + # Test compression_dim > 0 + with pytest.raises(ValidationError, match="greater than 0"): + ModelConfig(compression_dim=0) + + with pytest.raises(ValidationError, match="greater than 0"): + ModelConfig(compression_dim=-1) + + # Test top_k_ratio in [0, 1] + with pytest.raises(ValidationError, match="less than or equal to 1"): + ModelConfig(top_k_ratio=1.5) + + with pytest.raises(ValidationError, match="greater than or equal to 0"): + ModelConfig(top_k_ratio=-0.1) + + # Test dropout_rate in [0, 1] + with pytest.raises(ValidationError, match="less than or equal to 1"): + ModelConfig(dropout_rate=1.5) + + with pytest.raises(ValidationError, match="greater than or equal to 0"): + ModelConfig(dropout_rate=-0.1) + + # Test hidden_ratio > 0 + with pytest.raises(ValidationError, match="greater than 0"): + ModelConfig(hidden_ratio=0) + + with pytest.raises(ValidationError, match="greater than 0"): + ModelConfig(hidden_ratio=-1) + + def test_visualization_config_defaults(self): + """Verify VisualizationConfig creates with correct defaults.""" + config = VisualizationConfig() + assert config.plot_theme == "plotly_white" + assert config.color_scale == "viridis" + assert config.point_size == 8 + assert config.fig_width == 900 + assert config.fig_height == 600 + + def test_visualization_config_validation(self): + """Test validation constraints for VisualizationConfig.""" + # Test fig_width > 0 + with pytest.raises(ValidationError, match="greater than 0"): + VisualizationConfig(fig_width=0) + + with pytest.raises(ValidationError, match="greater than 0"): + VisualizationConfig(fig_width=-1) + + # Test fig_height > 0 + with pytest.raises(ValidationError, match="greater than 0"): + VisualizationConfig(fig_height=0) + + with pytest.raises(ValidationError, match="greater than 0"): + VisualizationConfig(fig_height=-1) + + # Test point_size > 0 + with pytest.raises(ValidationError, match="greater than 0"): + VisualizationConfig(point_size=0) + + with pytest.raises(ValidationError, match="greater than 0"): + VisualizationConfig(point_size=-1) + + def test_output_config_defaults(self): + """Verify OutputConfig creates with correct defaults.""" + config = OutputConfig() + output_dir = Path(__file__).parent.parent.parent / "outputs" + + assert config.directory == output_dir + assert config.html_self_contained is True + assert config.png_scale == 2 + + def test_output_config_validation(self): + """Test validation constraints for OutputConfig.""" + # Test png_scale > 0 + with pytest.raises(ValidationError, match="greater than 0"): + OutputConfig(png_scale=0) + + with pytest.raises(ValidationError, match="greater than 0"): + OutputConfig(png_scale=-1) + + def test_pooling_type_enum(self): + """Verify PoolingType enum values.""" + assert PoolingType.ATTENTION.value == "attention" + assert PoolingType.ATTENTION == PoolingType("attention") + + def test_feature_compressor_config(self): + """Verify FeatureCompressorConfig nests all models correctly.""" + model_cfg = ModelConfig(compression_dim=512) + viz_cfg = VisualizationConfig(point_size=16) + out_cfg = OutputConfig(directory="/tmp/outputs") + + config = FeatureCompressorConfig( + model=model_cfg, + visualization=viz_cfg, + output=out_cfg, + ) + + assert config.model.compression_dim == 512 + assert config.visualization.point_size == 16 + assert config.output.directory == Path("/tmp/outputs") + + +class TestYamlLoader: + """Test suite for YAML loading and saving.""" + + def test_load_existing_yaml(self): + """Load feature_compressor.yaml and verify values.""" + config_path = ( + Path(__file__).parent.parent / "configs" / "feature_compressor.yaml" + ) + config = load_yaml(config_path, FeatureCompressorConfig) + + # Verify model config + assert config.model.name == "facebook/dinov2-large" + assert config.model.compression_dim == 256 + assert config.model.pooling_type == PoolingType.ATTENTION + assert config.model.top_k_ratio == 0.5 + assert config.model.hidden_ratio == 2.0 + assert config.model.dropout_rate == 0.1 + assert config.model.use_residual is True + + # Verify visualization config + assert config.visualization.plot_theme == "plotly_white" + assert config.visualization.color_scale == "viridis" + assert config.visualization.point_size == 8 + assert config.visualization.fig_width == 900 + assert config.visualization.fig_height == 600 + + # Verify output config + output_dir = Path(__file__).parent.parent.parent / "outputs" + + assert config.output.directory == output_dir + assert config.output.html_self_contained is True + assert config.output.png_scale == 2 + + def test_load_yaml_validation(self): + """Test that invalid data raises ConfigError.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + # Write invalid config (missing required fields) + yaml.dump({"invalid": "data"}, f) + temp_path = f.name + + try: + with pytest.raises(ConfigError, match="validation failed"): + load_yaml(Path(temp_path), FeatureCompressorConfig) + finally: + Path(temp_path).unlink() + + def test_save_yaml_roundtrip(self): + """Create config, save to temp, verify file exists with content.""" + original = cfg_manager.load_config("feature_compressor") + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + temp_path = Path(f.name) + + try: + save_yaml(temp_path, original) + + # Verify file exists and has content + assert Path(temp_path).exists() + with open(temp_path, "r") as f: + content = f.read() + assert len(content) > 0 + assert "model" in content + assert "visualization" in content + assert "output" in content + finally: + Path(temp_path).unlink() + + def test_load_yaml_file_not_found(self): + """Verify FileNotFoundError raises ConfigError.""" + with pytest.raises(ConfigError, match="not found"): + load_yaml(Path("/nonexistent/path/config.yaml"), FeatureCompressorConfig) + + +class TestConfigManager: + """Test suite for ConfigManager singleton with multi-config support.""" + + def test_singleton_pattern(self): + """Verify ConfigManager() returns same instance.""" + manager1 = ConfigManager() + manager2 = ConfigManager() + assert manager1 is manager2 + + def test_load_config(self): + """Test loading feature_compressor config.""" + config = cfg_manager.load_config("feature_compressor") + + assert config is not None + assert config.model.compression_dim == 256 + assert config.visualization.point_size == 8 + + def test_get_config_not_loaded(self): + """Test that get_config() raises error for unloaded config.""" + with pytest.raises(ValueError, match="not loaded"): + cfg_manager.get_config("nonexistent_config") + + def test_list_configs(self): + """Test listing all loaded configurations.""" + cfg_manager.load_config("feature_compressor") + configs = cfg_manager.list_configs() + assert "feature_compressor" in configs + + def test_save_config(self): + """Test saving configuration to file.""" + config = FeatureCompressorConfig( + model=ModelConfig(compression_dim=512), + visualization=VisualizationConfig(), + output=OutputConfig(), + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + temp_path = Path(f.name) + + try: + cfg_manager.save_config("test_config", config, path=temp_path) + loaded_config = load_yaml(temp_path, FeatureCompressorConfig) + + assert loaded_config.model.compression_dim == 512 + finally: + if temp_path.exists(): + temp_path.unlink()