feat(configs): implement Pydantic configuration system with type safety

This commit is contained in:
2026-01-31 12:19:11 +08:00
parent 1454647aa6
commit 9e9070bdb4
10 changed files with 628 additions and 78 deletions

1
.gitignore vendored
View File

@@ -209,6 +209,7 @@ __marimo__/
data/
deps/
outputs/
.sisyphus
# Devenv
.devenv*

0
mini-nav/__init__.py Normal file
View File

View File

@@ -0,0 +1,28 @@
from .models import (
ModelConfig,
VisualizationConfig,
OutputConfig,
FeatureCompressorConfig,
PoolingType,
)
from .loader import load_yaml, save_yaml, ConfigError
from .config import (
ConfigManager,
cfg_manager,
)
__all__ = [
# Models
"ModelConfig",
"VisualizationConfig",
"OutputConfig",
"FeatureCompressorConfig",
"PoolingType",
# Loader
"load_yaml",
"save_yaml",
"ConfigError",
# Manager
"ConfigManager",
"cfg_manager",
]

View File

@@ -1,20 +1,136 @@
from enum import Enum
"""Configuration manager for multiple configurations."""
from pathlib import Path
from typing import Dict
from typing import Dict, Optional
import yaml
from .loader import load_yaml, save_yaml
from .models import FeatureCompressorConfig
class Config(Enum):
FEATURE_COMPRESSOR = "feature_compressor.yaml"
class ConfigManager:
"""Singleton configuration manager supporting multiple configs."""
_instance: Optional["ConfigManager"] = None
_configs: Dict[str, FeatureCompressorConfig] = {}
def __new__(cls) -> "ConfigManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
self.config_dir = Path(__file__).parent
def load_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_name: Name of config file without extension
Returns:
Loaded and validated FeatureCompressorConfig instance
"""
config_path = self.config_dir / f"{config_name}.yaml"
config = load_yaml(config_path, FeatureCompressorConfig)
self._configs[config_name] = config
return config
def load_all_configs(self) -> Dict[str, FeatureCompressorConfig]:
"""Load all YAML configuration files from config directory.
Returns:
Dictionary mapping config names to FeatureCompressorConfig instances
"""
config_files = list(self.config_dir.glob("*.yaml"))
loaded_configs = {}
for config_path in config_files:
config_name = config_path.stem
if config_name.startswith("_"):
continue # Skip private configs
config = load_yaml(config_path, FeatureCompressorConfig)
loaded_configs[config_name] = config
self._configs.update(loaded_configs)
return loaded_configs
def get_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
"""Get loaded configuration by name.
Args:
config_name: Name of configuration to retrieve
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
"""
if config_name not in self._configs:
raise ValueError(
f"Configuration '{config_name}' not loaded. "
f"Call load_config('{config_name}') or load_all_configs() first."
)
return self._configs[config_name]
def get_or_load_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
"""Get loaded configuration by name or load it if not loaded.
Args:
config_name: Name of configuration to retrieve
Returns:
FeatureCompressorConfig instance
Raises:
ValueError: If configuration not loaded
"""
if config_name not in self._configs:
return self.load_config(config_name)
return self._configs[config_name]
def list_configs(self) -> list[str]:
"""List names of all currently loaded configurations.
Returns:
List of configuration names
"""
return list(self._configs.keys())
def save_config(
self,
config_name: str = "feature_compressor",
config: Optional[FeatureCompressorConfig] = None,
path: Optional[Path] = None,
) -> None:
"""Save configuration to YAML file.
Args:
config_name: Name of config file without extension
config: Configuration to save. If None, saves currently loaded config for that name.
path: Optional custom path to save to. If None, saves to config_dir.
Raises:
ValueError: If no configuration loaded for the given name and config is None
"""
if config is None:
if config_name not in self._configs:
raise ValueError(
f"No configuration loaded for '{config_name}'. "
f"Cannot save without providing config parameter."
)
config = self._configs[config_name]
save_path = path if path else self.config_dir / f"{config_name}.yaml"
save_yaml(save_path, config)
self._configs[config_name] = config
def get_config_dir() -> Path:
return Path(__file__).parent
def get_default_config(config_type: Config) -> Dict[Unknown, Unknown]:
config_path = get_config_dir() / config_type.value
with open(config_path) as f:
return yaml.safe_load(f)
# Global singleton instance
cfg_manager = ConfigManager()

View File

@@ -0,0 +1,69 @@
"""Generic YAML loader with Pydantic validation."""
from pathlib import Path
from typing import TypeVar, Type
import yaml
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel)
class ConfigError(Exception):
"""Configuration loading and validation error."""
pass
def load_yaml(path: Path, model_class: Type[T]) -> T:
"""Load and validate YAML configuration file.
Args:
path: Path to YAML file (str or Path object)
model_class: Pydantic model class to validate against
Returns:
Validated model instance of type T
Raises:
ConfigError: On file not found, YAML parsing error, or validation failure
"""
# Coerce str to Path if needed
if isinstance(path, str):
path = Path(path)
try:
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except FileNotFoundError as e:
raise ConfigError(f"Configuration file not found: {path}") from e
except yaml.YAMLError as e:
raise ConfigError(f"YAML parsing error: {e}") from e
try:
return model_class.model_validate(data)
except ValidationError as e:
raise ConfigError(f"Configuration validation failed: {e}") from e
def save_yaml(path: Path, model: BaseModel) -> None:
"""Save Pydantic model to YAML file.
Args:
path: Path to YAML file (str or Path object)
model: Pydantic model instance to save
Raises:
ConfigError: On file write error
"""
# Coerce str to Path if needed
if isinstance(path, str):
path = Path(path)
try:
data = model.model_dump(exclude_unset=True)
with open(path, "w", encoding="utf-8") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
except Exception as e:
raise ConfigError(f"Failed to save configuration to {path}: {e}") from e

View File

@@ -0,0 +1,77 @@
"""Pydantic data models for feature compressor configuration."""
from enum import Enum
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, field_validator
class PoolingType(str, Enum):
"""Enum for pooling types."""
ATTENTION = "attention"
class ModelConfig(BaseModel):
"""Configuration for the vision model and compression."""
model_config = ConfigDict(extra="ignore")
name: str = "facebook/dinov2-large"
compression_dim: int = Field(
default=256, gt=0, description="Output feature dimension"
)
pooling_type: PoolingType = PoolingType.ATTENTION
top_k_ratio: float = Field(
default=0.5, ge=0, le=1, description="Ratio of tokens to keep"
)
hidden_ratio: float = Field(
default=2.0, gt=0, description="MLP hidden dim as multiple of compression_dim"
)
dropout_rate: float = Field(
default=0.1, ge=0, le=1, description="Dropout probability"
)
use_residual: bool = True
device: str = "auto"
class VisualizationConfig(BaseModel):
"""Configuration for visualization settings."""
model_config = ConfigDict(extra="ignore")
plot_theme: str = "plotly_white"
color_scale: str = "viridis"
point_size: int = Field(default=8, gt=0)
fig_width: int = Field(default=900, gt=0)
fig_height: int = Field(default=600, gt=0)
class OutputConfig(BaseModel):
"""Configuration for output settings."""
model_config = ConfigDict(extra="ignore")
directory: Path = Path(__file__).parent.parent.parent / "outputs"
html_self_contained: bool = True
png_scale: int = Field(default=2, gt=0)
@field_validator("directory", mode="after")
def convert_to_absolute(cls, v: Path) -> Path:
"""
Converts the path to an absolute path relative to the current working directory.
This works even if the path doesn't exist on disk.
"""
if v.is_absolute():
return v
return Path(__file__).parent.parent.parent / v
class FeatureCompressorConfig(BaseModel):
"""Root configuration for the feature compressor."""
model_config = ConfigDict(extra="ignore")
model: ModelConfig
visualization: VisualizationConfig
output: OutputConfig

View File

@@ -2,13 +2,12 @@
import time
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import torch
import yaml
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from transformers import AutoImageProcessor, AutoModel
from ...configs.config import Config, get_default_config
from ..utils.image_utils import load_image, preprocess_image
from .compressor import PoolNetCompressor
@@ -25,47 +24,47 @@ class DINOv2FeatureExtractor:
"""
def __init__(self, config_path: Optional[str] = None, device: str = "auto"):
self.config = self._load_config(config_path)
self.config: FeatureCompressorConfig = self._load_config(config_path)
# Set device
if device == "auto":
device = self.config.get("model", {}).get("device", "auto")
device = self.config.model.device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Load DINOv2 model and processor
model_name = self.config.get("model", {}).get("name", "facebook/dinov2-large")
model_name = self.config.model.name
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
# Initialize compressor
model_config = self.config.get("model", {})
self.compressor = PoolNetCompressor(
input_dim=self.model.config.hidden_size,
compression_dim=model_config.get("compression_dim", 256),
top_k_ratio=model_config.get("top_k_ratio", 0.5),
hidden_ratio=model_config.get("hidden_ratio", 2.0),
dropout_rate=model_config.get("dropout_rate", 0.1),
use_residual=model_config.get("use_residual", True),
compression_dim=self.config.model.compression_dim,
top_k_ratio=self.config.model.top_k_ratio,
hidden_ratio=self.config.model.hidden_ratio,
dropout_rate=self.config.model.dropout_rate,
use_residual=self.config.model.use_residual,
device=str(self.device),
)
def _load_config(self, config_path: Optional[str] = None) -> Dict:
def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration dictionary
FeatureCompressorConfig instance
"""
if config_path is None:
return get_default_config(Config.FEATURE_COMPRESSOR)
with open(config_path) as f:
return yaml.safe_load(f)
return cfg_manager.get_or_load_config("feature_compressor")
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)
def _extract_dinov2_features(self, images: List) -> torch.Tensor:
"""Extract DINOv2 last_hidden_state features.
@@ -149,14 +148,17 @@ class DINOv2FeatureExtractor:
"processing_time": processing_time,
"feature_norm": feature_norm,
"device": str(self.device),
"model_name": self.config.get("model", {}).get("name"),
"model_name": self.config.model.name,
},
}
return result
def process_batch(
self, image_dir: str, batch_size: int = 8, save_features: bool = True
self,
image_dir: Union[str, Path],
batch_size: int = 8,
save_features: bool = True,
) -> List[Dict[str, object]]:
"""Process multiple images in batches.
@@ -208,7 +210,7 @@ class DINOv2FeatureExtractor:
.mean()
.item(),
"device": str(self.device),
"model_name": self.config.get("model", {}).get("name"),
"model_name": self.config.model.name,
},
}
@@ -216,12 +218,7 @@ class DINOv2FeatureExtractor:
# Save features if requested
if save_features:
output_dir = Path(
self.config.get("output", {}).get("directory", "./outputs")
)
# Resolve relative to project root
if not output_dir.is_absolute():
output_dir = Path(__file__).parent.parent.parent / output_dir
output_dir = Path(self.config.output.directory)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"{file_path.stem}_features.json"

View File

@@ -2,11 +2,11 @@
import os
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import torch
import yaml
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from plotly.graph_objs import Figure
from ..utils.plot_utils import (
@@ -29,28 +29,27 @@ class FeatureVisualizer:
"""
def __init__(self, config_path: Optional[str] = None):
self.config = self._load_config(config_path)
self.config: FeatureCompressorConfig = self._load_config(config_path)
def _load_config(self, config_path: Optional[str] = None) -> dict:
def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration dictionary
Configuration Pydantic model
"""
if config_path is None:
config_path = (
Path(__file__).parent.parent.parent
/ "configs"
/ "feature_compressor.yaml"
)
return cfg_manager.get_or_load_config("feature_compressor")
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)
with open(config_path) as f:
return yaml.safe_load(f)
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
def plot_histogram(
self, features: torch.Tensor, title: Optional[str] = None
) -> Figure:
"""Plot histogram of feature values.
Args:
@@ -61,18 +60,21 @@ class FeatureVisualizer:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_histogram(features_np, title=title)
fig = create_histogram(
features_np, title="Feature Histogram" if title is None else title
)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
def plot_pca_2d(
self, features: torch.Tensor, labels: Optional[List] = None
) -> Figure:
"""Plot 2D PCA projection of features.
Args:
@@ -83,19 +85,21 @@ class FeatureVisualizer:
Plotly Figure object
"""
features_np = features.cpu().numpy()
viz_config = self.config.get("visualization", {})
fig = create_pca_scatter_2d(features_np, labels=labels)
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig = create_pca_scatter_2d(
features_np,
labels=[i for i in range(len(features_np))] if labels is None else labels,
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_traces(
marker=dict(
size=viz_config.get("point_size", 8),
colorscale=viz_config.get("color_scale", "viridis"),
size=self.config.visualization.point_size,
colorscale=self.config.visualization.color_scale,
)
)
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
@@ -116,16 +120,17 @@ class FeatureVisualizer:
fig = create_comparison_plot(features_np_list, names)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=viz_config.get("fig_width", 900) * len(features_list),
height=viz_config.get("fig_height", 600),
width=self.config.visualization.fig_width * len(features_list),
height=self.config.visualization.fig_height,
)
return fig
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
def generate_report(
self, results: List[dict], output_dir: Union[str, Path]
) -> List[str]:
"""Generate full feature analysis report.
Args:
@@ -158,7 +163,7 @@ class FeatureVisualizer:
return generated_files
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
"""Save figure in multiple formats.
Args:
@@ -169,8 +174,6 @@ class FeatureVisualizer:
if formats is None:
formats = ["html"]
output_config = self.config.get("output", {})
for fmt in formats:
if fmt == "png":
save_figure(fig, path, format="png")

