feat(visualizer): integrate image upload with similarity search

This commit is contained in:
2026-02-05 21:41:39 +08:00
parent a0df45ab05
commit e859fef2b3
5 changed files with 89 additions and 287 deletions

View File

@@ -1,4 +1,5 @@
from typing import Optional from typing import Optional
import lancedb import lancedb
import pyarrow as pa import pyarrow as pa
from configs import cfg_manager from configs import cfg_manager
@@ -31,7 +32,7 @@ class DatabaseManager:
# 初始化数据库与表格 # 初始化数据库与表格
self.db = lancedb.connect(db_path) self.db = lancedb.connect(db_path)
if "default" not in self.db.table_names(): if "default" not in self.db.list_tables().tables:
self.table = self.db.create_table("default", schema=db_schema) self.table = self.db.create_table("default", schema=db_schema)
else: else:
self.table = self.db.open_table("default") self.table = self.db.open_table("default")

View File

@@ -1,181 +0,0 @@
"""Feature visualization using Plotly."""
import os
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from configs import FeatureCompressorConfig, cfg_manager, load_yaml
from plotly.graph_objs import Figure
from ..utils.plot_utils import (
apply_theme,
create_comparison_plot,
create_histogram,
create_pca_scatter_2d,
save_figure,
)
class FeatureVisualizer:
"""Visualize DINOv2 features with interactive Plotly charts.
Supports histograms, PCA projections, and feature comparisons
with multiple export formats.
Args:
config_path: Path to YAML configuration file
"""
def __init__(self, config_path: Optional[str] = None):
self.config: FeatureCompressorConfig = self._load_config(config_path)
def _load_config(
self, config_path: Optional[str] = None
) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
config_path: Path to config file, or None for default
Returns:
Configuration Pydantic model
"""
if config_path is None:
return cfg_manager.get()
else:
return load_yaml(Path(config_path), FeatureCompressorConfig)
def plot_histogram(
self, features: torch.Tensor, title: Optional[str] = None
) -> Figure:
"""Plot histogram of feature values.
Args:
features: Feature tensor [batch, dim]
title: Plot title
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_histogram(
features_np, title="Feature Histogram" if title is None else title
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
def plot_pca_2d(
self, features: torch.Tensor, labels: Optional[List] = None
) -> Figure:
"""Plot 2D PCA projection of features.
Args:
features: Feature tensor [n_samples, dim]
labels: Optional labels for coloring
Returns:
Plotly Figure object
"""
features_np = features.cpu().numpy()
fig = create_pca_scatter_2d(
features_np,
labels=[i for i in range(len(features_np))] if labels is None else labels,
)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_traces(
marker=dict(
size=self.config.visualization.point_size,
colorscale=self.config.visualization.color_scale,
)
)
fig.update_layout(
width=self.config.visualization.fig_width,
height=self.config.visualization.fig_height,
)
return fig
def plot_comparison(
self, features_list: List[torch.Tensor], names: List[str]
) -> object:
"""Plot comparison of multiple feature sets.
Args:
features_list: List of feature tensors
names: Names for each feature set
Returns:
Plotly Figure object
"""
features_np_list = [f.cpu().numpy() for f in features_list]
fig = create_comparison_plot(features_np_list, names)
fig = apply_theme(fig, self.config.visualization.plot_theme)
fig.update_layout(
width=self.config.visualization.fig_width * len(features_list),
height=self.config.visualization.fig_height,
)
return fig
def generate_report(
self, results: List[dict], output_dir: Union[str, Path]
) -> List[str]:
"""Generate full feature analysis report.
Args:
results: List of extractor results
output_dir: Directory to save visualizations
Returns:
List of generated file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
generated_files = []
# Extract all compressed features
all_features = torch.cat([r["compressed_features"] for r in results], dim=0)
# Create histogram
hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution")
hist_path = output_dir / "feature_histogram"
self.save(hist_fig, str(hist_path), formats=["html"])
generated_files.append(str(hist_path) + ".html")
# Create PCA
pca_fig = self.plot_pca_2d(all_features)
pca_path = output_dir / "feature_pca_2d"
self.save(pca_fig, str(pca_path), formats=["html", "png"])
generated_files.append(str(pca_path) + ".html")
generated_files.append(str(pca_path) + ".png")
return generated_files
def save(self, fig: Figure, path: str, formats: List[str]) -> None:
"""Save figure in multiple formats.
Args:
fig: Plotly Figure object
path: Output file path (without extension)
formats: List of formats to export
"""
if formats is None:
formats = ["html"]
for fmt in formats:
if fmt == "png":
save_figure(fig, path, format="png")
else:
save_figure(fig, path, format=fmt)

View File

@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Union, cast
import polars as pl
import torch import torch
from database import db_manager from database import db_manager
from datasets import load_dataset from datasets import load_dataset
@@ -104,7 +103,9 @@ class FeatureRetrieval:
) )
@torch.no_grad() @torch.no_grad()
def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series: def extract_single_image_feature(
self, image: Union[Image.Image, Any]
) -> List[float]:
"""Extract feature from a single image without storing to database. """Extract feature from a single image without storing to database.
Args: Args:
@@ -128,8 +129,8 @@ class FeatureRetrieval:
cls_token = feats[:, 0] # [1, D] cls_token = feats[:, 0] # [1, D]
cls_token = cast(torch.Tensor, cls_token) cls_token = cast(torch.Tensor, cls_token)
# 返回 Polars Series # 返回 CLS List
return pl.Series("feature", cls_token.cpu().squeeze(0).tolist()) return cls_token.cpu().squeeze(0).tolist()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,97 +1,35 @@
"""Tests for FeatureVisualizer module.""" """Tests for visualizer app image upload similarity search."""
import os import base64
import tempfile import io
import numpy as np import numpy as np
import pytest from PIL import Image
import torch
from feature_compressor.core.visualizer import FeatureVisualizer
class TestFeatureVisualizer: class TestImageUploadSimilaritySearch:
"""Test suite for FeatureVisualizer class.""" """Test suite for image upload similarity search functionality."""
def test_histogram_generation(self): def test_base64_to_pil_image(self):
"""Test histogram generation from features.""" """Test conversion from base64 string to PIL Image."""
# Create a test image
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
img = Image.fromarray(img_array)
viz = FeatureVisualizer() # Convert to base64
features = torch.randn(20, 256) buffer = io.BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
fig = viz.plot_histogram(features, title="Test Histogram") # Add data URI prefix (as Dash provides)
img_base64_with_prefix = f"data:image/png;base64,{img_base64}"
assert fig is not None # Parse base64 to PIL Image
assert "Test Histogram" in fig.layout.title.text # Remove prefix
base64_str = img_base64_with_prefix.split(",")[1]
img_bytes = base64.b64decode(base64_str)
parsed_img = Image.open(io.BytesIO(img_bytes))
def test_pca_2d_generation(self): # Verify the image is valid
"""Test PCA 2D scatter plot generation.""" assert parsed_img.size == (224, 224)
assert parsed_img.mode == "RGB"
viz = FeatureVisualizer()
features = torch.randn(20, 256)
labels = ["cat"] * 10 + ["dog"] * 10
fig = viz.plot_pca_2d(features, labels=labels)
assert fig is not None
assert "PCA 2D" in fig.layout.title.text
def test_comparison_plot_generation(self):
"""Test comparison plot generation."""
viz = FeatureVisualizer()
features_list = [torch.randn(20, 256), torch.randn(20, 256)]
names = ["Set A", "Set B"]
fig = viz.plot_comparison(features_list, names)
assert fig is not None
assert "Comparison" in fig.layout.title.text
def test_html_export(self):
"""Test HTML export format."""
viz = FeatureVisualizer()
features = torch.randn(10, 256)
fig = viz.plot_histogram(features)
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_plot")
viz.save(fig, output_path, formats=["html"])
assert os.path.exists(output_path + ".html")
def test_png_export(self):
"""Test PNG export format."""
viz = FeatureVisualizer()
features = torch.randn(10, 256)
fig = viz.plot_histogram(features)
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_plot")
# Skip PNG export if Chrome not available
try:
viz.save(fig, output_path, formats=["png"])
assert os.path.exists(output_path + ".png")
except RuntimeError as e:
if "Chrome" in str(e):
pass
else:
raise
def test_json_export(self):
"""Test JSON export format."""
viz = FeatureVisualizer()
features = torch.randn(10, 256)
fig = viz.plot_histogram(features)
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_plot")
viz.save(fig, output_path, formats=["json"])
assert os.path.exists(output_path + ".json")

View File

@@ -1,10 +1,15 @@
import base64
import datetime import datetime
from typing import List, Optional, Union import io
from typing import List, Optional
import dash_ag_grid as dag import dash_ag_grid as dag
import dash_mantine_components as dmc import dash_mantine_components as dmc
from dash import Dash, Input, Output, State, callback, dcc, html from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager from database import db_manager
from feature_retrieval import FeatureRetrieval
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
class APP(Dash): class APP(Dash):
@@ -12,6 +17,9 @@ class APP(Dash):
_instance: Optional["APP"] = None _instance: Optional["APP"] = None
# Feature retrieval singleton
_feature_retrieval: FeatureRetrieval
def __new__(cls) -> "APP": def __new__(cls) -> "APP":
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
@@ -20,6 +28,11 @@ class APP(Dash):
def __init__(self): def __init__(self):
super().__init__(__name__) super().__init__(__name__)
# Initialize FeatureRetrieval
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
APP._feature_retrieval = FeatureRetrieval(processor, model)
df = ( df = (
db_manager.table.search() db_manager.table.search()
.select(["id", "label", "vector"]) .select(["id", "label", "vector"])
@@ -56,6 +69,7 @@ class APP(Dash):
), ),
html.Div(id="output-image-upload"), html.Div(id="output-image-upload"),
dag.AgGrid( dag.AgGrid(
id="ag-grid",
rowData=df.to_dicts(), rowData=df.to_dicts(),
columnDefs=columnDefs, columnDefs=columnDefs,
), ),
@@ -71,6 +85,7 @@ class APP(Dash):
@callback( @callback(
Output("output-image-upload", "children"), Output("output-image-upload", "children"),
Output("ag-grid", "rowData"),
Input("upload-image", "contents"), Input("upload-image", "contents"),
State("upload-image", "filename"), State("upload-image", "filename"),
State("upload-image", "last_modified"), State("upload-image", "last_modified"),
@@ -80,23 +95,51 @@ class APP(Dash):
list_of_names: List[str], list_of_names: List[str],
list_of_dates: List[int] | List[float], list_of_dates: List[int] | List[float],
): ):
def parse_contents(contents: str, filename: str, date: Union[int, float]): def parse_base64_to_pil(contents: str) -> Image.Image:
return html.Div( """Parse base64 string to PIL Image."""
[ # Remove data URI prefix (e.g., "data:image/png;base64,")
html.H5(filename), base64_str = contents.split(",")[1]
html.H6(datetime.datetime.fromtimestamp(date)), img_bytes = base64.b64decode(base64_str)
# HTML images accept base64 encoded strings in the same format return Image.open(io.BytesIO(img_bytes))
# that is supplied by the upload
dmc.Image(src=contents),
]
)
if list_of_contents is not None: if list_of_contents is not None:
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
pil_image = parse_base64_to_pil(imagecontent)
# Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(
pil_image
)
# Search for similar images in database
results_df = (
db_manager.table.search(feature_vector)
.select(["id", "label", "vector"])
.limit(10)
.to_polars()
)
# Convert to AgGrid row format
row_data = results_df.to_dicts()
# Display uploaded images
children = [ children = [
parse_contents(c, n, d) html.H5(filename),
for c, n, d in zip(list_of_contents, list_of_names, list_of_dates) html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format
# that is supplied by the upload
dmc.Image(src=imagecontent),
dmc.Text(f"{feature_vector[:5]}", size="xs"),
] ]
return children
return children, row_data
# Return empty if no content
return [], []
app = APP() app = APP()