mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(feature-compressor): add DINOv2 feature extraction and compression pipeline
This commit is contained in:
19
mini-nav/feature_compressor/utils/__init__.py
Normal file
19
mini-nav/feature_compressor/utils/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Utility modules for image, feature, and plot operations."""
|
||||
|
||||
from .feature_utils import (
|
||||
compute_feature_stats,
|
||||
normalize_features,
|
||||
save_features_to_csv,
|
||||
save_features_to_json,
|
||||
)
|
||||
from .image_utils import load_image, load_images_from_directory, preprocess_image
|
||||
|
||||
__all__ = [
|
||||
"load_image",
|
||||
"load_images_from_directory",
|
||||
"preprocess_image",
|
||||
"normalize_features",
|
||||
"compute_feature_stats",
|
||||
"save_features_to_json",
|
||||
"save_features_to_csv",
|
||||
]
|
||||
83
mini-nav/feature_compressor/utils/feature_utils.py
Normal file
83
mini-nav/feature_compressor/utils/feature_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Feature processing utilities."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
|
||||
def normalize_features(features: torch.Tensor) -> torch.Tensor:
|
||||
"""L2-normalize features.
|
||||
|
||||
Args:
|
||||
features: Tensor of shape [batch, dim] or [batch, seq, dim]
|
||||
|
||||
Returns:
|
||||
L2-normalized features
|
||||
"""
|
||||
norm = torch.norm(features, p=2, dim=-1, keepdim=True)
|
||||
return features / (norm + 1e-8)
|
||||
|
||||
|
||||
def compute_feature_stats(features: torch.Tensor) -> Dict[str, float]:
|
||||
"""Compute basic statistics for features.
|
||||
|
||||
Args:
|
||||
features: Tensor of shape [batch, dim] or [batch, seq, dim]
|
||||
|
||||
Returns:
|
||||
Dictionary with mean, std, min, max
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return {
|
||||
"mean": float(features.mean().item()),
|
||||
"std": float(features.std().item()),
|
||||
"min": float(features.min().item()),
|
||||
"max": float(features.max().item()),
|
||||
}
|
||||
|
||||
|
||||
def save_features_to_json(
|
||||
features: torch.Tensor, path: Path, metadata: Dict = None
|
||||
) -> None:
|
||||
"""Save features to JSON file.
|
||||
|
||||
Args:
|
||||
features: Tensor to save
|
||||
path: Output file path
|
||||
metadata: Optional metadata dictionary
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
features_np = features.cpu().numpy()
|
||||
|
||||
data = {
|
||||
"features": features_np.tolist(),
|
||||
"shape": list(features.shape),
|
||||
}
|
||||
|
||||
if metadata:
|
||||
data["metadata"] = metadata
|
||||
|
||||
with open(path, "w") as f:
|
||||
import json
|
||||
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def save_features_to_csv(features: torch.Tensor, path: Path) -> None:
|
||||
"""Save features to CSV file.
|
||||
|
||||
Args:
|
||||
features: Tensor to save
|
||||
path: Output file path
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
features_np = features.cpu().numpy()
|
||||
|
||||
np.savetxt(path, features_np, delimiter=",", fmt="%.6f")
|
||||
76
mini-nav/feature_compressor/utils/image_utils.py
Normal file
76
mini-nav/feature_compressor/utils/image_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Image loading and preprocessing utilities."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(path: Union[str, Path]) -> Image.Image:
|
||||
"""Load an image from file path or URL.
|
||||
|
||||
Args:
|
||||
path: File path or URL to image
|
||||
|
||||
Returns:
|
||||
PIL Image object
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If image cannot be loaded
|
||||
"""
|
||||
path_str = str(path)
|
||||
|
||||
if path_str.startswith(("http://", "https://")):
|
||||
response = requests.get(path_str, stream=True)
|
||||
response.raise_for_status()
|
||||
img = Image.open(response.raw)
|
||||
else:
|
||||
img = Image.open(path)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def preprocess_image(image: Image.Image, size: int = 224) -> Image.Image:
|
||||
"""Preprocess image to square format with resizing.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
size: Target size for shortest dimension (default: 224)
|
||||
|
||||
Returns:
|
||||
Resized PIL Image
|
||||
"""
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Resize while maintaining aspect ratio, then center crop
|
||||
image = image.resize((size, size), Image.Resampling.LANCZOS)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def load_images_from_directory(
|
||||
dir_path: Union[str, Path], extensions: List[str] = None
|
||||
) -> List[Image.Image]:
|
||||
"""Load all images from a directory.
|
||||
|
||||
Args:
|
||||
dir_path: Path to directory
|
||||
extensions: List of file extensions to include (e.g., ['.jpg', '.png'])
|
||||
|
||||
Returns:
|
||||
List of PIL Images
|
||||
"""
|
||||
if extensions is None:
|
||||
extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
|
||||
|
||||
dir_path = Path(dir_path)
|
||||
images = []
|
||||
|
||||
for ext in extensions:
|
||||
images.extend([load_image(p) for p in dir_path.glob(f"*{ext}")])
|
||||
images.extend([load_image(p) for p in dir_path.glob(f"*{ext.upper()}")])
|
||||
|
||||
return images
|
||||
167
mini-nav/feature_compressor/utils/plot_utils.py
Normal file
167
mini-nav/feature_compressor/utils/plot_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Plotting utility functions for feature visualization."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
|
||||
|
||||
def create_histogram(data: np.ndarray, title: str = None, **kwargs) -> go.Figure:
|
||||
"""Create a histogram plot.
|
||||
|
||||
Args:
|
||||
data: 1D array of values
|
||||
title: Plot title
|
||||
**kwargs: Additional histogram arguments
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
x=data.flatten(),
|
||||
name="Feature Values",
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if title:
|
||||
fig.update_layout(title=title)
|
||||
|
||||
fig.update_layout(
|
||||
xaxis_title="Value",
|
||||
yaxis_title="Count",
|
||||
hovermode="x unified",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_pca_scatter_2d(
|
||||
features: np.ndarray, labels: List = None, **kwargs
|
||||
) -> go.Figure:
|
||||
"""Create a 2D PCA scatter plot.
|
||||
|
||||
Args:
|
||||
features: 2D array [n_samples, n_features]
|
||||
labels: Optional list of labels for coloring
|
||||
**kwargs: Additional scatter arguments
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
# Apply PCA
|
||||
pca = PCA(n_components=2)
|
||||
components = pca.fit_transform(features)
|
||||
|
||||
explained_var = pca.explained_variance_ratio_ * 100
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
if labels is None:
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=components[:, 0],
|
||||
y=components[:, 1],
|
||||
mode="markers",
|
||||
marker=dict(size=8, opacity=0.7),
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for label in set(labels):
|
||||
mask = np.array(labels) == label
|
||||
fig.add_trace(
|
||||
go.Scatter(
|
||||
x=components[mask, 0],
|
||||
y=components[mask, 1],
|
||||
mode="markers",
|
||||
name=str(label),
|
||||
marker=dict(size=8, opacity=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title=f"PCA 2D Projection (Total Variance: {explained_var.sum():.1f}%)",
|
||||
xaxis_title=f"PC 1 ({explained_var[0]:.1f}%)",
|
||||
yaxis_title=f"PC 2 ({explained_var[1]:.1f}%)",
|
||||
hovermode="closest",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_comparison_plot(
|
||||
features_list: List[np.ndarray], names: List[str], **kwargs
|
||||
) -> go.Figure:
|
||||
"""Create a comparison plot of multiple feature sets.
|
||||
|
||||
Args:
|
||||
features_list: List of feature arrays
|
||||
names: List of names for each feature set
|
||||
**kwargs: Additional histogram arguments
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
fig = make_subplots(rows=1, cols=len(features_list), subplot_titles=names)
|
||||
|
||||
for i, features in enumerate(features_list, 1):
|
||||
fig.add_trace(
|
||||
go.Histogram(
|
||||
x=features.flatten(),
|
||||
name=names[i - 1],
|
||||
showlegend=False,
|
||||
**kwargs,
|
||||
),
|
||||
row=1,
|
||||
col=i,
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Feature Distribution Comparison",
|
||||
hovermode="x unified",
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def save_figure(fig: go.Figure, path: str, format: str = "html") -> None:
|
||||
"""Save figure to file.
|
||||
|
||||
Args:
|
||||
fig: Plotly Figure object
|
||||
path: Output file path (without extension)
|
||||
format: Output format ('html', 'png', 'json')
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if format == "html":
|
||||
fig.write_html(str(path) + ".html", include_plotlyjs="cdn")
|
||||
elif format == "png":
|
||||
fig.write_image(str(path) + ".png", scale=2)
|
||||
elif format == "json":
|
||||
fig.write_json(str(path) + ".json")
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
|
||||
def apply_theme(fig: go.Figure, theme: str = "plotly_white") -> go.Figure:
|
||||
"""Apply a theme to the figure.
|
||||
|
||||
Args:
|
||||
fig: Plotly Figure object
|
||||
theme: Theme name
|
||||
|
||||
Returns:
|
||||
Updated Plotly Figure object
|
||||
"""
|
||||
fig.update_layout(template=theme)
|
||||
return fig
|
||||
Reference in New Issue
Block a user