mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
182 lines
5.3 KiB
Python
182 lines
5.3 KiB
Python
"""Feature visualization using Plotly."""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
|
|
from plotly.graph_objs import Figure
|
|
|
|
from ..utils.plot_utils import (
|
|
apply_theme,
|
|
create_comparison_plot,
|
|
create_histogram,
|
|
create_pca_scatter_2d,
|
|
save_figure,
|
|
)
|
|
|
|
|
|
class FeatureVisualizer:
|
|
"""Visualize DINOv2 features with interactive Plotly charts.
|
|
|
|
Supports histograms, PCA projections, and feature comparisons
|
|
with multiple export formats.
|
|
|
|
Args:
|
|
config_path: Path to YAML configuration file
|
|
"""
|
|
|
|
def __init__(self, config_path: Optional[str] = None):
|
|
self.config: FeatureCompressorConfig = self._load_config(config_path)
|
|
|
|
def _load_config(
|
|
self, config_path: Optional[str] = None
|
|
) -> FeatureCompressorConfig:
|
|
"""Load configuration from YAML file.
|
|
|
|
Args:
|
|
config_path: Path to config file, or None for default
|
|
|
|
Returns:
|
|
Configuration Pydantic model
|
|
"""
|
|
if config_path is None:
|
|
return cfg_manager.get()
|
|
else:
|
|
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
|
|
|
def plot_histogram(
|
|
self, features: torch.Tensor, title: Optional[str] = None
|
|
) -> Figure:
|
|
"""Plot histogram of feature values.
|
|
|
|
Args:
|
|
features: Feature tensor [batch, dim]
|
|
title: Plot title
|
|
|
|
Returns:
|
|
Plotly Figure object
|
|
"""
|
|
features_np = features.cpu().numpy()
|
|
fig = create_histogram(
|
|
features_np, title="Feature Histogram" if title is None else title
|
|
)
|
|
|
|
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
|
fig.update_layout(
|
|
width=self.config.visualization.fig_width,
|
|
height=self.config.visualization.fig_height,
|
|
)
|
|
|
|
return fig
|
|
|
|
def plot_pca_2d(
|
|
self, features: torch.Tensor, labels: Optional[List] = None
|
|
) -> Figure:
|
|
"""Plot 2D PCA projection of features.
|
|
|
|
Args:
|
|
features: Feature tensor [n_samples, dim]
|
|
labels: Optional labels for coloring
|
|
|
|
Returns:
|
|
Plotly Figure object
|
|
"""
|
|
features_np = features.cpu().numpy()
|
|
|
|
fig = create_pca_scatter_2d(
|
|
features_np,
|
|
labels=[i for i in range(len(features_np))] if labels is None else labels,
|
|
)
|
|
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
|
fig.update_traces(
|
|
marker=dict(
|
|
size=self.config.visualization.point_size,
|
|
colorscale=self.config.visualization.color_scale,
|
|
)
|
|
)
|
|
fig.update_layout(
|
|
width=self.config.visualization.fig_width,
|
|
height=self.config.visualization.fig_height,
|
|
)
|
|
|
|
return fig
|
|
|
|
def plot_comparison(
|
|
self, features_list: List[torch.Tensor], names: List[str]
|
|
) -> object:
|
|
"""Plot comparison of multiple feature sets.
|
|
|
|
Args:
|
|
features_list: List of feature tensors
|
|
names: Names for each feature set
|
|
|
|
Returns:
|
|
Plotly Figure object
|
|
"""
|
|
features_np_list = [f.cpu().numpy() for f in features_list]
|
|
|
|
fig = create_comparison_plot(features_np_list, names)
|
|
|
|
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
|
fig.update_layout(
|
|
width=self.config.visualization.fig_width * len(features_list),
|
|
height=self.config.visualization.fig_height,
|
|
)
|
|
|
|
return fig
|
|
|
|
def generate_report(
|
|
self, results: List[dict], output_dir: Union[str, Path]
|
|
) -> List[str]:
|
|
"""Generate full feature analysis report.
|
|
|
|
Args:
|
|
results: List of extractor results
|
|
output_dir: Directory to save visualizations
|
|
|
|
Returns:
|
|
List of generated file paths
|
|
"""
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
generated_files = []
|
|
|
|
# Extract all compressed features
|
|
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
|
|
|
|
# Create histogram
|
|
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
|
|
hist_path = output_dir / "feature_histogram"
|
|
self.save(hist_fig, str(hist_path), formats=["html"])
|
|
generated_files.append(str(hist_path) + ".html")
|
|
|
|
# Create PCA
|
|
pca_fig = self.plot_pca_2d(all_features)
|
|
pca_path = output_dir / "feature_pca_2d"
|
|
self.save(pca_fig, str(pca_path), formats=["html", "png"])
|
|
generated_files.append(str(pca_path) + ".html")
|
|
generated_files.append(str(pca_path) + ".png")
|
|
|
|
return generated_files
|
|
|
|
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
|
|
"""Save figure in multiple formats.
|
|
|
|
Args:
|
|
fig: Plotly Figure object
|
|
path: Output file path (without extension)
|
|
formats: List of formats to export
|
|
"""
|
|
if formats is None:
|
|
formats = ["html"]
|
|
|
|
for fmt in formats:
|
|
if fmt == "png":
|
|
save_figure(fig, path, format="png")
|
|
else:
|
|
save_figure(fig, path, format=fmt)
|