feat(configs): implement Pydantic configuration system with type safety

This commit is contained in:
2026-01-31 12:19:11 +08:00
parent 1454647aa6
commit 9e9070bdb4
10 changed files with 628 additions and 78 deletions

View File

@@ -2,11 +2,11 @@
import os
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import torch
import yaml
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from plotly.graph_objs import Figure
from ..utils.plot_utils import (
@@ -29,28 +29,27 @@ class FeatureVisualizer:
"""
def __init__(self, config_path: Optional[str] = None):
self.config = self._load_config(config_path)
self.config: FeatureCompressorConfig = self._load_config(config_path)
def _load_config(self, config_path: Optional[str] = None) -> dict:
def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration dictionary
Configuration Pydantic model
"""
if config_path is None:
config_path = (
Path(__file__).parent.parent.parent
/ "configs"
/ "feature_compressor.yaml"
)
return cfg_manager.get_or_load_config("feature_compressor")
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)
with open(config_path) as f:
return yaml.safe_load(f)
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
def plot_histogram(
self, features: torch.Tensor, title: Optional[str] = None
) -> Figure:
"""Plot histogram of feature values.
Args:
@@ -61,18 +60,21 @@ class FeatureVisualizer:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_histogram(features_np, title=title)
fig = create_histogram(
features_np, title="Feature Histogram" if title is None else title
)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
def plot_pca_2d(
self, features: torch.Tensor, labels: Optional[List] = None
) -> Figure:
"""Plot 2D PCA projection of features.
Args:
@@ -83,19 +85,21 @@ class FeatureVisualizer:
Plotly Figure object
"""
features_np = features.cpu().numpy()
viz_config = self.config.get("visualization", {})
fig = create_pca_scatter_2d(features_np, labels=labels)
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig = create_pca_scatter_2d(
features_np,
labels=[i for i in range(len(features_np))] if labels is None else labels,
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_traces(
marker=dict(
size=viz_config.get("point_size", 8),
colorscale=viz_config.get("color_scale", "viridis"),
size=self.config.visualization.point_size,
colorscale=self.config.visualization.color_scale,
)
)
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
@@ -116,16 +120,17 @@ class FeatureVisualizer:
fig = create_comparison_plot(features_np_list, names)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=viz_config.get("fig_width", 900) * len(features_list),
height=viz_config.get("fig_height", 600),
width=self.config.visualization.fig_width * len(features_list),
height=self.config.visualization.fig_height,
)
return fig
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
def generate_report(
self, results: List[dict], output_dir: Union[str, Path]
) -> List[str]:
"""Generate full feature analysis report.
Args:
@@ -158,7 +163,7 @@ class FeatureVisualizer:
return generated_files
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
"""Save figure in multiple formats.
Args:
@@ -169,8 +174,6 @@ class FeatureVisualizer:
if formats is None:
formats = ["html"]
output_config = self.config.get("output", {})
for fmt in formats:
if fmt == "png":
save_figure(fig, path, format="png")