Files
Mini-Nav/mini-nav/feature_compressor/core/visualizer.py

182 lines
5.3 KiB
Python

"""Feature visualization using Plotly."""
import os
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from plotly.graph_objs import Figure
from ..utils.plot_utils import (
apply_theme,
create_comparison_plot,
create_histogram,
create_pca_scatter_2d,
save_figure,
)
class FeatureVisualizer:
"""Visualize DINOv2 features with interactive Plotly charts.
Supports histograms, PCA projections, and feature comparisons
with multiple export formats.
Args:
config_path: Path to YAML configuration file
"""
def __init__(self, config_path: Optional[str] = None):
self.config: FeatureCompressorConfig = self._load_config(config_path)
def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration Pydantic model
"""
if config_path is None:
return cfg_manager.get()
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)
def plot_histogram(
self, features: torch.Tensor, title: Optional[str] = None
) -> Figure:
"""Plot histogram of feature values.
Args:
features: Feature tensor [batch, dim]
title: Plot title
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_histogram(
features_np, title="Feature Histogram" if title is None else title
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
def plot_pca_2d(
self, features: torch.Tensor, labels: Optional[List] = None
) -> Figure:
"""Plot 2D PCA projection of features.
Args:
features: Feature tensor [n_samples, dim]
labels: Optional labels for coloring
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_pca_scatter_2d(
features_np,
labels=[i for i in range(len(features_np))] if labels is None else labels,
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_traces(
marker=dict(
size=self.config.visualization.point_size,
colorscale=self.config.visualization.color_scale,
)
)
fig.update_layout(
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
def plot_comparison(
self, features_list: List[torch.Tensor], names: List[str]
) -> object:
"""Plot comparison of multiple feature sets.
Args:
features_list: List of feature tensors
names: Names for each feature set
Returns:
Plotly Figure object
"""
features_np_list = [f.cpu().numpy() for f in features_list]
fig = create_comparison_plot(features_np_list, names)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=self.config.visualization.fig_width * len(features_list),
height=self.config.visualization.fig_height,
)
return fig
def generate_report(
self, results: List[dict], output_dir: Union[str, Path]
) -> List[str]:
"""Generate full feature analysis report.
Args:
results: List of extractor results
output_dir: Directory to save visualizations
Returns:
List of generated file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
generated_files = []
# Extract all compressed features
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
# Create histogram
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
hist_path = output_dir / "feature_histogram"
self.save(hist_fig, str(hist_path), formats=["html"])
generated_files.append(str(hist_path) + ".html")
# Create PCA
pca_fig = self.plot_pca_2d(all_features)
pca_path = output_dir / "feature_pca_2d"
self.save(pca_fig, str(pca_path), formats=["html", "png"])
generated_files.append(str(pca_path) + ".html")
generated_files.append(str(pca_path) + ".png")
return generated_files
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
"""Save figure in multiple formats.
Args:
fig: Plotly Figure object
path: Output file path (without extension)
formats: List of formats to export
"""
if formats is None:
formats = ["html"]
for fmt in formats:
if fmt == "png":
save_figure(fig, path, format="png")
else:
save_figure(fig, path, format=fmt)