mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): integrate image upload with similarity search
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from configs import cfg_manager
|
from configs import cfg_manager
|
||||||
@@ -31,7 +32,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# 初始化数据库与表格
|
# 初始化数据库与表格
|
||||||
self.db = lancedb.connect(db_path)
|
self.db = lancedb.connect(db_path)
|
||||||
if "default" not in self.db.table_names():
|
if "default" not in self.db.list_tables().tables:
|
||||||
self.table = self.db.create_table("default", schema=db_schema)
|
self.table = self.db.create_table("default", schema=db_schema)
|
||||||
else:
|
else:
|
||||||
self.table = self.db.open_table("default")
|
self.table = self.db.open_table("default")
|
||||||
|
|||||||
@@ -1,181 +0,0 @@
|
|||||||
"""Feature visualization using Plotly."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
|
|
||||||
from plotly.graph_objs import Figure
|
|
||||||
|
|
||||||
from ..utils.plot_utils import (
|
|
||||||
apply_theme,
|
|
||||||
create_comparison_plot,
|
|
||||||
create_histogram,
|
|
||||||
create_pca_scatter_2d,
|
|
||||||
save_figure,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureVisualizer:
|
|
||||||
"""Visualize DINOv2 features with interactive Plotly charts.
|
|
||||||
|
|
||||||
Supports histograms, PCA projections, and feature comparisons
|
|
||||||
with multiple export formats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_path: Path to YAML configuration file
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[str] = None):
|
|
||||||
self.config: FeatureCompressorConfig = self._load_config(config_path)
|
|
||||||
|
|
||||||
def _load_config(
|
|
||||||
self, config_path: Optional[str] = None
|
|
||||||
) -> FeatureCompressorConfig:
|
|
||||||
"""Load configuration from YAML file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_path: Path to config file, or None for default
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configuration Pydantic model
|
|
||||||
"""
|
|
||||||
if config_path is None:
|
|
||||||
return cfg_manager.get()
|
|
||||||
else:
|
|
||||||
return load_yaml(Path(config_path), FeatureCompressorConfig)
|
|
||||||
|
|
||||||
def plot_histogram(
|
|
||||||
self, features: torch.Tensor, title: Optional[str] = None
|
|
||||||
) -> Figure:
|
|
||||||
"""Plot histogram of feature values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: Feature tensor [batch, dim]
|
|
||||||
title: Plot title
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object
|
|
||||||
"""
|
|
||||||
features_np = features.cpu().numpy()
|
|
||||||
fig = create_histogram(
|
|
||||||
features_np, title="Feature Histogram" if title is None else title
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
|
||||||
fig.update_layout(
|
|
||||||
width=self.config.visualization.fig_width,
|
|
||||||
height=self.config.visualization.fig_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
def plot_pca_2d(
|
|
||||||
self, features: torch.Tensor, labels: Optional[List] = None
|
|
||||||
) -> Figure:
|
|
||||||
"""Plot 2D PCA projection of features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: Feature tensor [n_samples, dim]
|
|
||||||
labels: Optional labels for coloring
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object
|
|
||||||
"""
|
|
||||||
features_np = features.cpu().numpy()
|
|
||||||
|
|
||||||
fig = create_pca_scatter_2d(
|
|
||||||
features_np,
|
|
||||||
labels=[i for i in range(len(features_np))] if labels is None else labels,
|
|
||||||
)
|
|
||||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
|
||||||
fig.update_traces(
|
|
||||||
marker=dict(
|
|
||||||
size=self.config.visualization.point_size,
|
|
||||||
colorscale=self.config.visualization.color_scale,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
fig.update_layout(
|
|
||||||
width=self.config.visualization.fig_width,
|
|
||||||
height=self.config.visualization.fig_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
def plot_comparison(
|
|
||||||
self, features_list: List[torch.Tensor], names: List[str]
|
|
||||||
) -> object:
|
|
||||||
"""Plot comparison of multiple feature sets.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features_list: List of feature tensors
|
|
||||||
names: Names for each feature set
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plotly Figure object
|
|
||||||
"""
|
|
||||||
features_np_list = [f.cpu().numpy() for f in features_list]
|
|
||||||
|
|
||||||
fig = create_comparison_plot(features_np_list, names)
|
|
||||||
|
|
||||||
fig = apply_theme(fig, self.config.visualization.plot_theme)
|
|
||||||
fig.update_layout(
|
|
||||||
width=self.config.visualization.fig_width * len(features_list),
|
|
||||||
height=self.config.visualization.fig_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
def generate_report(
|
|
||||||
self, results: List[dict], output_dir: Union[str, Path]
|
|
||||||
) -> List[str]:
|
|
||||||
"""Generate full feature analysis report.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of extractor results
|
|
||||||
output_dir: Directory to save visualizations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of generated file paths
|
|
||||||
"""
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
generated_files = []
|
|
||||||
|
|
||||||
# Extract all compressed features
|
|
||||||
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
|
|
||||||
|
|
||||||
# Create histogram
|
|
||||||
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
|
|
||||||
hist_path = output_dir / "feature_histogram"
|
|
||||||
self.save(hist_fig, str(hist_path), formats=["html"])
|
|
||||||
generated_files.append(str(hist_path) + ".html")
|
|
||||||
|
|
||||||
# Create PCA
|
|
||||||
pca_fig = self.plot_pca_2d(all_features)
|
|
||||||
pca_path = output_dir / "feature_pca_2d"
|
|
||||||
self.save(pca_fig, str(pca_path), formats=["html", "png"])
|
|
||||||
generated_files.append(str(pca_path) + ".html")
|
|
||||||
generated_files.append(str(pca_path) + ".png")
|
|
||||||
|
|
||||||
return generated_files
|
|
||||||
|
|
||||||
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
|
|
||||||
"""Save figure in multiple formats.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fig: Plotly Figure object
|
|
||||||
path: Output file path (without extension)
|
|
||||||
formats: List of formats to export
|
|
||||||
"""
|
|
||||||
if formats is None:
|
|
||||||
formats = ["html"]
|
|
||||||
|
|
||||||
for fmt in formats:
|
|
||||||
if fmt == "png":
|
|
||||||
save_figure(fig, path, format="png")
|
|
||||||
else:
|
|
||||||
save_figure(fig, path, format=fmt)
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import polars as pl
|
|
||||||
import torch
|
import torch
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@@ -104,7 +103,9 @@ class FeatureRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series:
|
def extract_single_image_feature(
|
||||||
|
self, image: Union[Image.Image, Any]
|
||||||
|
) -> List[float]:
|
||||||
"""Extract feature from a single image without storing to database.
|
"""Extract feature from a single image without storing to database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -128,8 +129,8 @@ class FeatureRetrieval:
|
|||||||
cls_token = feats[:, 0] # [1, D]
|
cls_token = feats[:, 0] # [1, D]
|
||||||
cls_token = cast(torch.Tensor, cls_token)
|
cls_token = cast(torch.Tensor, cls_token)
|
||||||
|
|
||||||
# 返回 Polars Series
|
# 返回 CLS List
|
||||||
return pl.Series("feature", cls_token.cpu().squeeze(0).tolist())
|
return cls_token.cpu().squeeze(0).tolist()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,97 +1,35 @@
|
|||||||
"""Tests for FeatureVisualizer module."""
|
"""Tests for visualizer app image upload similarity search."""
|
||||||
|
|
||||||
import os
|
import base64
|
||||||
import tempfile
|
import io
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
from PIL import Image
|
||||||
import torch
|
|
||||||
from feature_compressor.core.visualizer import FeatureVisualizer
|
|
||||||
|
|
||||||
|
|
||||||
class TestFeatureVisualizer:
|
class TestImageUploadSimilaritySearch:
|
||||||
"""Test suite for FeatureVisualizer class."""
|
"""Test suite for image upload similarity search functionality."""
|
||||||
|
|
||||||
def test_histogram_generation(self):
|
def test_base64_to_pil_image(self):
|
||||||
"""Test histogram generation from features."""
|
"""Test conversion from base64 string to PIL Image."""
|
||||||
|
# Create a test image
|
||||||
|
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
||||||
|
img = Image.fromarray(img_array)
|
||||||
|
|
||||||
viz = FeatureVisualizer()
|
# Convert to base64
|
||||||
features = torch.randn(20, 256)
|
buffer = io.BytesIO()
|
||||||
|
img.save(buffer, format="PNG")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
fig = viz.plot_histogram(features, title="Test Histogram")
|
# Add data URI prefix (as Dash provides)
|
||||||
|
img_base64_with_prefix = f"data:image/png;base64,{img_base64}"
|
||||||
|
|
||||||
assert fig is not None
|
# Parse base64 to PIL Image
|
||||||
assert "Test Histogram" in fig.layout.title.text
|
# Remove prefix
|
||||||
|
base64_str = img_base64_with_prefix.split(",")[1]
|
||||||
|
img_bytes = base64.b64decode(base64_str)
|
||||||
|
parsed_img = Image.open(io.BytesIO(img_bytes))
|
||||||
|
|
||||||
def test_pca_2d_generation(self):
|
# Verify the image is valid
|
||||||
"""Test PCA 2D scatter plot generation."""
|
assert parsed_img.size == (224, 224)
|
||||||
|
assert parsed_img.mode == "RGB"
|
||||||
viz = FeatureVisualizer()
|
|
||||||
features = torch.randn(20, 256)
|
|
||||||
labels = ["cat"] * 10 + ["dog"] * 10
|
|
||||||
|
|
||||||
fig = viz.plot_pca_2d(features, labels=labels)
|
|
||||||
|
|
||||||
assert fig is not None
|
|
||||||
assert "PCA 2D" in fig.layout.title.text
|
|
||||||
|
|
||||||
def test_comparison_plot_generation(self):
|
|
||||||
"""Test comparison plot generation."""
|
|
||||||
|
|
||||||
viz = FeatureVisualizer()
|
|
||||||
features_list = [torch.randn(20, 256), torch.randn(20, 256)]
|
|
||||||
names = ["Set A", "Set B"]
|
|
||||||
|
|
||||||
fig = viz.plot_comparison(features_list, names)
|
|
||||||
|
|
||||||
assert fig is not None
|
|
||||||
assert "Comparison" in fig.layout.title.text
|
|
||||||
|
|
||||||
def test_html_export(self):
|
|
||||||
"""Test HTML export format."""
|
|
||||||
|
|
||||||
viz = FeatureVisualizer()
|
|
||||||
features = torch.randn(10, 256)
|
|
||||||
|
|
||||||
fig = viz.plot_histogram(features)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
output_path = os.path.join(tmpdir, "test_plot")
|
|
||||||
viz.save(fig, output_path, formats=["html"])
|
|
||||||
|
|
||||||
assert os.path.exists(output_path + ".html")
|
|
||||||
|
|
||||||
def test_png_export(self):
|
|
||||||
"""Test PNG export format."""
|
|
||||||
|
|
||||||
viz = FeatureVisualizer()
|
|
||||||
features = torch.randn(10, 256)
|
|
||||||
|
|
||||||
fig = viz.plot_histogram(features)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
output_path = os.path.join(tmpdir, "test_plot")
|
|
||||||
|
|
||||||
# Skip PNG export if Chrome not available
|
|
||||||
try:
|
|
||||||
viz.save(fig, output_path, formats=["png"])
|
|
||||||
assert os.path.exists(output_path + ".png")
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "Chrome" in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def test_json_export(self):
|
|
||||||
"""Test JSON export format."""
|
|
||||||
|
|
||||||
viz = FeatureVisualizer()
|
|
||||||
features = torch.randn(10, 256)
|
|
||||||
|
|
||||||
fig = viz.plot_histogram(features)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
output_path = os.path.join(tmpdir, "test_plot")
|
|
||||||
viz.save(fig, output_path, formats=["json"])
|
|
||||||
|
|
||||||
assert os.path.exists(output_path + ".json")
|
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Optional, Union
|
import io
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import dash_ag_grid as dag
|
import dash_ag_grid as dag
|
||||||
import dash_mantine_components as dmc
|
import dash_mantine_components as dmc
|
||||||
from dash import Dash, Input, Output, State, callback, dcc, html
|
from dash import Dash, Input, Output, State, callback, dcc, html
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
|
from feature_retrieval import FeatureRetrieval
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
|
||||||
class APP(Dash):
|
class APP(Dash):
|
||||||
@@ -12,6 +17,9 @@ class APP(Dash):
|
|||||||
|
|
||||||
_instance: Optional["APP"] = None
|
_instance: Optional["APP"] = None
|
||||||
|
|
||||||
|
# Feature retrieval singleton
|
||||||
|
_feature_retrieval: FeatureRetrieval
|
||||||
|
|
||||||
def __new__(cls) -> "APP":
|
def __new__(cls) -> "APP":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
@@ -20,6 +28,11 @@ class APP(Dash):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(__name__)
|
super().__init__(__name__)
|
||||||
|
|
||||||
|
# Initialize FeatureRetrieval
|
||||||
|
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
|
||||||
|
model = AutoModel.from_pretrained("facebook/dinov2-large")
|
||||||
|
APP._feature_retrieval = FeatureRetrieval(processor, model)
|
||||||
|
|
||||||
df = (
|
df = (
|
||||||
db_manager.table.search()
|
db_manager.table.search()
|
||||||
.select(["id", "label", "vector"])
|
.select(["id", "label", "vector"])
|
||||||
@@ -56,6 +69,7 @@ class APP(Dash):
|
|||||||
),
|
),
|
||||||
html.Div(id="output-image-upload"),
|
html.Div(id="output-image-upload"),
|
||||||
dag.AgGrid(
|
dag.AgGrid(
|
||||||
|
id="ag-grid",
|
||||||
rowData=df.to_dicts(),
|
rowData=df.to_dicts(),
|
||||||
columnDefs=columnDefs,
|
columnDefs=columnDefs,
|
||||||
),
|
),
|
||||||
@@ -71,6 +85,7 @@ class APP(Dash):
|
|||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Output("output-image-upload", "children"),
|
Output("output-image-upload", "children"),
|
||||||
|
Output("ag-grid", "rowData"),
|
||||||
Input("upload-image", "contents"),
|
Input("upload-image", "contents"),
|
||||||
State("upload-image", "filename"),
|
State("upload-image", "filename"),
|
||||||
State("upload-image", "last_modified"),
|
State("upload-image", "last_modified"),
|
||||||
@@ -80,23 +95,51 @@ class APP(Dash):
|
|||||||
list_of_names: List[str],
|
list_of_names: List[str],
|
||||||
list_of_dates: List[int] | List[float],
|
list_of_dates: List[int] | List[float],
|
||||||
):
|
):
|
||||||
def parse_contents(contents: str, filename: str, date: Union[int, float]):
|
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||||
return html.Div(
|
"""Parse base64 string to PIL Image."""
|
||||||
[
|
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
||||||
html.H5(filename),
|
base64_str = contents.split(",")[1]
|
||||||
html.H6(datetime.datetime.fromtimestamp(date)),
|
img_bytes = base64.b64decode(base64_str)
|
||||||
# HTML images accept base64 encoded strings in the same format
|
return Image.open(io.BytesIO(img_bytes))
|
||||||
# that is supplied by the upload
|
|
||||||
dmc.Image(src=contents),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if list_of_contents is not None:
|
if list_of_contents is not None:
|
||||||
|
# Process first uploaded image for similarity search
|
||||||
|
filename = list_of_names[0]
|
||||||
|
uploaddate = list_of_dates[0]
|
||||||
|
imagecontent = list_of_contents[0]
|
||||||
|
|
||||||
|
pil_image = parse_base64_to_pil(imagecontent)
|
||||||
|
|
||||||
|
# Extract feature vector using DINOv2
|
||||||
|
feature_vector = APP._feature_retrieval.extract_single_image_feature(
|
||||||
|
pil_image
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for similar images in database
|
||||||
|
results_df = (
|
||||||
|
db_manager.table.search(feature_vector)
|
||||||
|
.select(["id", "label", "vector"])
|
||||||
|
.limit(10)
|
||||||
|
.to_polars()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to AgGrid row format
|
||||||
|
row_data = results_df.to_dicts()
|
||||||
|
|
||||||
|
# Display uploaded images
|
||||||
children = [
|
children = [
|
||||||
parse_contents(c, n, d)
|
html.H5(filename),
|
||||||
for c, n, d in zip(list_of_contents, list_of_names, list_of_dates)
|
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
||||||
|
# HTML images accept base64 encoded strings in same format
|
||||||
|
# that is supplied by the upload
|
||||||
|
dmc.Image(src=imagecontent),
|
||||||
|
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||||
]
|
]
|
||||||
return children
|
|
||||||
|
return children, row_data
|
||||||
|
|
||||||
|
# Return empty if no content
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
app = APP()
|
app = APP()
|
||||||
|
|||||||
Reference in New Issue
Block a user