Files
Mini-Nav/mini-nav/feature_compressor/core/visualizer.py

179 lines
5.3 KiB
Python

"""Feature visualization using Plotly."""
import os
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import yaml
from plotly.graph_objs import Figure
from ..utils.plot_utils import (
apply_theme,
create_comparison_plot,
create_histogram,
create_pca_scatter_2d,
save_figure,
)
class FeatureVisualizer:
"""Visualize DINOv2 features with interactive Plotly charts.
Supports histograms, PCA projections, and feature comparisons
with multiple export formats.
Args:
config_path: Path to YAML configuration file
"""
def __init__(self, config_path: Optional[str] = None):
self.config = self._load_config(config_path)
def _load_config(self, config_path: Optional[str] = None) -> dict:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration dictionary
"""
if config_path is None:
config_path = (
Path(__file__).parent.parent.parent
/ "configs"
/ "feature_compressor.yaml"
)
with open(config_path) as f:
return yaml.safe_load(f)
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
"""Plot histogram of feature values.
Args:
features: Feature tensor [batch, dim]
title: Plot title
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_histogram(features_np, title=title)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
)
return fig
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
"""Plot 2D PCA projection of features.
Args:
features: Feature tensor [n_samples, dim]
labels: Optional labels for coloring
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
viz_config = self.config.get("visualization", {})
fig = create_pca_scatter_2d(features_np, labels=labels)
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_traces(
marker=dict(
size=viz_config.get("point_size", 8),
colorscale=viz_config.get("color_scale", "viridis"),
)
)
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
)
return fig
def plot_comparison(
self, features_list: List[torch.Tensor], names: List[str]
) -> object:
"""Plot comparison of multiple feature sets.
Args:
features_list: List of feature tensors
names: Names for each feature set
Returns:
Plotly Figure object
"""
features_np_list = [f.cpu().numpy() for f in features_list]
fig = create_comparison_plot(features_np_list, names)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_layout(
width=viz_config.get("fig_width", 900) * len(features_list),
height=viz_config.get("fig_height", 600),
)
return fig
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
"""Generate full feature analysis report.
Args:
results: List of extractor results
output_dir: Directory to save visualizations
Returns:
List of generated file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
generated_files = []
# Extract all compressed features
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
# Create histogram
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
hist_path = output_dir / "feature_histogram"
self.save(hist_fig, str(hist_path), formats=["html"])
generated_files.append(str(hist_path) + ".html")
# Create PCA
pca_fig = self.plot_pca_2d(all_features)
pca_path = output_dir / "feature_pca_2d"
self.save(pca_fig, str(pca_path), formats=["html", "png"])
generated_files.append(str(pca_path) + ".html")
generated_files.append(str(pca_path) + ".png")
return generated_files
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
"""Save figure in multiple formats.
Args:
fig: Plotly Figure object
path: Output file path (without extension)
formats: List of formats to export
"""
if formats is None:
formats = ["html"]
output_config = self.config.get("output", {})
for fmt in formats:
if fmt == "png":
save_figure(fig, path, format="png")
else:
save_figure(fig, path, format=fmt)