feat(feature-compressor): add DINOv2 feature extraction and compression pipeline

This commit is contained in:
2026-01-31 10:33:37 +08:00
parent f9a359fc28
commit 1454647aa6
22 changed files with 1486 additions and 16 deletions

View File

@@ -0,0 +1,7 @@
"""Core compression, extraction, and visualization modules."""
from .compressor import PoolNetCompressor
from .extractor import DINOv2FeatureExtractor
from .visualizer import FeatureVisualizer
__all__ = ["PoolNetCompressor", "DINOv2FeatureExtractor", "FeatureVisualizer"]

View File

@@ -0,0 +1,135 @@
"""Feature compression module with attention-based pooling and MLP."""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PoolNetCompressor(nn.Module):
"""Pool + Network feature compressor for DINOv2 embeddings.
Combines attention-based Top-K token pooling with a 2-layer MLP to compress
DINOv2's last_hidden_state from [batch, seq_len, hidden_dim] to [batch, compression_dim].
Args:
input_dim: Input feature dimension (e.g., 1024 for DINOv2-large)
compression_dim: Output feature dimension (default: 256)
top_k_ratio: Ratio of tokens to keep via attention pooling (default: 0.5)
hidden_ratio: Hidden layer dimension as multiple of compression_dim (default: 2.0)
dropout_rate: Dropout probability (default: 0.1)
use_residual: Whether to use residual connection (default: True)
device: Device to place model on ('auto', 'cpu', or 'cuda')
"""
def __init__(
self,
input_dim: int,
compression_dim: int = 256,
top_k_ratio: float = 0.5,
hidden_ratio: float = 2.0,
dropout_rate: float = 0.1,
use_residual: bool = True,
device: str = "auto",
):
super().__init__()
self.input_dim = input_dim
self.compression_dim = compression_dim
self.top_k_ratio = top_k_ratio
self.use_residual = use_residual
# Attention mechanism for token selection
self.attention = nn.Sequential(
nn.Linear(input_dim, input_dim // 4),
nn.Tanh(),
nn.Linear(input_dim // 4, 1),
)
# Compression network: 2-layer MLP
hidden_dim = int(compression_dim * hidden_ratio)
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, compression_dim),
)
# Residual projection if dimensions don't match
if use_residual and input_dim != compression_dim:
self.residual_proj = nn.Linear(input_dim, compression_dim)
else:
self.residual_proj = None
# Set device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.to(self.device)
def _compute_attention_scores(self, x: torch.Tensor) -> torch.Tensor:
"""Compute attention scores for each token.
Args:
x: Input tensor [batch, seq_len, input_dim]
Returns:
Attention scores [batch, seq_len, 1]
"""
scores = self.attention(x) # [batch, seq_len, 1]
return scores.squeeze(-1) # [batch, seq_len]
def _apply_pooling(self, x: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
"""Apply Top-K attention pooling to select important tokens.
Args:
x: Input tensor [batch, seq_len, input_dim]
scores: Attention scores [batch, seq_len]
Returns:
Pooled features [batch, k, input_dim] where k = ceil(seq_len * top_k_ratio)
"""
batch_size, seq_len, _ = x.shape
k = max(1, int(seq_len * self.top_k_ratio))
# Get top-k indices
top_k_values, top_k_indices = torch.topk(scores, k=k, dim=-1) # [batch, k]
# Select top-k tokens
batch_indices = (
torch.arange(batch_size, device=x.device).unsqueeze(1).expand(-1, k)
)
pooled = x[batch_indices, top_k_indices, :] # [batch, k, input_dim]
return pooled
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through compressor.
Args:
x: Input features [batch, seq_len, input_dim]
Returns:
Compressed features [batch, compression_dim]
"""
# Compute attention scores
scores = self._compute_attention_scores(x)
# Apply Top-K pooling
pooled = self._apply_pooling(x, scores)
# Average pool over selected tokens to get [batch, input_dim]
pooled = pooled.mean(dim=1) # [batch, input_dim]
# Apply compression network
compressed = self.net(pooled) # [batch, compression_dim]
# Add residual connection if enabled
if self.use_residual:
if self.residual_proj is not None:
residual = self.residual_proj(pooled)
else:
residual = pooled[:, : self.compression_dim]
compressed = compressed + residual
return compressed

View File

@@ -0,0 +1,236 @@
"""DINOv2 feature extraction and compression pipeline."""
import time
from pathlib import Path
from typing import Dict, List, Optional
import torch
import yaml
from transformers import AutoImageProcessor, AutoModel
from ...configs.config import Config, get_default_config
from ..utils.image_utils import load_image, preprocess_image
from .compressor import PoolNetCompressor
class DINOv2FeatureExtractor:
"""End-to-end DINOv2 feature extraction with compression.
Loads DINOv2 model, extracts last_hidden_state features,
and applies PoolNetCompressor for dimensionality reduction.
Args:
config_path: Path to YAML configuration file
device: Device to use ('auto', 'cpu', or 'cuda')
"""
def __init__(self, config_path: Optional[str] = None, device: str = "auto"):
self.config = self._load_config(config_path)
# Set device
if device == "auto":
device = self.config.get("model", {}).get("device", "auto")
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Load DINOv2 model and processor
model_name = self.config.get("model", {}).get("name", "facebook/dinov2-large")
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
# Initialize compressor
model_config = self.config.get("model", {})
self.compressor = PoolNetCompressor(
input_dim=self.model.config.hidden_size,
compression_dim=model_config.get("compression_dim", 256),
top_k_ratio=model_config.get("top_k_ratio", 0.5),
hidden_ratio=model_config.get("hidden_ratio", 2.0),
dropout_rate=model_config.get("dropout_rate", 0.1),
use_residual=model_config.get("use_residual", True),
device=str(self.device),
)
def _load_config(self, config_path: Optional[str] = None) -> Dict:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration dictionary
"""
if config_path is None:
return get_default_config(Config.FEATURE_COMPRESSOR)
with open(config_path) as f:
return yaml.safe_load(f)
def _extract_dinov2_features(self, images: List) -> torch.Tensor:
"""Extract DINOv2 last_hidden_state features.
Args:
images: List of PIL Images
Returns:
last_hidden_state [batch, seq_len, hidden_dim]
"""
with torch.no_grad():
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
features = outputs.last_hidden_state
return features
def _compress_features(self, features: torch.Tensor) -> torch.Tensor:
"""Compress features using PoolNetCompressor.
Args:
features: [batch, seq_len, hidden_dim]
Returns:
compressed [batch, compression_dim]
"""
with torch.no_grad():
compressed = self.compressor(features)
return compressed
def process_image(
self, image_path: str, visualize: bool = False
) -> Dict[str, object]:
"""Process a single image and extract compressed features.
Args:
image_path: Path to image file
visualize: Whether to generate visualizations
Returns:
Dictionary with original_features, compressed_features, metadata
"""
start_time = time.time()
# Load and preprocess image
image = load_image(image_path)
image = preprocess_image(image, size=224)
# Extract DINOv2 features
original_features = self._extract_dinov2_features([image])
# Compute feature stats for compression ratio
original_dim = original_features.shape[-1]
compressed_dim = self.compressor.compression_dim
compression_ratio = original_dim / compressed_dim
# Compress features
compressed_features = self._compress_features(original_features)
# Get pooled features (before compression) for analysis
pooled_features = self.compressor._apply_pooling(
original_features,
self.compressor._compute_attention_scores(original_features),
)
pooled_features = pooled_features.mean(dim=1)
# Compute feature norm
feature_norm = torch.norm(compressed_features, p=2, dim=-1).mean().item()
processing_time = time.time() - start_time
# Build result dictionary
result = {
"original_features": original_features.cpu(),
"compressed_features": compressed_features.cpu(),
"pooled_features": pooled_features.cpu(),
"metadata": {
"image_path": str(image_path),
"compression_ratio": compression_ratio,
"processing_time": processing_time,
"feature_norm": feature_norm,
"device": str(self.device),
"model_name": self.config.get("model", {}).get("name"),
},
}
return result
def process_batch(
self, image_dir: str, batch_size: int = 8, save_features: bool = True
) -> List[Dict[str, object]]:
"""Process multiple images in batches.
Args:
image_dir: Directory containing images
batch_size: Number of images per batch
save_features: Whether to save features to disk
Returns:
List of result dictionaries, one per image
"""
image_dir = Path(image_dir)
image_files = sorted(image_dir.glob("*.*"))
results = []
# Process in batches
for i in range(0, len(image_files), batch_size):
batch_files = image_files[i : i + batch_size]
# Load and preprocess batch
images = [preprocess_image(load_image(f), size=224) for f in batch_files]
# Extract features for batch
original_features = self._extract_dinov2_features(images)
compressed_features = self._compress_features(original_features)
# Create individual results
for j, file_path in enumerate(batch_files):
pooled_features = self.compressor._apply_pooling(
original_features[j : j + 1],
self.compressor._compute_attention_scores(
original_features[j : j + 1]
),
).mean(dim=1)
result = {
"original_features": original_features[j : j + 1].cpu(),
"compressed_features": compressed_features[j : j + 1].cpu(),
"pooled_features": pooled_features.cpu(),
"metadata": {
"image_path": str(file_path),
"compression_ratio": original_features.shape[-1]
/ self.compressor.compression_dim,
"processing_time": 0.0,
"feature_norm": torch.norm(
compressed_features[j : j + 1], p=2, dim=-1
)
.mean()
.item(),
"device": str(self.device),
"model_name": self.config.get("model", {}).get("name"),
},
}
results.append(result)
# Save features if requested
if save_features:
output_dir = Path(
self.config.get("output", {}).get("directory", "./outputs")
)
# Resolve relative to project root
if not output_dir.is_absolute():
output_dir = Path(__file__).parent.parent.parent / output_dir
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"{file_path.stem}_features.json"
from ..utils.feature_utils import save_features_to_json
save_features_to_json(
result["compressed_features"],
output_path,
result["metadata"],
)
return results

View File

@@ -0,0 +1,178 @@
"""Feature visualization using Plotly."""
import os
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import yaml
from plotly.graph_objs import Figure
from ..utils.plot_utils import (
apply_theme,
create_comparison_plot,
create_histogram,
create_pca_scatter_2d,
save_figure,
)
class FeatureVisualizer:
"""Visualize DINOv2 features with interactive Plotly charts.
Supports histograms, PCA projections, and feature comparisons
with multiple export formats.
Args:
config_path: Path to YAML configuration file
"""
def __init__(self, config_path: Optional[str] = None):
self.config = self._load_config(config_path)
def _load_config(self, config_path: Optional[str] = None) -> dict:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration dictionary
"""
if config_path is None:
config_path = (
Path(__file__).parent.parent.parent
/ "configs"
/ "feature_compressor.yaml"
)
with open(config_path) as f:
return yaml.safe_load(f)
def plot_histogram(self, features: torch.Tensor, title: str = None) -> object:
"""Plot histogram of feature values.
Args:
features: Feature tensor [batch, dim]
title: Plot title
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_histogram(features_np, title=title)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
)
return fig
def plot_pca_2d(self, features: torch.Tensor, labels: List = None) -> Figure:
"""Plot 2D PCA projection of features.
Args:
features: Feature tensor [n_samples, dim]
labels: Optional labels for coloring
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
viz_config = self.config.get("visualization", {})
fig = create_pca_scatter_2d(features_np, labels=labels)
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_traces(
marker=dict(
size=viz_config.get("point_size", 8),
colorscale=viz_config.get("color_scale", "viridis"),
)
)
fig.update_layout(
width=viz_config.get("fig_width", 900),
height=viz_config.get("fig_height", 600),
)
return fig
def plot_comparison(
self, features_list: List[torch.Tensor], names: List[str]
) -> object:
"""Plot comparison of multiple feature sets.
Args:
features_list: List of feature tensors
names: Names for each feature set
Returns:
Plotly Figure object
"""
features_np_list = [f.cpu().numpy() for f in features_list]
fig = create_comparison_plot(features_np_list, names)
viz_config = self.config.get("visualization", {})
fig = apply_theme(fig, viz_config.get("plot_theme", "plotly_white"))
fig.update_layout(
width=viz_config.get("fig_width", 900) * len(features_list),
height=viz_config.get("fig_height", 600),
)
return fig
def generate_report(self, results: List[dict], output_dir: str) -> List[str]:
"""Generate full feature analysis report.
Args:
results: List of extractor results
output_dir: Directory to save visualizations
Returns:
List of generated file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
generated_files = []
# Extract all compressed features
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
# Create histogram
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
hist_path = output_dir / "feature_histogram"
self.save(hist_fig, str(hist_path), formats=["html"])
generated_files.append(str(hist_path) + ".html")
# Create PCA
pca_fig = self.plot_pca_2d(all_features)
pca_path = output_dir / "feature_pca_2d"
self.save(pca_fig, str(pca_path), formats=["html", "png"])
generated_files.append(str(pca_path) + ".html")
generated_files.append(str(pca_path) + ".png")
return generated_files
def save(self, fig: object, path: str, formats: List[str] = None) -> None:
"""Save figure in multiple formats.
Args:
fig: Plotly Figure object
path: Output file path (without extension)
formats: List of formats to export
"""
if formats is None:
formats = ["html"]
output_config = self.config.get("output", {})
for fmt in formats:
if fmt == "png":
save_figure(fig, path, format="png")
else:
save_figure(fig, path, format=fmt)