mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(configs): remove unused settings
This commit is contained in:
@@ -1,22 +1,20 @@
|
|||||||
from .models import (
|
|
||||||
ModelConfig,
|
|
||||||
VisualizationConfig,
|
|
||||||
OutputConfig,
|
|
||||||
FeatureCompressorConfig,
|
|
||||||
PoolingType,
|
|
||||||
)
|
|
||||||
from .loader import load_yaml, save_yaml, ConfigError
|
|
||||||
from .config import (
|
from .config import (
|
||||||
ConfigManager,
|
ConfigManager,
|
||||||
cfg_manager,
|
cfg_manager,
|
||||||
)
|
)
|
||||||
|
from .loader import ConfigError, load_yaml, save_yaml
|
||||||
|
from .models import (
|
||||||
|
Config,
|
||||||
|
ModelConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PoolingType,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Models
|
# Models
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
"VisualizationConfig",
|
|
||||||
"OutputConfig",
|
"OutputConfig",
|
||||||
"FeatureCompressorConfig",
|
"Config",
|
||||||
"PoolingType",
|
"PoolingType",
|
||||||
# Loader
|
# Loader
|
||||||
"load_yaml",
|
"load_yaml",
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .loader import load_yaml, save_yaml
|
from .loader import load_yaml, save_yaml
|
||||||
from .models import FeatureCompressorConfig
|
from .models import Config
|
||||||
|
|
||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
"""Singleton configuration manager for unified config."""
|
"""Singleton configuration manager for unified config."""
|
||||||
|
|
||||||
_instance: Optional["ConfigManager"] = None
|
_instance: Optional["ConfigManager"] = None
|
||||||
_config: Optional[FeatureCompressorConfig] = None
|
_config: Optional[Config] = None
|
||||||
|
|
||||||
def __new__(cls) -> "ConfigManager":
|
def __new__(cls) -> "ConfigManager":
|
||||||
"""Singleton pattern - ensure only one instance exists."""
|
"""Singleton pattern - ensure only one instance exists."""
|
||||||
@@ -24,17 +24,17 @@ class ConfigManager:
|
|||||||
self.config_dir = Path(__file__).parent
|
self.config_dir = Path(__file__).parent
|
||||||
self.config_path = self.config_dir / "config.yaml"
|
self.config_path = self.config_dir / "config.yaml"
|
||||||
|
|
||||||
def load(self) -> FeatureCompressorConfig:
|
def load(self) -> Config:
|
||||||
"""Load configuration from config.yaml file.
|
"""Load configuration from config.yaml file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded and validated FeatureCompressorConfig instance
|
Loaded and validated FeatureCompressorConfig instance
|
||||||
"""
|
"""
|
||||||
config = load_yaml(self.config_path, FeatureCompressorConfig)
|
config = load_yaml(self.config_path, Config)
|
||||||
self._config = config
|
self._config = config
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def get(self) -> FeatureCompressorConfig:
|
def get(self) -> Config:
|
||||||
"""Get loaded configuration, auto-loading if not already loaded.
|
"""Get loaded configuration, auto-loading if not already loaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -50,7 +50,7 @@ class ConfigManager:
|
|||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
config: Optional[FeatureCompressorConfig] = None,
|
config: Optional[Config] = None,
|
||||||
path: Optional[Path] = None,
|
path: Optional[Path] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save configuration to YAML file.
|
"""Save configuration to YAML file.
|
||||||
|
|||||||
@@ -1,21 +1,7 @@
|
|||||||
model:
|
model:
|
||||||
name: "facebook/dinov2-large"
|
name: "facebook/dinov2-large"
|
||||||
compression_dim: 256
|
compression_dim: 512
|
||||||
pooling_type: "attention" # attention-based Top-K
|
device: "auto" # auto-detect GPU
|
||||||
top_k_ratio: 0.5 # Keep 50% of tokens
|
|
||||||
hidden_ratio: 2.0 # MLP hidden = compression_dim * 2
|
|
||||||
dropout_rate: 0.1
|
|
||||||
use_residual: true
|
|
||||||
device: "auto" # auto-detect GPU
|
|
||||||
|
|
||||||
visualization:
|
|
||||||
plot_theme: "plotly_white"
|
|
||||||
color_scale: "viridis"
|
|
||||||
point_size: 8
|
|
||||||
fig_width: 900
|
|
||||||
fig_height: 600
|
|
||||||
|
|
||||||
output:
|
output:
|
||||||
directory: "./outputs"
|
directory: "./outputs"
|
||||||
html_self_contained: true
|
|
||||||
png_scale: 2 # 2x resolution for PNG
|
|
||||||
|
|||||||
@@ -19,42 +19,17 @@ class ModelConfig(BaseModel):
|
|||||||
|
|
||||||
name: str = "facebook/dinov2-large"
|
name: str = "facebook/dinov2-large"
|
||||||
compression_dim: int = Field(
|
compression_dim: int = Field(
|
||||||
default=256, gt=0, description="Output feature dimension"
|
default=512, gt=0, description="Output feature dimension"
|
||||||
)
|
)
|
||||||
pooling_type: PoolingType = PoolingType.ATTENTION
|
|
||||||
top_k_ratio: float = Field(
|
|
||||||
default=0.5, ge=0, le=1, description="Ratio of tokens to keep"
|
|
||||||
)
|
|
||||||
hidden_ratio: float = Field(
|
|
||||||
default=2.0, gt=0, description="MLP hidden dim as multiple of compression_dim"
|
|
||||||
)
|
|
||||||
dropout_rate: float = Field(
|
|
||||||
default=0.1, ge=0, le=1, description="Dropout probability"
|
|
||||||
)
|
|
||||||
use_residual: bool = True
|
|
||||||
device: str = "auto"
|
device: str = "auto"
|
||||||
|
|
||||||
|
|
||||||
class VisualizationConfig(BaseModel):
|
|
||||||
"""Configuration for visualization settings."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
|
||||||
|
|
||||||
plot_theme: str = "plotly_white"
|
|
||||||
color_scale: str = "viridis"
|
|
||||||
point_size: int = Field(default=8, gt=0)
|
|
||||||
fig_width: int = Field(default=900, gt=0)
|
|
||||||
fig_height: int = Field(default=600, gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputConfig(BaseModel):
|
class OutputConfig(BaseModel):
|
||||||
"""Configuration for output settings."""
|
"""Configuration for output settings."""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
directory: Path = Path(__file__).parent.parent.parent / "outputs"
|
directory: Path = Path(__file__).parent.parent.parent / "outputs"
|
||||||
html_self_contained: bool = True
|
|
||||||
png_scale: int = Field(default=2, gt=0)
|
|
||||||
|
|
||||||
@field_validator("directory", mode="after")
|
@field_validator("directory", mode="after")
|
||||||
def convert_to_absolute(cls, v: Path) -> Path:
|
def convert_to_absolute(cls, v: Path) -> Path:
|
||||||
@@ -67,11 +42,10 @@ class OutputConfig(BaseModel):
|
|||||||
return Path(__file__).parent.parent.parent / v
|
return Path(__file__).parent.parent.parent / v
|
||||||
|
|
||||||
|
|
||||||
class FeatureCompressorConfig(BaseModel):
|
class Config(BaseModel):
|
||||||
"""Root configuration for the feature compressor."""
|
"""Root configuration for the feature compressor."""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
visualization: VisualizationConfig
|
|
||||||
output: OutputConfig
|
output: OutputConfig
|
||||||
|
|||||||
@@ -6,13 +6,12 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from configs import (
|
from configs import (
|
||||||
|
Config,
|
||||||
ConfigError,
|
ConfigError,
|
||||||
ConfigManager,
|
ConfigManager,
|
||||||
FeatureCompressorConfig,
|
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
OutputConfig,
|
OutputConfig,
|
||||||
PoolingType,
|
PoolingType,
|
||||||
VisualizationConfig,
|
|
||||||
cfg_manager,
|
cfg_manager,
|
||||||
load_yaml,
|
load_yaml,
|
||||||
save_yaml,
|
save_yaml,
|
||||||
@@ -27,12 +26,7 @@ class TestConfigModels:
|
|||||||
"""Verify ModelConfig creates with correct defaults."""
|
"""Verify ModelConfig creates with correct defaults."""
|
||||||
config = ModelConfig()
|
config = ModelConfig()
|
||||||
assert config.name == "facebook/dinov2-large"
|
assert config.name == "facebook/dinov2-large"
|
||||||
assert config.compression_dim == 256
|
assert config.compression_dim == 512
|
||||||
assert config.pooling_type == PoolingType.ATTENTION
|
|
||||||
assert config.top_k_ratio == 0.5
|
|
||||||
assert config.hidden_ratio == 2.0
|
|
||||||
assert config.dropout_rate == 0.1
|
|
||||||
assert config.use_residual is True
|
|
||||||
assert config.device == "auto"
|
assert config.device == "auto"
|
||||||
|
|
||||||
def test_model_config_validation(self):
|
def test_model_config_validation(self):
|
||||||
@@ -44,76 +38,12 @@ class TestConfigModels:
|
|||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
with pytest.raises(ValidationError, match="greater than 0"):
|
||||||
ModelConfig(compression_dim=-1)
|
ModelConfig(compression_dim=-1)
|
||||||
|
|
||||||
# Test top_k_ratio in [0, 1]
|
|
||||||
with pytest.raises(ValidationError, match="less than or equal to 1"):
|
|
||||||
ModelConfig(top_k_ratio=1.5)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
|
||||||
ModelConfig(top_k_ratio=-0.1)
|
|
||||||
|
|
||||||
# Test dropout_rate in [0, 1]
|
|
||||||
with pytest.raises(ValidationError, match="less than or equal to 1"):
|
|
||||||
ModelConfig(dropout_rate=1.5)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than or equal to 0"):
|
|
||||||
ModelConfig(dropout_rate=-0.1)
|
|
||||||
|
|
||||||
# Test hidden_ratio > 0
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
ModelConfig(hidden_ratio=0)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
ModelConfig(hidden_ratio=-1)
|
|
||||||
|
|
||||||
def test_visualization_config_defaults(self):
|
|
||||||
"""Verify VisualizationConfig creates with correct defaults."""
|
|
||||||
config = VisualizationConfig()
|
|
||||||
assert config.plot_theme == "plotly_white"
|
|
||||||
assert config.color_scale == "viridis"
|
|
||||||
assert config.point_size == 8
|
|
||||||
assert config.fig_width == 900
|
|
||||||
assert config.fig_height == 600
|
|
||||||
|
|
||||||
def test_visualization_config_validation(self):
|
|
||||||
"""Test validation constraints for VisualizationConfig."""
|
|
||||||
# Test fig_width > 0
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
VisualizationConfig(fig_width=0)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
VisualizationConfig(fig_width=-1)
|
|
||||||
|
|
||||||
# Test fig_height > 0
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
VisualizationConfig(fig_height=0)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
VisualizationConfig(fig_height=-1)
|
|
||||||
|
|
||||||
# Test point_size > 0
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
VisualizationConfig(point_size=0)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
VisualizationConfig(point_size=-1)
|
|
||||||
|
|
||||||
def test_output_config_defaults(self):
|
def test_output_config_defaults(self):
|
||||||
"""Verify OutputConfig creates with correct defaults."""
|
"""Verify OutputConfig creates with correct defaults."""
|
||||||
config = OutputConfig()
|
config = OutputConfig()
|
||||||
output_dir = Path(__file__).parent.parent.parent / "outputs"
|
output_dir = Path(__file__).parent.parent.parent / "outputs"
|
||||||
|
|
||||||
assert config.directory == output_dir
|
assert config.directory == output_dir
|
||||||
assert config.html_self_contained is True
|
|
||||||
assert config.png_scale == 2
|
|
||||||
|
|
||||||
def test_output_config_validation(self):
|
|
||||||
"""Test validation constraints for OutputConfig."""
|
|
||||||
# Test png_scale > 0
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
OutputConfig(png_scale=0)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError, match="greater than 0"):
|
|
||||||
OutputConfig(png_scale=-1)
|
|
||||||
|
|
||||||
def test_pooling_type_enum(self):
|
def test_pooling_type_enum(self):
|
||||||
"""Verify PoolingType enum values."""
|
"""Verify PoolingType enum values."""
|
||||||
@@ -123,17 +53,14 @@ class TestConfigModels:
|
|||||||
def test_feature_compressor_config(self):
|
def test_feature_compressor_config(self):
|
||||||
"""Verify FeatureCompressorConfig nests all models correctly."""
|
"""Verify FeatureCompressorConfig nests all models correctly."""
|
||||||
model_cfg = ModelConfig(compression_dim=512)
|
model_cfg = ModelConfig(compression_dim=512)
|
||||||
viz_cfg = VisualizationConfig(point_size=16)
|
|
||||||
out_cfg = OutputConfig(directory="/tmp/outputs")
|
out_cfg = OutputConfig(directory="/tmp/outputs")
|
||||||
|
|
||||||
config = FeatureCompressorConfig(
|
config = Config(
|
||||||
model=model_cfg,
|
model=model_cfg,
|
||||||
visualization=viz_cfg,
|
|
||||||
output=out_cfg,
|
output=out_cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert config.model.compression_dim == 512
|
assert config.model.compression_dim == 512
|
||||||
assert config.visualization.point_size == 16
|
|
||||||
assert config.output.directory == Path("/tmp/outputs")
|
assert config.output.directory == Path("/tmp/outputs")
|
||||||
|
|
||||||
|
|
||||||
@@ -143,30 +70,16 @@ class TestYamlLoader:
|
|||||||
def test_load_existing_yaml(self):
|
def test_load_existing_yaml(self):
|
||||||
"""Load config.yaml and verify values."""
|
"""Load config.yaml and verify values."""
|
||||||
config_path = Path(__file__).parent.parent / "configs" / "config.yaml"
|
config_path = Path(__file__).parent.parent / "configs" / "config.yaml"
|
||||||
config = load_yaml(config_path, FeatureCompressorConfig)
|
config = load_yaml(config_path, Config)
|
||||||
|
|
||||||
# Verify model config
|
# Verify model config
|
||||||
assert config.model.name == "facebook/dinov2-large"
|
assert config.model.name == "facebook/dinov2-large"
|
||||||
assert config.model.compression_dim == 256
|
assert config.model.compression_dim == 256
|
||||||
assert config.model.pooling_type == PoolingType.ATTENTION
|
|
||||||
assert config.model.top_k_ratio == 0.5
|
|
||||||
assert config.model.hidden_ratio == 2.0
|
|
||||||
assert config.model.dropout_rate == 0.1
|
|
||||||
assert config.model.use_residual is True
|
|
||||||
|
|
||||||
# Verify visualization config
|
|
||||||
assert config.visualization.plot_theme == "plotly_white"
|
|
||||||
assert config.visualization.color_scale == "viridis"
|
|
||||||
assert config.visualization.point_size == 8
|
|
||||||
assert config.visualization.fig_width == 900
|
|
||||||
assert config.visualization.fig_height == 600
|
|
||||||
|
|
||||||
# Verify output config
|
# Verify output config
|
||||||
output_dir = Path(__file__).parent.parent.parent / "outputs"
|
output_dir = Path(__file__).parent.parent.parent / "outputs"
|
||||||
|
|
||||||
assert config.output.directory == output_dir
|
assert config.output.directory == output_dir
|
||||||
assert config.output.html_self_contained is True
|
|
||||||
assert config.output.png_scale == 2
|
|
||||||
|
|
||||||
def test_load_yaml_validation(self):
|
def test_load_yaml_validation(self):
|
||||||
"""Test that invalid data raises ConfigError."""
|
"""Test that invalid data raises ConfigError."""
|
||||||
@@ -177,7 +90,7 @@ class TestYamlLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with pytest.raises(ConfigError, match="validation failed"):
|
with pytest.raises(ConfigError, match="validation failed"):
|
||||||
load_yaml(Path(temp_path), FeatureCompressorConfig)
|
load_yaml(Path(temp_path), Config)
|
||||||
finally:
|
finally:
|
||||||
Path(temp_path).unlink()
|
Path(temp_path).unlink()
|
||||||
|
|
||||||
@@ -205,7 +118,7 @@ class TestYamlLoader:
|
|||||||
def test_load_yaml_file_not_found(self):
|
def test_load_yaml_file_not_found(self):
|
||||||
"""Verify FileNotFoundError raises ConfigError."""
|
"""Verify FileNotFoundError raises ConfigError."""
|
||||||
with pytest.raises(ConfigError, match="not found"):
|
with pytest.raises(ConfigError, match="not found"):
|
||||||
load_yaml(Path("/nonexistent/path/config.yaml"), FeatureCompressorConfig)
|
load_yaml(Path("/nonexistent/path/config.yaml"), Config)
|
||||||
|
|
||||||
|
|
||||||
class TestConfigManager:
|
class TestConfigManager:
|
||||||
@@ -223,7 +136,6 @@ class TestConfigManager:
|
|||||||
|
|
||||||
assert config is not None
|
assert config is not None
|
||||||
assert config.model.compression_dim == 256
|
assert config.model.compression_dim == 256
|
||||||
assert config.visualization.point_size == 8
|
|
||||||
|
|
||||||
def test_get_without_load(self):
|
def test_get_without_load(self):
|
||||||
"""Test that get() auto-loads config if not loaded."""
|
"""Test that get() auto-loads config if not loaded."""
|
||||||
@@ -237,9 +149,8 @@ class TestConfigManager:
|
|||||||
|
|
||||||
def test_save_config(self):
|
def test_save_config(self):
|
||||||
"""Test saving configuration to file."""
|
"""Test saving configuration to file."""
|
||||||
config = FeatureCompressorConfig(
|
config = Config(
|
||||||
model=ModelConfig(compression_dim=512),
|
model=ModelConfig(compression_dim=512),
|
||||||
visualization=VisualizationConfig(),
|
|
||||||
output=OutputConfig(),
|
output=OutputConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -248,7 +159,7 @@ class TestConfigManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
cfg_manager.save(config, path=temp_path)
|
cfg_manager.save(config, path=temp_path)
|
||||||
loaded_config = load_yaml(temp_path, FeatureCompressorConfig)
|
loaded_config = load_yaml(temp_path, Config)
|
||||||
|
|
||||||
assert loaded_config.model.compression_dim == 512
|
assert loaded_config.model.compression_dim == 512
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
Reference in New Issue
Block a user