feat(configs): implement Pydantic configuration system with type safety

This commit is contained in:
2026-01-31 12:19:11 +08:00
parent 1454647aa6
commit 9e9070bdb4
10 changed files with 628 additions and 78 deletions

1
.gitignore vendored
View File

@@ -209,6 +209,7 @@ __marimo__/
data/ data/
deps/ deps/
outputs/ outputs/
.sisyphus
# Devenv # Devenv
.devenv* .devenv*

0
mini-nav/__init__.py Normal file
View File

View File

@@ -0,0 +1,28 @@
from .models import (
ModelConfig,
VisualizationConfig,
OutputConfig,
FeatureCompressorConfig,
PoolingType,
)
from .loader import load_yaml, save_yaml, ConfigError
from .config import (
ConfigManager,
cfg_manager,
)
__all__ = [
# Models
"ModelConfig",
"VisualizationConfig",
"OutputConfig",
"FeatureCompressorConfig",
"PoolingType",
# Loader
"load_yaml",
"save_yaml",
"ConfigError",
# Manager
"ConfigManager",
"cfg_manager",
]

View File

@@ -1,20 +1,136 @@
from enum import Enum """Configuration manager for multiple configurations."""
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Optional
import yaml from .loader import load_yaml, save_yaml
from .models import FeatureCompressorConfig
class Config(Enum): class ConfigManager:
FEATURE_COMPRESSOR = "feature_compressor.yaml" """Singleton configuration manager supporting multiple configs."""
_instance: Optional["ConfigManager"] = None
_configs: Dict[str, FeatureCompressorConfig] = {}
def __new__(cls) -> "ConfigManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
self.config_dir = Path(__file__).parent
def load_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_name: Name of config file without extension
Returns:
Loaded and validated FeatureCompressorConfig instance
"""
config_path = self.config_dir / f"{config_name}.yaml"
config = load_yaml(config_path, FeatureCompressorConfig)
self._configs[config_name] = config
return config
def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]:
"""Load all YAML configuration files from config directory.
Returns:
Dictionary mapping config names to FeatureCompressorConfig instances
"""
config_files = list(self.config_dir.glob("*.yaml"))
loaded_configs = {}
for config_path in config_files:
config_name = config_path.stem
if config_name.startswith("_"):
continue # Skip private configs
config = load_yaml(config_path, FeatureCompressorConfig)
loaded_configs[config_name] = config
self._configs.update(loaded_configs)
return loaded_configs
def get_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
"""Get loaded configuration by name.
Args:
config_name: Name of configuration to retrieve
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
"""
if config_name not in self._configs:
raise ValueError(
f"Configuration '{config_name}' not loaded. "
f"Call load_config('{config_name}') or load_all_configs() first."
)
return self._configs[config_name]
def get_or_load_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
"""Get loaded configuration by name or load it if not loaded.
Args:
config_name: Name of configuration to retrieve
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
"""
if config_name not in self._configs:
return self.load_config(config_name)
return self._configs[config_name]
def list_configs(self) -> list[str]:
"""List names of all currently loaded configurations.
Returns:
List of configuration names
"""
return list(self._configs.keys())
def save_config(
self,
config_name: str = "feature_compressor",
config: Optional[FeatureCompressorConfig] = None,
path: Optional[Path] = None,
) -> None:
"""Save configuration to YAML file.
Args:
config_name: Name of config file without extension
config: Configuration to save. If None, saves currently loaded config for that name.
path: Optional custom path to save to. If None, saves to config_dir.
Raises:
ValueError: If no configuration loaded for the given name and config is None
"""
if config is None:
if config_name not in self._configs:
raise ValueError(
f"No configuration loaded for '{config_name}'. "
f"Cannot save without providing config parameter."
)
config = self._configs[config_name]
save_path = path if path else self.config_dir / f"{config_name}.yaml"
save_yaml(save_path, config)
self._configs[config_name] = config
def get_config_dir() -> Path: # Global singleton instance
return Path(__file__).parent cfg_manager = ConfigManager()
def get_default_config(config_type: Config) -> Dict[Unknown, Unknown]:
config_path = get_config_dir() / config_type.value
with open(config_path) as f:
return yaml.safe_load(f)

View File

@@ -0,0 +1,69 @@
"""Generic YAML loader with Pydantic validation."""
from pathlib import Path
from typing import TypeVar, Type
import yaml
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel)
class ConfigError(Exception):
"""Configuration loading and validation error."""
pass
def load_yaml(path: Path, model_class: Type[T]) -> T:
"""Load and validate YAML configuration file.
Args:
path: Path to YAML file (str or Path object)
model_class: Pydantic model class to validate against
Returns:
Validated model instance of type T
Raises:
ConfigError: On file not found, YAML parsing error, or validation failure
"""
# Coerce str to Path if needed
if isinstance(path, str):
path = Path(path)
try:
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except FileNotFoundError as e:
raise ConfigError(f"Configuration file not found: {path}") from e
except yaml.YAMLError as e:
raise ConfigError(f"YAML parsing error: {e}") from e
try:
return model_class.model_validate(data)
except ValidationError as e:
raise ConfigError(f"Configuration validation failed: {e}") from e
def save_yaml(path: Path, model: BaseModel) -> None:
"""Save Pydantic model to YAML file.
Args:
path: Path to YAML file (str or Path object)
model: Pydantic model instance to save
Raises:
ConfigError: On file write error
"""
# Coerce str to Path if needed
if isinstance(path, str):
path = Path(path)
try:
data = model.model_dump(exclude_unset=True)
with open(path, "w", encoding="utf-8") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
except Exception as e:
raise ConfigError(f"Failed to save configuration to {path}: {e}") from e

View File

@@ -0,0 +1,77 @@
"""Pydantic data models for feature compressor configuration."""
from enum import Enum
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, field_validator
class PoolingType(str, Enum):
"""Enum for pooling types."""
ATTENTION = "attention"
class ModelConfig(BaseModel):
"""Configuration for the vision model and compression."""
model_config = ConfigDict(extra="ignore")
name: str = "facebook/dinov2-large"
compression_dim: int = Field(
default=256, gt=0, description="Output feature dimension"
)
pooling_type: PoolingType = PoolingType.ATTENTION
top_k_ratio: float = Field(
default=0.5, ge=0, le=1, description="Ratio of tokens to keep"
)
hidden_ratio: float = Field(
default=2.0, gt=0, description="MLP hidden dim as multiple of compression_dim"
)
dropout_rate: float = Field(
default=0.1, ge=0, le=1, description="Dropout probability"
)
use_residual: bool = True
device: str = "auto"
class VisualizationConfig(BaseModel):
"""Configuration for visualization settings."""
model_config = ConfigDict(extra="ignore")
plot_theme: str = "plotly_white"
color_scale: str = "viridis"
point_size: int = Field(default=8, gt=0)
fig_width: int = Field(default=900, gt=0)
fig_height: int = Field(default=600, gt=0)
class OutputConfig(BaseModel):
"""Configuration for output settings."""
model_config = ConfigDict(extra="ignore")
directory: Path = Path(__file__).parent.parent.parent / "outputs"
html_self_contained: bool = True
png_scale: int = Field(default=2, gt=0)
@field_validator("directory", mode="after")
def convert_to_absolute(cls, v: Path) -> Path:
"""
Converts the path to an absolute path relative to the current working directory.
This works even if the path doesn't exist on disk.
"""
if v.is_absolute():
return v
return Path(__file__).parent.parent.parent / v
class FeatureCompressorConfig(BaseModel):
"""Root configuration for the feature compressor."""
model_config = ConfigDict(extra="ignore")
model: ModelConfig
visualization: VisualizationConfig
output: OutputConfig

View File

@@ -2,13 +2,12 @@
import time import time
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import torch import torch
import yaml from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
from ...configs.config import Config, get_default_config
from ..utils.image_utils import load_image, preprocess_image from ..utils.image_utils import load_image, preprocess_image
from .compressor import PoolNetCompressor from .compressor import PoolNetCompressor
@@ -25,47 +24,47 @@ class DINOv2FeatureExtractor:
""" """
def __init__(self, config_path: Optional[str] = None, device: str = "auto"): def __init__(self, config_path: Optional[str] = None, device: str = "auto"):
self.config = self._load_config(config_path) self.config: FeatureCompressorConfig = self._load_config(config_path)
# Set device # Set device
if device == "auto": if device == "auto":
device = self.config.get("model", {}).get("device", "auto") device = self.config.model.device
if device == "auto": if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device) self.device = torch.device(device)
# Load DINOv2 model and processor # Load DINOv2 model and processor
model_name = self.config.get("model", {}).get("name", "facebook/dinov2-large") model_name = self.config.model.name
self.processor = AutoImageProcessor.from_pretrained(model_name) self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval() self.model.eval()
# Initialize compressor # Initialize compressor
model_config = self.config.get("model", {})
self.compressor = PoolNetCompressor( self.compressor = PoolNetCompressor(
input_dim=self.model.config.hidden_size, input_dim=self.model.config.hidden_size,
compression_dim=model_config.get("compression_dim", 256), compression_dim=self.config.model.compression_dim,
top_k_ratio=model_config.get("top_k_ratio", 0.5), top_k_ratio=self.config.model.top_k_ratio,
hidden_ratio=model_config.get("hidden_ratio", 2.0), hidden_ratio=self.config.model.hidden_ratio,
dropout_rate=model_config.get("dropout_rate", 0.1), dropout_rate=self.config.model.dropout_rate,
use_residual=model_config.get("use_residual", True), use_residual=self.config.model.use_residual,
device=str(self.device), device=str(self.device),
) )
def _load_config(self, config_path: Optional[str] = None) -> Dict: def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file. """Load configuration from YAML file.
Args: Args:
config_path: Path to config file, or None for default config_path: Path to config file, or None for default
Returns: Returns:
Configuration dictionary FeatureCompressorConfig instance
""" """
if config_path is None: if config_path is None:
return get_default_config(Config.FEATURE_COMPRESSOR) return cfg_manager.get_or_load_config("feature_compressor")
else:
with open(config_path) as f: return load_yaml(Path(config_path), FeatureCompressorConfig)
return yaml.safe_load(f)
def _extract_dinov2_features(self, images: List) -> torch.Tensor: def _extract_dinov2_features(self, images: List) -> torch.Tensor:
"""Extract DINOv2 last_hidden_state features. """Extract DINOv2 last_hidden_state features.
@@ -149,14 +148,17 @@ class DINOv2FeatureExtractor:
"processing_time": processing_time, "processing_time": processing_time,
"feature_norm": feature_norm, "feature_norm": feature_norm,
"device": str(self.device), "device": str(self.device),
"model_name": self.config.get("model", {}).get("name"), "model_name": self.config.model.name,
}, },
} }
return result return result
def process_batch( def process_batch(
self, image_dir: str, batch_size: int = 8, save_features: bool = True self,
image_dir: Union[str, Path],
batch_size: int = 8,
save_features: bool = True,
) -> List[Dict[str, object]]: ) -> List[Dict[str, object]]:
"""Process multiple images in batches. """Process multiple images in batches.
@@ -208,7 +210,7 @@ class DINOv2FeatureExtractor:
.mean() .mean()
.item(), .item(),
"device": str(self.device), "device": str(self.device),
"model_name": self.config.get("model", {}).get("name"), "model_name": self.config.model.name,
}, },
} }
@@ -216,12 +218,7 @@ class DINOv2FeatureExtractor:
# Save features if requested # Save features if requested
if save_features: if save_features:
output_dir = Path( output_dir = Path(self.config.output.directory)
self.config.get("output", {}).get("directory", "./outputs")
)
# Resolve relative to project root
if not output_dir.is_absolute():
output_dir = Path(__file__).parent.parent.parent / output_dir
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"{file_path.stem}_features.json" output_path = output_dir / f"{file_path.stem}_features.json"

