mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): integrate image upload with similarity search
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from configs import cfg_manager
|
||||
@@ -31,7 +32,7 @@ class DatabaseManager:
|
||||
|
||||
# 初始化数据库与表格
|
||||
self.db = lancedb.connect(db_path)
|
||||
if "default" not in self.db.table_names():
|
||||
if "default" not in self.db.list_tables().tables:
|
||||
self.table = self.db.create_table("default", schema=db_schema)
|
||||
else:
|
||||
self.table = self.db.open_table("default")
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
"""Feature visualization using Plotly."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
|
||||
from plotly.graph_objs import Figure
|
||||
|
||||
from ..utils.plot_utils import (
|
||||
apply_theme,
|
||||
create_comparison_plot,
|
||||
create_histogram,
|
||||
create_pca_scatter_2d,
|
||||
save_figure,
|
||||
)
|
||||
|
||||
|
||||
class FeatureVisualizer:
|
||||
"""Visualize DINOv2 features with interactive Plotly charts.
|
||||
|
||||
Supports histograms, PCA projections, and feature comparisons
|
||||
with multiple export formats.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config: FeatureCompressorConfig = self._load_config(config_path)
|
||||
|
||||
def _load_config(
|
||||
self, config_path: Optional[str] = None
|
||||
) -> FeatureCompressorConfig:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file, or None for default
|
||||
|
||||
Returns:
|
||||
Configuration Pydantic model
|
||||
"""
|
||||
if config_path is None:
|
||||
return cfg_manager.get()
|
||||
else:
|
||||
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
||||
|
||||
def plot_histogram(
|
||||
self, features: torch.Tensor, title: Optional[str] = None
|
||||
) -> Figure:
|
||||
"""Plot histogram of feature values.
|
||||
|
||||
Args:
|
||||
features: Feature tensor [batch, dim]
|
||||
title: Plot title
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
fig = create_histogram(
|
||||
features_np, title="Feature Histogram" if title is None else title
|
||||
)
|
||||
|
||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
||||
fig.update_layout(
|
||||
width=self.config.visualization.fig_width,
|
||||
height=self.config.visualization.fig_height,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_pca_2d(
|
||||
self, features: torch.Tensor, labels: Optional[List] = None
|
||||
) -> Figure:
|
||||
"""Plot 2D PCA projection of features.
|
||||
|
||||
Args:
|
||||
features: Feature tensor [n_samples, dim]
|
||||
labels: Optional labels for coloring
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np = features.cpu().numpy()
|
||||
|
||||
fig = create_pca_scatter_2d(
|
||||
features_np,
|
||||
labels=[i for i in range(len(features_np))] if labels is None else labels,
|
||||
)
|
||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
||||
fig.update_traces(
|
||||
marker=dict(
|
||||
size=self.config.visualization.point_size,
|
||||
colorscale=self.config.visualization.color_scale,
|
||||
)
|
||||
)
|
||||
fig.update_layout(
|
||||
width=self.config.visualization.fig_width,
|
||||
height=self.config.visualization.fig_height,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_comparison(
|
||||
self, features_list: List[torch.Tensor], names: List[str]
|
||||
) -> object:
|
||||
"""Plot comparison of multiple feature sets.
|
||||
|
||||
Args:
|
||||
features_list: List of feature tensors
|
||||
names: Names for each feature set
|
||||
|
||||
Returns:
|
||||
Plotly Figure object
|
||||
"""
|
||||
features_np_list = [f.cpu().numpy() for f in features_list]
|
||||
|
||||
fig = create_comparison_plot(features_np_list, names)
|
||||
|
||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
||||
fig.update_layout(
|
||||
width=self.config.visualization.fig_width * len(features_list),
|
||||
height=self.config.visualization.fig_height,
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def generate_report(
|
||||
self, results: List[dict], output_dir: Union[str, Path]
|
||||
) -> List[str]:
|
||||
"""Generate full feature analysis report.
|
||||
|
||||
Args:
|
||||
results: List of extractor results
|
||||
output_dir: Directory to save visualizations
|
||||
|
||||
Returns:
|
||||
List of generated file paths
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generated_files = []
|
||||
|
||||
# Extract all compressed features
|
||||
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
|
||||
|
||||
# Create histogram
|
||||
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
|
||||
hist_path = output_dir / "feature_histogram"
|
||||
self.save(hist_fig, str(hist_path), formats=["html"])
|
||||
generated_files.append(str(hist_path) + ".html")
|
||||
|
||||
# Create PCA
|
||||
pca_fig = self.plot_pca_2d(all_features)
|
||||
pca_path = output_dir / "feature_pca_2d"
|
||||
self.save(pca_fig, str(pca_path), formats=["html", "png"])
|
||||
generated_files.append(str(pca_path) + ".html")
|
||||
generated_files.append(str(pca_path) + ".png")
|
||||
|
||||
return generated_files
|
||||
|
||||
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
|
||||
"""Save figure in multiple formats.
|
||||
|
||||
Args:
|
||||
fig: Plotly Figure object
|
||||
path: Output file path (without extension)
|
||||
formats: List of formats to export
|
||||
"""
|
||||
if formats is None:
|
||||
formats = ["html"]
|
||||
|
||||
for fmt in formats:
|
||||
if fmt == "png":
|
||||
save_figure(fig, path, format="png")
|
||||
else:
|
||||
save_figure(fig, path, format=fmt)
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
from database import db_manager
|
||||
from datasets import load_dataset
|
||||
@@ -104,7 +103,9 @@ class FeatureRetrieval:
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series:
|
||||
def extract_single_image_feature(
|
||||
self, image: Union[Image.Image, Any]
|
||||
) -> List[float]:
|
||||
"""Extract feature from a single image without storing to database.
|
||||
|
||||
Args:
|
||||
@@ -128,8 +129,8 @@ class FeatureRetrieval:
|
||||
cls_token = feats[:, 0] # [1, D]
|
||||
cls_token = cast(torch.Tensor, cls_token)
|
||||
|
||||
# 返回 Polars Series
|
||||
return pl.Series("feature", cls_token.cpu().squeeze(0).tolist())
|
||||
# 返回 CLS List
|
||||
return cls_token.cpu().squeeze(0).tolist()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,97 +1,35 @@
|
||||
"""Tests for FeatureVisualizer module."""
|
||||
"""Tests for visualizer app image upload similarity search."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from feature_compressor.core.visualizer import FeatureVisualizer
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestFeatureVisualizer:
|
||||
"""Test suite for FeatureVisualizer class."""
|
||||
class TestImageUploadSimilaritySearch:
|
||||
"""Test suite for image upload similarity search functionality."""
|
||||
|
||||
def test_histogram_generation(self):
|
||||
"""Test histogram generation from features."""
|
||||
def test_base64_to_pil_image(self):
|
||||
"""Test conversion from base64 string to PIL Image."""
|
||||
# Create a test image
|
||||
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
||||
img = Image.fromarray(img_array)
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(20, 256)
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
fig = viz.plot_histogram(features, title="Test Histogram")
|
||||
# Add data URI prefix (as Dash provides)
|
||||
img_base64_with_prefix = f"data:image/png;base64,{img_base64}"
|
||||
|
||||
assert fig is not None
|
||||
assert "Test Histogram" in fig.layout.title.text
|
||||
# Parse base64 to PIL Image
|
||||
# Remove prefix
|
||||
base64_str = img_base64_with_prefix.split(",")[1]
|
||||
img_bytes = base64.b64decode(base64_str)
|
||||
parsed_img = Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
def test_pca_2d_generation(self):
|
||||
"""Test PCA 2D scatter plot generation."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(20, 256)
|
||||
labels = ["cat"] * 10 + ["dog"] * 10
|
||||
|
||||
fig = viz.plot_pca_2d(features, labels=labels)
|
||||
|
||||
assert fig is not None
|
||||
assert "PCA 2D" in fig.layout.title.text
|
||||
|
||||
def test_comparison_plot_generation(self):
|
||||
"""Test comparison plot generation."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features_list = [torch.randn(20, 256), torch.randn(20, 256)]
|
||||
names = ["Set A", "Set B"]
|
||||
|
||||
fig = viz.plot_comparison(features_list, names)
|
||||
|
||||
assert fig is not None
|
||||
assert "Comparison" in fig.layout.title.text
|
||||
|
||||
def test_html_export(self):
|
||||
"""Test HTML export format."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(10, 256)
|
||||
|
||||
fig = viz.plot_histogram(features)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "test_plot")
|
||||
viz.save(fig, output_path, formats=["html"])
|
||||
|
||||
assert os.path.exists(output_path + ".html")
|
||||
|
||||
def test_png_export(self):
|
||||
"""Test PNG export format."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(10, 256)
|
||||
|
||||
fig = viz.plot_histogram(features)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "test_plot")
|
||||
|
||||
# Skip PNG export if Chrome not available
|
||||
try:
|
||||
viz.save(fig, output_path, formats=["png"])
|
||||
assert os.path.exists(output_path + ".png")
|
||||
except RuntimeError as e:
|
||||
if "Chrome" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def test_json_export(self):
|
||||
"""Test JSON export format."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(10, 256)
|
||||
|
||||
fig = viz.plot_histogram(features)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "test_plot")
|
||||
viz.save(fig, output_path, formats=["json"])
|
||||
|
||||
assert os.path.exists(output_path + ".json")
|
||||
# Verify the image is valid
|
||||
assert parsed_img.size == (224, 224)
|
||||
assert parsed_img.mode == "RGB"
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import base64
|
||||
import datetime
|
||||
from typing import List, Optional, Union
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
import dash_ag_grid as dag
|
||||
import dash_mantine_components as dmc
|
||||
from dash import Dash, Input, Output, State, callback, dcc, html
|
||||
from database import db_manager
|
||||
from feature_retrieval import FeatureRetrieval
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
class APP(Dash):
|
||||
@@ -12,6 +17,9 @@ class APP(Dash):
|
||||
|
||||
_instance: Optional["APP"] = None
|
||||
|
||||
# Feature retrieval singleton
|
||||
_feature_retrieval: FeatureRetrieval
|
||||
|
||||
def __new__(cls) -> "APP":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
@@ -20,6 +28,11 @@ class APP(Dash):
|
||||
def __init__(self):
|
||||
super().__init__(__name__)
|
||||
|
||||
# Initialize FeatureRetrieval
|
||||
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
|
||||
model = AutoModel.from_pretrained("facebook/dinov2-large")
|
||||
APP._feature_retrieval = FeatureRetrieval(processor, model)
|
||||
|
||||
df = (
|
||||
db_manager.table.search()
|
||||
.select(["id", "label", "vector"])
|
||||
@@ -56,6 +69,7 @@ class APP(Dash):
|
||||
),
|
||||
html.Div(id="output-image-upload"),
|
||||
dag.AgGrid(
|
||||
id="ag-grid",
|
||||
rowData=df.to_dicts(),
|
||||
columnDefs=columnDefs,
|
||||
),
|
||||
@@ -71,6 +85,7 @@ class APP(Dash):
|
||||
|
||||
@callback(
|
||||
Output("output-image-upload", "children"),
|
||||
Output("ag-grid", "rowData"),
|
||||
Input("upload-image", "contents"),
|
||||
State("upload-image", "filename"),
|
||||
State("upload-image", "last_modified"),
|
||||
@@ -80,23 +95,51 @@ class APP(Dash):
|
||||
list_of_names: List[str],
|
||||
list_of_dates: List[int] | List[float],
|
||||
):
|
||||
def parse_contents(contents: str, filename: str, date: Union[int, float]):
|
||||
return html.Div(
|
||||
[
|
||||
html.H5(filename),
|
||||
html.H6(datetime.datetime.fromtimestamp(date)),
|
||||
# HTML images accept base64 encoded strings in the same format
|
||||
# that is supplied by the upload
|
||||
dmc.Image(src=contents),
|
||||
]
|
||||
)
|
||||
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||
"""Parse base64 string to PIL Image."""
|
||||
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
||||
base64_str = contents.split(",")[1]
|
||||
img_bytes = base64.b64decode(base64_str)
|
||||
return Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
if list_of_contents is not None:
|
||||
# Process first uploaded image for similarity search
|
||||
filename = list_of_names[0]
|
||||
uploaddate = list_of_dates[0]
|
||||
imagecontent = list_of_contents[0]
|
||||
|
||||
pil_image = parse_base64_to_pil(imagecontent)
|
||||
|
||||
# Extract feature vector using DINOv2
|
||||
feature_vector = APP._feature_retrieval.extract_single_image_feature(
|
||||
pil_image
|
||||
)
|
||||
|
||||
# Search for similar images in database
|
||||
results_df = (
|
||||
db_manager.table.search(feature_vector)
|
||||
.select(["id", "label", "vector"])
|
||||
.limit(10)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
# Convert to AgGrid row format
|
||||
row_data = results_df.to_dicts()
|
||||
|
||||
# Display uploaded images
|
||||
children = [
|
||||
parse_contents(c, n, d)
|
||||
for c, n, d in zip(list_of_contents, list_of_names, list_of_dates)
|
||||
html.H5(filename),
|
||||
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
||||
# HTML images accept base64 encoded strings in same format
|
||||
# that is supplied by the upload
|
||||
dmc.Image(src=imagecontent),
|
||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||
]
|
||||
return children
|
||||
|
||||
return children, row_data
|
||||
|
||||
# Return empty if no content
|
||||
return [], []
|
||||
|
||||
|
||||
app = APP()
|
||||
|
||||
Reference in New Issue
Block a user