View File

@@ -1,7 +1,7 @@
"""Image loading and preprocessing utilities."""
from pathlib import Path
from typing import List, Union
from typing import List, Optional, Union
import requests
from PIL import Image
@@ -52,7 +52,7 @@ def preprocess_image(image: Image.Image, size: int = 224) -> Image.Image:
def load_images_from_directory(
dir_path: Union[str, Path], extensions: List[str] = None
dir_path: Union[str, Path], extensions: Optional[List[str]] = None
) -> List[Image.Image]:
"""Load all images from a directory.

View File

@@ -0,0 +1,259 @@
"""Tests for configuration system using Pydantic models."""
import tempfile
from pathlib import Path
import pytest
import yaml
from configs import (
ConfigError,
ConfigManager,
FeatureCompressorConfig,
ModelConfig,
OutputConfig,
PoolingType,
VisualizationConfig,
cfg_manager,
load_yaml,
save_yaml,
)
from pydantic import ValidationError
class TestConfigModels:
"""Test suite for Pydantic configuration models."""
def test_model_config_defaults(self):
"""Verify ModelConfig creates with correct defaults."""
config = ModelConfig()
assert config.name == "facebook/dinov2-large"
assert config.compression_dim == 256
assert config.pooling_type == PoolingType.ATTENTION
assert config.top_k_ratio == 0.5
assert config.hidden_ratio == 2.0
assert config.dropout_rate == 0.1
assert config.use_residual is True
assert config.device == "auto"
def test_model_config_validation(self):
"""Test validation constraints for ModelConfig."""
# Test compression_dim > 0
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(compression_dim=0)
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(compression_dim=-1)
# Test top_k_ratio in [0, 1]
with pytest.raises(ValidationError, match="less than or equal to 1"):
ModelConfig(top_k_ratio=1.5)
with pytest.raises(ValidationError, match="greater than or equal to 0"):
ModelConfig(top_k_ratio=-0.1)
# Test dropout_rate in [0, 1]
with pytest.raises(ValidationError, match="less than or equal to 1"):
ModelConfig(dropout_rate=1.5)
with pytest.raises(ValidationError, match="greater than or equal to 0"):
ModelConfig(dropout_rate=-0.1)
# Test hidden_ratio > 0
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(hidden_ratio=0)
with pytest.raises(ValidationError, match="greater than 0"):
ModelConfig(hidden_ratio=-1)
def test_visualization_config_defaults(self):
"""Verify VisualizationConfig creates with correct defaults."""
config = VisualizationConfig()
assert config.plot_theme == "plotly_white"
assert config.color_scale == "viridis"
assert config.point_size == 8
assert config.fig_width == 900
assert config.fig_height == 600
def test_visualization_config_validation(self):
"""Test validation constraints for VisualizationConfig."""
# Test fig_width > 0
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_width=0)
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_width=-1)
# Test fig_height > 0
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_height=0)
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(fig_height=-1)
# Test point_size > 0
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(point_size=0)
with pytest.raises(ValidationError, match="greater than 0"):
VisualizationConfig(point_size=-1)
def test_output_config_defaults(self):
"""Verify OutputConfig creates with correct defaults."""
config = OutputConfig()
output_dir = Path(__file__).parent.parent.parent / "outputs"
assert config.directory == output_dir
assert config.html_self_contained is True
assert config.png_scale == 2
def test_output_config_validation(self):
"""Test validation constraints for OutputConfig."""
# Test png_scale > 0
with pytest.raises(ValidationError, match="greater than 0"):
OutputConfig(png_scale=0)
with pytest.raises(ValidationError, match="greater than 0"):
OutputConfig(png_scale=-1)
def test_pooling_type_enum(self):
"""Verify PoolingType enum values."""
assert PoolingType.ATTENTION.value == "attention"
assert PoolingType.ATTENTION == PoolingType("attention")
def test_feature_compressor_config(self):
"""Verify FeatureCompressorConfig nests all models correctly."""
model_cfg = ModelConfig(compression_dim=512)
viz_cfg = VisualizationConfig(point_size=16)
out_cfg = OutputConfig(directory="/tmp/outputs")
config = FeatureCompressorConfig(
model=model_cfg,
visualization=viz_cfg,
output=out_cfg,
)
assert config.model.compression_dim == 512
assert config.visualization.point_size == 16
assert config.output.directory == Path("/tmp/outputs")
class TestYamlLoader:
"""Test suite for YAML loading and saving."""
def test_load_existing_yaml(self):
"""Load feature_compressor.yaml and verify values."""
config_path = (
Path(__file__).parent.parent / "configs" / "feature_compressor.yaml"
)
config = load_yaml(config_path, FeatureCompressorConfig)
# Verify model config
assert config.model.name == "facebook/dinov2-large"
assert config.model.compression_dim == 256
assert config.model.pooling_type == PoolingType.ATTENTION
assert config.model.top_k_ratio == 0.5
assert config.model.hidden_ratio == 2.0
assert config.model.dropout_rate == 0.1
assert config.model.use_residual is True
# Verify visualization config
assert config.visualization.plot_theme == "plotly_white"
assert config.visualization.color_scale == "viridis"
assert config.visualization.point_size == 8
assert config.visualization.fig_width == 900
assert config.visualization.fig_height == 600
# Verify output config
output_dir = Path(__file__).parent.parent.parent / "outputs"
assert config.output.directory == output_dir
assert config.output.html_self_contained is True
assert config.output.png_scale == 2
def test_load_yaml_validation(self):
"""Test that invalid data raises ConfigError."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
# Write invalid config (missing required fields)
yaml.dump({"invalid": "data"}, f)
temp_path = f.name
try:
with pytest.raises(ConfigError, match="validation failed"):
load_yaml(Path(temp_path), FeatureCompressorConfig)
finally:
Path(temp_path).unlink()
def test_save_yaml_roundtrip(self):
"""Create config, save to temp, verify file exists with content."""
original = cfg_manager.load_config("feature_compressor")
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
temp_path = Path(f.name)
try:
save_yaml(temp_path, original)
# Verify file exists and has content
assert Path(temp_path).exists()
with open(temp_path, "r") as f:
content = f.read()
assert len(content) > 0
assert "model" in content
assert "visualization" in content
assert "output" in content
finally:
Path(temp_path).unlink()
def test_load_yaml_file_not_found(self):
"""Verify FileNotFoundError raises ConfigError."""
with pytest.raises(ConfigError, match="not found"):
load_yaml(Path("/nonexistent/path/config.yaml"), FeatureCompressorConfig)
class TestConfigManager:
"""Test suite for ConfigManager singleton with multi-config support."""
def test_singleton_pattern(self):
"""Verify ConfigManager() returns same instance."""
manager1 = ConfigManager()
manager2 = ConfigManager()
assert manager1 is manager2
def test_load_config(self):
"""Test loading feature_compressor config."""
config = cfg_manager.load_config("feature_compressor")
assert config is not None
assert config.model.compression_dim == 256
assert config.visualization.point_size == 8
def test_get_config_not_loaded(self):
"""Test that get_config() raises error for unloaded config."""
with pytest.raises(ValueError, match="not loaded"):
cfg_manager.get_config("nonexistent_config")
def test_list_configs(self):
"""Test listing all loaded configurations."""
cfg_manager.load_config("feature_compressor")
configs = cfg_manager.list_configs()
assert "feature_compressor" in configs
def test_save_config(self):
"""Test saving configuration to file."""
config = FeatureCompressorConfig(
model=ModelConfig(compression_dim=512),
visualization=VisualizationConfig(),
output=OutputConfig(),
)
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
temp_path = Path(f.name)
try:
cfg_manager.save_config("test_config", config, path=temp_path)
loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
assert loaded_config.model.compression_dim == 512
finally:
if temp_path.exists():
temp_path.unlink()