mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
feat(feature-compressor): add DINOv2 feature extraction and compression pipeline
This commit is contained in:
7
mini-nav/feature_compressor/core/__init__.py
Normal file
7
mini-nav/feature_compressor/core/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Core compression, extraction, and visualization modules."""
|
||||
|
||||
from .compressor import PoolNetCompressor
|
||||
from .extractor import DINOv2FeatureExtractor
|
||||
from .visualizer import FeatureVisualizer
|
||||
|
||||
__all__ = ["PoolNetCompressor", "DINOv2FeatureExtractor", "FeatureVisualizer"]
|
||||
135
mini-nav/feature_compressor/core/compressor.py
Normal file
135
mini-nav/feature_compressor/core/compressor.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Feature compression module with attention-based pooling and MLP."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PoolNetCompressor(nn.Module):
|
||||
"""Pool + Network feature compressor for DINOv2 embeddings.
|
||||
|
||||
Combines attention-based Top-K token pooling with a 2-layer MLP to compress
|
||||
DINOv2's last_hidden_state from [batch, seq_len, hidden_dim] to [batch, compression_dim].
|
||||
|
||||
Args:
|
||||
input_dim: Input feature dimension (e.g., 1024 for DINOv2-large)
|
||||
compression_dim: Output feature dimension (default: 256)
|
||||
top_k_ratio: Ratio of tokens to keep via attention pooling (default: 0.5)
|
||||
hidden_ratio: Hidden layer dimension as multiple of compression_dim (default: 2.0)
|
||||
dropout_rate: Dropout probability (default: 0.1)
|
||||
use_residual: Whether to use residual connection (default: True)
|
||||
device: Device to place model on ('auto', 'cpu', or 'cuda')
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
compression_dim: int = 256,
|
||||
top_k_ratio: float = 0.5,
|
||||
hidden_ratio: float = 2.0,
|
||||
dropout_rate: float = 0.1,
|
||||
use_residual: bool = True,
|
||||
device: str = "auto",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.compression_dim = compression_dim
|
||||
self.top_k_ratio = top_k_ratio
|
||||
self.use_residual = use_residual
|
||||
|
||||
# Attention mechanism for token selection
|
||||
self.attention = nn.Sequential(
|
||||
nn.Linear(input_dim, input_dim // 4),
|
||||
nn.Tanh(),
|
||||
nn.Linear(input_dim // 4, 1),
|
||||
)
|
||||
|
||||
# Compression network: 2-layer MLP
|
||||
hidden_dim = int(compression_dim * hidden_ratio)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_dim, compression_dim),
|
||||
)
|
||||
|
||||
# Residual projection if dimensions don't match
|
||||
if use_residual and input_dim != compression_dim:
|
||||
self.residual_proj = nn.Linear(input_dim, compression_dim)
|
||||
else:
|
||||
self.residual_proj = None
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = torch.device(device)
|
||||
self.to(self.device)
|
||||
|
||||
def _compute_attention_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute attention scores for each token.
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch, seq_len, input_dim]
|
||||
|
||||
Returns:
|
||||
Attention scores [batch, seq_len, 1]
|
||||
"""
|
||||
scores = self.attention(x) # [batch, seq_len, 1]
|
||||
return scores.squeeze(-1) # [batch, seq_len]
|
||||
|
||||
def _apply_pooling(self, x: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Top-K attention pooling to select important tokens.
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch, seq_len, input_dim]
|
||||
scores: Attention scores [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
Pooled features [batch, k, input_dim] where k = ceil(seq_len * top_k_ratio)
|
||||
"""
|
||||
batch_size, seq_len, _ = x.shape
|
||||
k = max(1, int(seq_len * self.top_k_ratio))
|
||||
|
||||
# Get top-k indices
|
||||
top_k_values, top_k_indices = torch.topk(scores, k=k, dim=-1) # [batch, k]
|
||||
|
||||
# Select top-k tokens
|
||||
batch_indices = (
|
||||
torch.arange(batch_size, device=x.device).unsqueeze(1).expand(-1, k)
|
||||
)
|
||||
pooled = x[batch_indices, top_k_indices, :] # [batch, k, input_dim]
|
||||
|
||||
return pooled
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through compressor.
|
||||
|
||||
Args:
|
||||
x: Input features [batch, seq_len, input_dim]
|
||||
|
||||
Returns:
|
||||
Compressed features [batch, compression_dim]
|
||||
"""
|
||||
# Compute attention scores
|
||||
scores = self._compute_attention_scores(x)
|
||||
|
||||
# Apply Top-K pooling
|
||||
pooled = self._apply_pooling(x, scores)
|
||||
|
||||
# Average pool over selected tokens to get [batch, input_dim]
|
||||
pooled = pooled.mean(dim=1) # [batch, input_dim]
|
||||
|
||||
# Apply compression network
|
||||
compressed = self.net(pooled) # [batch, compression_dim]
|
||||
|
||||
# Add residual connection if enabled
|
||||
if self.use_residual:
|
||||
if self.residual_proj is not None:
|
||||
residual = self.residual_proj(pooled)
|
||||
else:
|
||||
residual = pooled[:, : self.compression_dim]
|
||||
compressed = compressed + residual
|
||||
|
||||
return compressed
|
||||
236
mini-nav/feature_compressor/core/extractor.py
Normal file
236
mini-nav/feature_compressor/core/extractor.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""DINOv2 feature extraction and compression pipeline."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from ...configs.config import Config, get_default_config
|
||||
from ..utils.image_utils import load_image, preprocess_image
|
||||
from .compressor import PoolNetCompressor
|
||||
|
||||
|
||||
class DINOv2FeatureExtractor:
|
||||
"""End-to-end DINOv2 feature extraction with compression.
|
||||
|
||||
Loads DINOv2 model, extracts last_hidden_state features,
|
||||
and applies PoolNetCompressor for dimensionality reduction.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
device: Device to use ('auto', 'cpu', or 'cuda')
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None, device: str = "auto"):
|
||||
self.config = self._load_config(config_path)
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = self.config.get("model", {}).get("device", "auto")
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = torch.device(device)
|
||||
|
||||
# Load DINOv2 model and processor
|
||||
model_name = self.config.get("model", {}).get("name", "facebook/dinov2-large")
|
||||
self.processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Initialize compressor
|
||||
model_config = self.config.get("model", {})
|
||||
self.compressor = PoolNetCompressor(
|
||||
input_dim=self.model.config.hidden_size,
|
||||
compression_dim=model_config.get("compression_dim", 256),
|
||||
top_k_ratio=model_config.get("top_k_ratio", 0.5),
|
||||
hidden_ratio=model_config.get("hidden_ratio", 2.0),
|
||||
dropout_rate=model_config.get("dropout_rate", 0.1),
|
||||
use_residual=model_config.get("use_residual", True),
|
||||
device=str(self.device),
|
||||
)
|
||||
|
||||
def _load_config(self, config_path: Optional[str] = None) -> Dict:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file, or None for default
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
"""
|
||||
if config_path is None:
|
||||
return get_default_config(Config.FEATURE_COMPRESSOR)
|
||||
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def _extract_dinov2_features(self, images: List) -> torch.Tensor:
|
||||
"""Extract DINOv2 last_hidden_state features.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images
|
||||
|
||||
Returns:
|
||||
last_hidden_state [batch, seq_len, hidden_dim]
|
||||
"""
|
||||
with torch.no_grad():
|
||||
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
||||
outputs = self.model(**inputs)
|
||||
features = outputs.last_hidden_state
|
||||
|
||||
return features
|
||||
|
||||
def _compress_features(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compress features using PoolNetCompressor.
|
||||
|
||||
Args:
|
||||
features: [batch, seq_len, hidden_dim]
|
||||
|
||||
Returns:
|
||||
compressed [batch, compression_dim]
|
||||
"""
|
||||
with torch.no_grad():
|
||||
compressed = self.compressor(features)
|
||||
|
||||
return compressed
|
||||
|
||||
def process_image(
|
||||
self, image_path: str, visualize: bool = False
|
||||
) -> Dict[str, object]:
|
||||
"""Process a single image and extract compressed features.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
visualize: Whether to generate visualizations
|
||||
|
||||
Returns:
|
||||
Dictionary with original_features, compressed_features, metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Load and preprocess image
|
||||
image = load_image(image_path)
|
||||
image = preprocess_image(image, size=224)
|
||||
|
||||
# Extract DINOv2 features
|
||||
original_features = self._extract_dinov2_features([image])
|
||||
|
||||
# Compute feature stats for compression ratio
|
||||
original_dim = original_features.shape[-1]
|
||||
compressed_dim = self.compressor.compression_dim
|
||||
compression_ratio = original_dim / compressed_dim
|
||||
|
||||
# Compress features
|
||||
compressed_features = self._compress_features(original_features)
|
||||
|
||||
# Get pooled features (before compression) for analysis
|
||||
pooled_features = self.compressor._apply_pooling(
|
||||
original_features,
|
||||
self.compressor._compute_attention_scores(original_features),
|
||||
)
|
||||
pooled_features = pooled_features.mean(dim=1)
|
||||
|
||||
# Compute feature norm
|
||||
feature_norm = torch.norm(compressed_features, p=2, dim=-1).mean().item()
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Build result dictionary
|
||||
result = {
|
||||
"original_features": original_features.cpu(),
|
||||
"compressed_features": compressed_features.cpu(),
|
||||
"pooled_features": pooled_features.cpu(),
|
||||
"metadata": {
|
||||
"image_path": str(image_path),
|
||||
"compression_ratio": compression_ratio,
|
||||
"processing_time": processing_time,
|
||||
"feature_norm": feature_norm,
|
||||
"device": str(self.device),
|
||||
"model_name": self.config.get("model", {}).get("name"),
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def process_batch(
|
||||
self, image_dir: str, batch_size: int = 8, save_features: bool = True
|
||||
) -> List[Dict[str, object]]:
|
||||
"""Process multiple images in batches.
|
||||
|
||||
Args:
|
||||
image_dir: Directory containing images
|
||||
batch_size: Number of images per batch
|
||||
save_features: Whether to save features to disk
|
||||
|
||||
Returns:
|
||||
List of result dictionaries, one per image
|
||||
"""
|
||||
image_dir = Path(image_dir)
|
||||
image_files = sorted(image_dir.glob("*.*"))
|
||||
|
||||
results = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(image_files), batch_size):
|
||||
batch_files = image_files[i : i + batch_size]
|
||||
|
||||
# Load and preprocess batch
|
||||
images = [preprocess_image(load_image(f), size=224) for f in batch_files]
|
||||
|
||||
# Extract features for batch
|
||||
original_features = self._extract_dinov2_features(images)
|
||||
compressed_features = self._compress_features(original_features)
|
||||
|
||||
# Create individual results
|
||||
for j, file_path in enumerate(batch_files):
|
||||
pooled_features = self.compressor._apply_pooling(
|
||||
original_features[j : j + 1],
|
||||
self.compressor._compute_attention_scores(
|
||||
original_features[j : j + 1]
|
||||
),
|
||||
).mean(dim=1)
|
||||
|
||||
result = {
|
||||
"original_features": original_features[j : j + 1].cpu(),
|
||||
"compressed_features": compressed_features[j : j + 1].cpu(),
|
||||
"pooled_features": pooled_features.cpu(),
|
||||
"metadata": {
|
||||
"image_path": str(file_path),
|
||||
"compression_ratio": original_features.shape[-1]
|
||||
/ self.compressor.compression_dim,
|
||||
"processing_time": 0.0,
|
||||
"feature_norm": torch.norm(
|
||||
compressed_features[j : j + 1], p=2, dim=-1
|
||||
)
|
||||
.mean()
|
||||
.item(),
|
||||
"device": str(self.device),
|
||||
"model_name": self.config.get("model", {}).get("name"),
|
||||
},
|
||||
}
|
||||
|
||||
results.append(result)
|
||||
|
||||
# Save features if requested
|
||||
if save_features:
|
||||
output_dir = Path(
|
||||
self.config.get("output", {}).get("directory", "./outputs")
|
||||
)
|
||||
# Resolve relative to project root
|
||||
if not output_dir.is_absolute():
|
||||
output_dir = Path(__file__).parent.parent.parent / output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_path = output_dir / f"{file_path.stem}_features.json"
|
||||
from ..utils.feature_utils import save_features_to_json
|
||||
|
||||
save_features_to_json(
|
||||
result["compressed_features"],
|
||||
output_path,
|
||||
result["metadata"],
|
||||
)
|
||||
|
||||
return results
|
||||
178
mini-nav/feature_compressor/core/visualizer.py
Normal file
178
mini-nav/feature_compressor/core/visualizer.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Feature visualization using Plotly."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from ..utils.plot_utils import (
|
||||
apply_theme,
|
||||
create_comparison_plot,
|
||||
create_histogram,
|
||||
create_pca_scatter_2d,
|
||||
save_figure,
|
||||
)
|
||||
|
||||
|
||||
class FeatureVisualizer:
|
||||
"""Visualize DINOv2 features with interactive Plotly charts.
|
||||
|
||||
Supports histograms, PCA projections, and feature comparisons
|
||||
with multiple export formats.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config = self._load_config(config_path)
|
||||
|
||||
def _load_config(self, config_path: Optional[str] = None) -> dict:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file, or None for default
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
"""
|
||||
if config_path is None:
|
||||
config_path = (
|
||||
Path(__file__).parent.parent.parent
|
||||
/ "configs"
|
||||
/ "feature_compressor.yaml"
|
||||
)
|
||||
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
|
||||
"""Plot histogram of feature values.
|
||||
|
||||
Args:
|
||||
features: Feature tensor [batch, dim]
|
||||
title: Plot title
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
fig = create_histogram(features_np, title=title)
|
||||
|
||||
viz_config = self.config.get("visualization", {})
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
|
||||
"""Plot 2D PCA projection of features.
|
||||
|
||||
Args:
|
||||
features: Feature tensor [n_samples, dim]
|
||||
labels: Optional labels for coloring
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
viz_config = self.config.get("visualization", {})
|
||||
|
||||
fig = create_pca_scatter_2d(features_np, labels=labels)
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig.update_traces(
|
||||
marker=dict(
|
||||
size=viz_config.get("point_size", 8),
|
||||
colorscale=viz_config.get("color_scale", "viridis"),
|
||||
)
|
||||
)
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_comparison(
|
||||
self, features_list: List[torch.Tensor], names: List[str]
|
||||
) -> object:
|
||||
"""Plot comparison of multiple feature sets.
|
||||
|
||||
Args:
|
||||
features_list: List of feature tensors
|
||||
names: Names for each feature set
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np_list = [f.cpu().numpy() for f in features_list]
|
||||
|
||||
fig = create_comparison_plot(features_np_list, names)
|
||||
|
||||
viz_config = self.config.get("visualization", {})
|
||||
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
|
||||
fig.update_layout(
|
||||
width=viz_config.get("fig_width", 900) * len(features_list),
|
||||
height=viz_config.get("fig_height", 600),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
|
||||
"""Generate full feature analysis report.
|
||||
|
||||
Args:
|
||||
results: List of extractor results
|
||||
output_dir: Directory to save visualizations
|
||||
|
||||
Returns:
|
||||
List of generated file paths
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generated_files = []
|
||||
|
||||
# Extract all compressed features
|
||||
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
|
||||
|
||||
# Create histogram
|
||||
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
|
||||
hist_path = output_dir / "feature_histogram"
|
||||
self.save(hist_fig, str(hist_path), formats=["html"])
|
||||
generated_files.append(str(hist_path) + ".html")
|
||||
|
||||
# Create PCA
|
||||
pca_fig = self.plot_pca_2d(all_features)
|
||||
pca_path = output_dir / "feature_pca_2d"
|
||||
self.save(pca_fig, str(pca_path), formats=["html", "png"])
|
||||
generated_files.append(str(pca_path) + ".html")
|
||||
generated_files.append(str(pca_path) + ".png")
|
||||
|
||||
return generated_files
|
||||
|
||||
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
|
||||
"""Save figure in multiple formats.
|
||||
|
||||
Args:
|
||||
fig: Plotly Figure object
|
||||
path: Output file path (without extension)
|
||||
formats: List of formats to export
|
||||
"""
|
||||
if formats is None:
|
||||
formats = ["html"]
|
||||
|
||||
output_config = self.config.get("output", {})
|
||||
|
||||
for fmt in formats:
|
||||
if fmt == "png":
|
||||
save_figure(fig, path, format="png")
|
||||
else:
|
||||
save_figure(fig, path, format=fmt)
|
||||
Reference in New Issue
Block a user