mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(configs): implement Pydantic configuration system with type safety
This commit is contained in:
@@ -2,13 +2,12 @@
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from ...configs.config import Config, get_default_config
|
||||
from ..utils.image_utils import load_image, preprocess_image
|
||||
from .compressor import PoolNetCompressor
|
||||
|
||||
@@ -25,47 +24,47 @@ class DINOv2FeatureExtractor:
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None, device: str = "auto"):
|
||||
self.config = self._load_config(config_path)
|
||||
self.config: FeatureCompressorConfig = self._load_config(config_path)
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = self.config.get("model", {}).get("device", "auto")
|
||||
device = self.config.model.device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = torch.device(device)
|
||||
|
||||
# Load DINOv2 model and processor
|
||||
model_name = self.config.get("model", {}).get("name", "facebook/dinov2-large")
|
||||
model_name = self.config.model.name
|
||||
self.processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Initialize compressor
|
||||
model_config = self.config.get("model", {})
|
||||
self.compressor = PoolNetCompressor(
|
||||
input_dim=self.model.config.hidden_size,
|
||||
compression_dim=model_config.get("compression_dim", 256),
|
||||
top_k_ratio=model_config.get("top_k_ratio", 0.5),
|
||||
hidden_ratio=model_config.get("hidden_ratio", 2.0),
|
||||
dropout_rate=model_config.get("dropout_rate", 0.1),
|
||||
use_residual=model_config.get("use_residual", True),
|
||||
compression_dim=self.config.model.compression_dim,
|
||||
top_k_ratio=self.config.model.top_k_ratio,
|
||||
hidden_ratio=self.config.model.hidden_ratio,
|
||||
dropout_rate=self.config.model.dropout_rate,
|
||||
use_residual=self.config.model.use_residual,
|
||||
device=str(self.device),
|
||||
)
|
||||
|
||||
def _load_config(self, config_path: Optional[str] = None) -> Dict:
|
||||
def _load_config(
|
||||
self, config_path: Optional[str] = None
|
||||
) -> FeatureCompressorConfig:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file, or None for default
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
FeatureCompressorConfig instance
|
||||
"""
|
||||
if config_path is None:
|
||||
return get_default_config(Config.FEATURE_COMPRESSOR)
|
||||
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
return cfg_manager.get_or_load_config("feature_compressor")
|
||||
else:
|
||||
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
||||
|
||||
def _extract_dinov2_features(self, images: List) -> torch.Tensor:
|
||||
"""Extract DINOv2 last_hidden_state features.
|
||||
@@ -149,14 +148,17 @@ class DINOv2FeatureExtractor:
|
||||
"processing_time": processing_time,
|
||||
"feature_norm": feature_norm,
|
||||
"device": str(self.device),
|
||||
"model_name": self.config.get("model", {}).get("name"),
|
||||
"model_name": self.config.model.name,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def process_batch(
|
||||
self, image_dir: str, batch_size: int = 8, save_features: bool = True
|
||||
self,
|
||||
image_dir: Union[str, Path],
|
||||
batch_size: int = 8,
|
||||
save_features: bool = True,
|
||||
) -> List[Dict[str, object]]:
|
||||
"""Process multiple images in batches.
|
||||
|
||||
@@ -208,7 +210,7 @@ class DINOv2FeatureExtractor:
|
||||
.mean()
|
||||
.item(),
|
||||
"device": str(self.device),
|
||||
"model_name": self.config.get("model", {}).get("name"),
|
||||
"model_name": self.config.model.name,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -216,12 +218,7 @@ class DINOv2FeatureExtractor:
|
||||
|
||||
# Save features if requested
|
||||
if save_features:
|
||||
output_dir = Path(
|
||||
self.config.get("output", {}).get("directory", "./outputs")
|
||||
)
|
||||
# Resolve relative to project root
|
||||
if not output_dir.is_absolute():
|
||||
output_dir = Path(__file__).parent.parent.parent / output_dir
|
||||
output_dir = Path(self.config.output.directory)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_path = output_dir / f"{file_path.stem}_features.json"
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from ..utils.plot_utils import (
|
||||
@@ -29,28 +29,27 @@ class FeatureVisualizer:
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config = self._load_config(config_path)
|
||||
self.config: FeatureCompressorConfig = self._load_config(config_path)
|
||||
|
||||
def _load_config(self, config_path: Optional[str] = None) -> dict:
|
||||
def _load_config(
|
||||
self, config_path: Optional[str] = None
|
||||
) -> FeatureCompressorConfig:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file, or None for default
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
Configuration Pydantic model
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "configs"
|
||||
/ "feature_compressor.yaml"
|
||||
)
|
||||
return cfg_manager.get_or_load_config("feature_compressor")
|
||||
else:
|
||||
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
||||
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
|
||||
def plot_histogram(
|
||||
self, features: torch.Tensor, title: Optional[str] = None
|
||||
) -> Figure:
|
||||
"""Plot histogram of feature values.
|
||||
|
||||
Args:
|
||||
@@ -61,18 +60,21 @@ class FeatureVisualizer:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
fig = create_histogram(features_np, title=title)
|
||||
fig = create_histogram(
|
||||
features_np, title="Feature Histogram" if title is None else title
|
||||
)
|
||||
|
||||
viz_config = self.config.get("visualization", {})
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
width=self.config.visualization.fig_width,
|
||||
height=self.config.visualization.fig_height,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
|
||||
def plot_pca_2d(
|
||||
self, features: torch.Tensor, labels: Optional[List] = None
|
||||
) -> Figure:
|
||||
"""Plot 2D PCA projection of features.
|
||||
|
||||
Args:
|
||||
@@ -83,19 +85,21 @@ class FeatureVisualizer:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
viz_config = self.config.get("visualization", {})
|
||||
|
||||
fig = create_pca_scatter_2d(features_np, labels=labels)
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig = create_pca_scatter_2d(
|
||||
features_np,
|
||||
labels=[i for i in range(len(features_np))] if labels is None else labels,
|
||||
)
|
||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
||||
fig.update_traces(
|
||||
marker=dict(
|
||||
size=viz_config.get("point_size", 8),
|
||||
colorscale=viz_config.get("color_scale", "viridis"),
|
||||
size=self.config.visualization.point_size,
|
||||
colorscale=self.config.visualization.color_scale,
|
||||
)
|
||||
)
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
width=self.config.visualization.fig_width,
|
||||
height=self.config.visualization.fig_height,
|
||||
)
|
||||
|
||||
return fig
|
||||
@@ -116,16 +120,17 @@ class FeatureVisualizer:
|
||||
|
||||
fig = create_comparison_plot(features_np_list, names)
|
||||
|
||||
viz_config = self.config.get("visualization", {})
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900) * len(features_list),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
width=self.config.visualization.fig_width * len(features_list),
|
||||
height=self.config.visualization.fig_height,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
|
||||
def generate_report(
|
||||
self, results: List[dict], output_dir: Union[str, Path]
|
||||
) -> List[str]:
|
||||
"""Generate full feature analysis report.
|
||||
|
||||
Args:
|
||||
@@ -158,7 +163,7 @@ class FeatureVisualizer:
|
||||
|
||||
return generated_files
|
||||
|
||||
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
|
||||
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
|
||||
"""Save figure in multiple formats.
|
||||
|
||||
Args:
|
||||
@@ -169,8 +174,6 @@ class FeatureVisualizer:
|
||||
if formats is None:
|
||||
formats = ["html"]
|
||||
|
||||
output_config = self.config.get("output", {})
|
||||
|
||||
for fmt in formats:
|
||||
if fmt == "png":
|
||||
save_figure(fig, path, format="png")
|
||||
|
||||
Reference in New Issue
Block a user