mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(feature-compressor): add DINOv2 feature extraction and compression pipeline
This commit is contained in:
178
mini-nav/feature_compressor/core/visualizer.py
Normal file
178
mini-nav/feature_compressor/core/visualizer.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Feature visualization using Plotly."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from ..utils.plot_utils import (
|
||||
apply_theme,
|
||||
create_comparison_plot,
|
||||
create_histogram,
|
||||
create_pca_scatter_2d,
|
||||
save_figure,
|
||||
)
|
||||
|
||||
|
||||
class FeatureVisualizer:
|
||||
"""Visualize DINOv2 features with interactive Plotly charts.
|
||||
|
||||
Supports histograms, PCA projections, and feature comparisons
|
||||
with multiple export formats.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config = self._load_config(config_path)
|
||||
|
||||
def _load_config(self, config_path: Optional[str] = None) -> dict:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file, or None for default
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "configs"
|
||||
/ "feature_compressor.yaml"
|
||||
)
|
||||
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
|
||||
"""Plot histogram of feature values.
|
||||
|
||||
Args:
|
||||
features: Feature tensor [batch, dim]
|
||||
title: Plot title
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
fig = create_histogram(features_np, title=title)
|
||||
|
||||
viz_config = self.config.get("visualization", {})
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
|
||||
"""Plot 2D PCA projection of features.
|
||||
|
||||
Args:
|
||||
features: Feature tensor [n_samples, dim]
|
||||
labels: Optional labels for coloring
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
viz_config = self.config.get("visualization", {})
|
||||
|
||||
fig = create_pca_scatter_2d(features_np, labels=labels)
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig.update_traces(
|
||||
marker=dict(
|
||||
size=viz_config.get("point_size", 8),
|
||||
colorscale=viz_config.get("color_scale", "viridis"),
|
||||
)
|
||||
)
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_comparison(
|
||||
self, features_list: List[torch.Tensor], names: List[str]
|
||||
) -> object:
|
||||
"""Plot comparison of multiple feature sets.
|
||||
|
||||
Args:
|
||||
features_list: List of feature tensors
|
||||
names: Names for each feature set
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np_list = [f.cpu().numpy() for f in features_list]
|
||||
|
||||
fig = create_comparison_plot(features_np_list, names)
|
||||
|
||||
viz_config = self.config.get("visualization", {})
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900) * len(features_list),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
|
||||
"""Generate full feature analysis report.
|
||||
|
||||
Args:
|
||||
results: List of extractor results
|
||||
output_dir: Directory to save visualizations
|
||||
|
||||
Returns:
|
||||
List of generated file paths
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generated_files = []
|
||||
|
||||
# Extract all compressed features
|
||||
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
|
||||
|
||||
# Create histogram
|
||||
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
|
||||
hist_path = output_dir / "feature_histogram"
|
||||
self.save(hist_fig, str(hist_path), formats=["html"])
|
||||
generated_files.append(str(hist_path) + ".html")
|
||||
|
||||
# Create PCA
|
||||
pca_fig = self.plot_pca_2d(all_features)
|
||||
pca_path = output_dir / "feature_pca_2d"
|
||||
self.save(pca_fig, str(pca_path), formats=["html", "png"])
|
||||
generated_files.append(str(pca_path) + ".html")
|
||||
generated_files.append(str(pca_path) + ".png")
|
||||
|
||||
return generated_files
|
||||
|
||||
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
|
||||
"""Save figure in multiple formats.
|
||||
|
||||
Args:
|
||||
fig: Plotly Figure object
|
||||
path: Output file path (without extension)
|
||||
formats: List of formats to export
|
||||
"""
|
||||
if formats is None:
|
||||
formats = ["html"]
|
||||
|
||||
output_config = self.config.get("output", {})
|
||||
|
||||
for fmt in formats:
|
||||
if fmt == "png":
|
||||
save_figure(fig, path, format="png")
|
||||
else:
|
||||
save_figure(fig, path, format=fmt)
|
||||
Reference in New Issue
Block a user