View File

@@ -2,11 +2,11 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import yaml from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from plotly.graph_objs import Figure from plotly.graph_objs import Figure
from ..utils.plot_utils import ( from ..utils.plot_utils import (
@@ -29,28 +29,27 @@ class FeatureVisualizer:
""" """
def __init__(self, config_path: Optional[str] = None): def __init__(self, config_path: Optional[str] = None):
self.config = self._load_config(config_path) self.config: FeatureCompressorConfig = self._load_config(config_path)
def _load_config(self, config_path: Optional[str] = None) -> dict: def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file. """Load configuration from YAML file.
Args: Args:
config_path: Path to config file, or None for default config_path: Path to config file, or None for default
Returns: Returns:
Configuration dictionary Configuration Pydantic model
""" """
if config_path is None: if config_path is None:
config_path = ( return cfg_manager.get_or_load_config("feature_compressor")
Path(__file__).parent.parent.parent else:
/ "configs" return load_yaml(Path(config_path), FeatureCompressorConfig)
/ "feature_compressor.yaml"
)
with open(config_path) as f: def plot_histogram(
return yaml.safe_load(f) self, features: torch.Tensor, title: Optional[str] = None
) -> Figure:
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
"""Plot histogram of feature values. """Plot histogram of feature values.
Args: Args:
@@ -61,18 +60,21 @@ class FeatureVisualizer:
Plotly Figure object Plotly Figure object
""" """
features_np = features.cpu().numpy() features_np = features.cpu().numpy()
fig = create_histogram(features_np, title=title) fig = create_histogram(
features_np, title="Feature Histogram" if title is None else title
)
viz_config = self.config.get("visualization", {}) fig = apply_theme(fig, self.config.visualization.plot_theme)
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_layout( fig.update_layout(
width=viz_config.get("fig_width", 900), width=self.config.visualization.fig_width,
height=viz_config.get("fig_height", 600), height=self.config.visualization.fig_height,
) )
return fig return fig
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure: def plot_pca_2d(
self, features: torch.Tensor, labels: Optional[List] = None
) -> Figure:
"""Plot 2D PCA projection of features. """Plot 2D PCA projection of features.
Args: Args:
@@ -83,19 +85,21 @@ class FeatureVisualizer:
Plotly Figure object Plotly Figure object
""" """
features_np = features.cpu().numpy() features_np = features.cpu().numpy()
viz_config = self.config.get("visualization", {})
fig = create_pca_scatter_2d(features_np, labels=labels) fig = create_pca_scatter_2d(
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white")) features_np,
labels=[i for i in range(len(features_np))] if labels is None else labels,
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_traces( fig.update_traces(
marker=dict( marker=dict(
size=viz_config.get("point_size", 8), size=self.config.visualization.point_size,
colorscale=viz_config.get("color_scale", "viridis"), colorscale=self.config.visualization.color_scale,
) )
) )
fig.update_layout( fig.update_layout(
width=viz_config.get("fig_width", 900), width=self.config.visualization.fig_width,
height=viz_config.get("fig_height", 600), height=self.config.visualization.fig_height,
) )
return fig return fig
@@ -116,16 +120,17 @@ class FeatureVisualizer:
fig = create_comparison_plot(features_np_list, names) fig = create_comparison_plot(features_np_list, names)
viz_config = self.config.get("visualization", {}) fig = apply_theme(fig, self.config.visualization.plot_theme)
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_layout( fig.update_layout(
width=viz_config.get("fig_width", 900) * len(features_list), width=self.config.visualization.fig_width * len(features_list),
height=viz_config.get("fig_height", 600), height=self.config.visualization.fig_height,
) )
return fig return fig
def generate_report(self, results: List[dict], output_dir: str) -> List[str]: def generate_report(
self, results: List[dict], output_dir: Union[str, Path]
) -> List[str]:
"""Generate full feature analysis report. """Generate full feature analysis report.
Args: Args:
@@ -158,7 +163,7 @@ class FeatureVisualizer:
return generated_files return generated_files
def save(self, fig: object, path: str, formats: List[str] = None) -> None: def save(self, fig: Figure, path: str, formats: List[str]) -> None:
"""Save figure in multiple formats. """Save figure in multiple formats.
Args: Args:
@@ -169,8 +174,6 @@ class FeatureVisualizer:
if formats is None: if formats is None:
formats = ["html"] formats = ["html"]
output_config = self.config.get("output", {})
for fmt in formats: for fmt in formats:
if fmt == "png": if fmt == "png":
save_figure(fig, path, format="png") save_figure(fig, path, format="png")

View File

@@ -1,7 +1,7 @@
"""Image loading and preprocessing utilities.""" """Image loading and preprocessing utilities."""
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Optional, Union
import requests import requests
from PIL import Image from PIL import Image
@@ -52,7 +52,7 @@ def preprocess_image(image: Image.Image, size: int = 224) -> Image.Image:
def load_images_from_directory( def load_images_from_directory(
dir_path: Union[str, Path], extensions: List[str] = None dir_path: Union[str, Path], extensions: Optional[List[str]] = None
) -> List[Image.Image]: ) -> List[Image.Image]:
"""Load all images from a directory. """Load all images from a directory.

View File

@@ -0,0 +1,259 @@
"""Tests for configuration system using Pydantic models."""
import tempfile
from pathlib import Path
import pytest
import yaml
from configs import (
ConfigError,
ConfigManager,
FeatureCompressorConfig,
ModelConfig,
OutputConfig,
PoolingType,
VisualizationConfig,
cfg_manager,
load_yaml,
save_yaml,
)
from pydantic import ValidationError
class TestConfigModels:
"""Test suite for Pydantic configuration models."""
def test_model_config_defaults(self):
"""Verify ModelConfig creates with correct defaults."""
config = ModelConfig()
assert config.name == "facebook/dinov2-large"
assert config.compression_dim == 256
assert config.pooling_type == PoolingType.ATTENTION
assert config.top_k_ratio == 0.5
assert config.hidden_ratio == 2.0
assert config.dropout_rate == 0.1
assert config.use_residual is True
assert config.device == "auto"
def test_model_config_validation(self):
"""Test validation constraints for ModelConfig."""
# Test compression_dim > 0
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(compression_dim=0)
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(compression_dim=-1)
# Test top_k_ratio in [0, 1]
with pytest.raises(ValidationError, match="less than or equal to 1"):
ModelConfig(top_k_ratio=1.5)
with pytest.raises(ValidationError, match="greater than or equal to 0"):
ModelConfig(top_k_ratio=-0.1)
# Test dropout_rate in [0, 1]
with pytest.raises(ValidationError, match="less than or equal to 1"):
ModelConfig(dropout_rate=1.5)
with pytest.raises(ValidationError, match="greater than or equal to 0"):
ModelConfig(dropout_rate=-0.1)
# Test hidden_ratio > 0
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(hidden_ratio=0)
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(hidden_ratio=-1)
def test_visualization_config_defaults(self):
"""Verify VisualizationConfig creates with correct defaults."""
config = VisualizationConfig()
assert config.plot_theme == "plotly_white"
assert config.color_scale == "viridis"
assert config.point_size == 8
assert config.fig_width == 900
assert config.fig_height == 600
def test_visualization_config_validation(self):
"""Test validation constraints for VisualizationConfig."""
# Test fig_width > 0
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_width=0)
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_width=-1)
# Test fig_height > 0
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_height=0)
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_height=-1)
# Test point_size > 0
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(point_size=0)
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(point_size=-1)
def test_output_config_defaults(self):
"""Verify OutputConfig creates with correct defaults."""
config = OutputConfig()
output_dir = Path(__file__).parent.parent.parent / "outputs"
assert config.directory == output_dir
assert config.html_self_contained is True
assert config.png_scale == 2
def test_output_config_validation(self):
"""Test validation constraints for OutputConfig."""
# Test png_scale > 0
with pytest.raises(ValidationError, match="greater than 0"):
OutputConfig(png_scale=0)
with pytest.raises(ValidationError, match="greater than 0"):
OutputConfig(png_scale=-1)
def test_pooling_type_enum(self):
"""Verify PoolingType enum values."""
assert PoolingType.ATTENTION.value == "attention"
assert PoolingType.ATTENTION == PoolingType("attention")
def test_feature_compressor_config(self):
"""Verify FeatureCompressorConfig nests all models correctly."""
model_cfg = ModelConfig(compression_dim=512)
viz_cfg = VisualizationConfig(point_size=16)
out_cfg = OutputConfig(directory="/tmp/outputs")
config = FeatureCompressorConfig(
model=model_cfg,
visualization=viz_cfg,
output=out_cfg,
)
assert config.model.compression_dim == 512
assert config.visualization.point_size == 16
assert config.output.directory == Path("/tmp/outputs")
class TestYamlLoader:
"""Test suite for YAML loading and saving."""
def test_load_existing_yaml(self):
"""Load feature_compressor.yaml and verify values."""
config_path = (
Path(__file__).parent.parent / "configs" / "feature_compressor.yaml"
)
config = load_yaml(config_path, FeatureCompressorConfig)
# Verify model config
assert config.model.name == "facebook/dinov2-large"
assert config.model.compression_dim == 256
assert config.model.pooling_type == PoolingType.ATTENTION
assert config.model.top_k_ratio == 0.5
assert config.model.hidden_ratio == 2.0
assert config.model.dropout_rate == 0.1
assert config.model.use_residual is True
# Verify visualization config
assert config.visualization.plot_theme == "plotly_white"
assert config.visualization.color_scale == "viridis"
assert config.visualization.point_size == 8
assert config.visualization.fig_width == 900
assert config.visualization.fig_height == 600
# Verify output config
output_dir = Path(__file__).parent.parent.parent / "outputs"
assert config.output.directory == output_dir
assert config.output.html_self_contained is True
assert config.output.png_scale == 2
def test_load_yaml_validation(self):
"""Test that invalid data raises ConfigError."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
# Write invalid config (missing required fields)
yaml.dump({"invalid": "data"}, f)
temp_path = f.name
try:
with pytest.raises(ConfigError, match="validation failed"):
load_yaml(Path(temp_path), FeatureCompressorConfig)
finally:
Path(temp_path).unlink()
def test_save_yaml_roundtrip(self):
"""Create config, save to temp, verify file exists with content."""
original = cfg_manager.load_config("feature_compressor")
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
temp_path = Path(f.name)
try:
save_yaml(temp_path, original)
# Verify file exists and has content
assert Path(temp_path).exists()
with open(temp_path, "r") as f:
content = f.read()
assert len(content) > 0
assert "model" in content
assert "visualization" in content
assert "output" in content
finally:
Path(temp_path).unlink()
def test_load_yaml_file_not_found(self):
"""Verify FileNotFoundError raises ConfigError."""
with pytest.raises(ConfigError, match="not found"):
load_yaml(Path("/nonexistent/path/config.yaml"), FeatureCompressorConfig)
class TestConfigManager:
"""Test suite for ConfigManager singleton with multi-config support."""
def test_singleton_pattern(self):
"""Verify ConfigManager() returns same instance."""
manager1 = ConfigManager()
manager2 = ConfigManager()
assert manager1 is manager2
def test_load_config(self):
"""Test loading feature_compressor config."""
config = cfg_manager.load_config("feature_compressor")
assert config is not None
assert config.model.compression_dim == 256
assert config.visualization.point_size == 8
def test_get_config_not_loaded(self):
"""Test that get_config() raises error for unloaded config."""
with pytest.raises(ValueError, match="not loaded"):
cfg_manager.get_config("nonexistent_config")
def test_list_configs(self):
"""Test listing all loaded configurations."""
cfg_manager.load_config("feature_compressor")
configs = cfg_manager.list_configs()
assert "feature_compressor" in configs
def test_save_config(self):
"""Test saving configuration to file."""
config = FeatureCompressorConfig(
model=ModelConfig(compression_dim=512),
visualization=VisualizationConfig(),
output=OutputConfig(),
)
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
temp_path = Path(f.name)
try:
cfg_manager.save_config("test_config", config, path=temp_path)
loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
assert loaded_config.model.compression_dim == 512
finally:
if temp_path.exists():
temp_path.unlink()