diff --git a/mini-nav/database.py b/mini-nav/database.py index 4e743cd..eccb8a2 100644 --- a/mini-nav/database.py +++ b/mini-nav/database.py @@ -1,4 +1,5 @@ from typing import Optional + import lancedb import pyarrow as pa from configs import cfg_manager @@ -31,7 +32,7 @@ class DatabaseManager: # 初始化数据库与表格 self.db = lancedb.connect(db_path) - if "default" not in self.db.table_names(): + if "default" not in self.db.list_tables().tables: self.table = self.db.create_table("default", schema=db_schema) else: self.table = self.db.open_table("default") diff --git a/mini-nav/feature_compressor/core/visualizer.py b/mini-nav/feature_compressor/core/visualizer.py deleted file mode 100644 index 8ceefa2..0000000 --- a/mini-nav/feature_compressor/core/visualizer.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Feature visualization using Plotly.""" - -import os -from pathlib import Path -from typing import List, Optional, Union - -import numpy as np -import torch -from configs import FeatureCompressorConfig, cfg_manager, load_yaml -from plotly.graph_objs import Figure - -from ..utils.plot_utils import ( - apply_theme, - create_comparison_plot, - create_histogram, - create_pca_scatter_2d, - save_figure, -) - - -class FeatureVisualizer: - """Visualize DINOv2 features with interactive Plotly charts. - - Supports histograms, PCA projections, and feature comparisons - with multiple export formats. - - Args: - config_path: Path to YAML configuration file - """ - - def __init__(self, config_path: Optional[str] = None): - self.config: FeatureCompressorConfig = self._load_config(config_path) - - def _load_config( - self, config_path: Optional[str] = None - ) -> FeatureCompressorConfig: - """Load configuration from YAML file. - - Args: - config_path: Path to config file, or None for default - - Returns: - Configuration Pydantic model - """ - if config_path is None: - return cfg_manager.get() - else: - return load_yaml(Path(config_path), FeatureCompressorConfig) - - def plot_histogram( - self, features: torch.Tensor, title: Optional[str] = None - ) -> Figure: - """Plot histogram of feature values. - - Args: - features: Feature tensor [batch, dim] - title: Plot title - - Returns: - Plotly Figure object - """ - features_np = features.cpu().numpy() - fig = create_histogram( - features_np, title="Feature Histogram" if title is None else title - ) - - fig = apply_theme(fig, self.config.visualization.plot_theme) - fig.update_layout( - width=self.config.visualization.fig_width, - height=self.config.visualization.fig_height, - ) - - return fig - - def plot_pca_2d( - self, features: torch.Tensor, labels: Optional[List] = None - ) -> Figure: - """Plot 2D PCA projection of features. - - Args: - features: Feature tensor [n_samples, dim] - labels: Optional labels for coloring - - Returns: - Plotly Figure object - """ - features_np = features.cpu().numpy() - - fig = create_pca_scatter_2d( - features_np, - labels=[i for i in range(len(features_np))] if labels is None else labels, - ) - fig = apply_theme(fig, self.config.visualization.plot_theme) - fig.update_traces( - marker=dict( - size=self.config.visualization.point_size, - colorscale=self.config.visualization.color_scale, - ) - ) - fig.update_layout( - width=self.config.visualization.fig_width, - height=self.config.visualization.fig_height, - ) - - return fig - - def plot_comparison( - self, features_list: List[torch.Tensor], names: List[str] - ) -> object: - """Plot comparison of multiple feature sets. - - Args: - features_list: List of feature tensors - names: Names for each feature set - - Returns: - Plotly Figure object - """ - features_np_list = [f.cpu().numpy() for f in features_list] - - fig = create_comparison_plot(features_np_list, names) - - fig = apply_theme(fig, self.config.visualization.plot_theme) - fig.update_layout( - width=self.config.visualization.fig_width * len(features_list), - height=self.config.visualization.fig_height, - ) - - return fig - - def generate_report( - self, results: List[dict], output_dir: Union[str, Path] - ) -> List[str]: - """Generate full feature analysis report. - - Args: - results: List of extractor results - output_dir: Directory to save visualizations - - Returns: - List of generated file paths - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - generated_files = [] - - # Extract all compressed features - all_features = torch.cat([r["compressed_features"] for r in results], dim=0) - - # Create histogram - hist_fig = self.plot_histogram(all_features, "Compressed Feature Distribution") - hist_path = output_dir / "feature_histogram" - self.save(hist_fig, str(hist_path), formats=["html"]) - generated_files.append(str(hist_path) + ".html") - - # Create PCA - pca_fig = self.plot_pca_2d(all_features) - pca_path = output_dir / "feature_pca_2d" - self.save(pca_fig, str(pca_path), formats=["html", "png"]) - generated_files.append(str(pca_path) + ".html") - generated_files.append(str(pca_path) + ".png") - - return generated_files - - def save(self, fig: Figure, path: str, formats: List[str]) -> None: - """Save figure in multiple formats. - - Args: - fig: Plotly Figure object - path: Output file path (without extension) - formats: List of formats to export - """ - if formats is None: - formats = ["html"] - - for fmt in formats: - if fmt == "png": - save_figure(fig, path, format="png") - else: - save_figure(fig, path, format=fmt) diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index cb12b60..f4cf345 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Optional, Union, cast -import polars as pl import torch from database import db_manager from datasets import load_dataset @@ -104,7 +103,9 @@ class FeatureRetrieval: ) @torch.no_grad() - def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series: + def extract_single_image_feature( + self, image: Union[Image.Image, Any] + ) -> List[float]: """Extract feature from a single image without storing to database. Args: @@ -128,8 +129,8 @@ class FeatureRetrieval: cls_token = feats[:, 0] # [1, D] cls_token = cast(torch.Tensor, cls_token) - # 返回 Polars Series - return pl.Series("feature", cls_token.cpu().squeeze(0).tolist()) + # 返回 CLS List + return cls_token.cpu().squeeze(0).tolist() if __name__ == "__main__": diff --git a/mini-nav/tests/test_visualizer.py b/mini-nav/tests/test_visualizer.py index 7f29587..e6283d8 100644 --- a/mini-nav/tests/test_visualizer.py +++ b/mini-nav/tests/test_visualizer.py @@ -1,97 +1,35 @@ -"""Tests for FeatureVisualizer module.""" +"""Tests for visualizer app image upload similarity search.""" -import os -import tempfile +import base64 +import io import numpy as np -import pytest -import torch -from feature_compressor.core.visualizer import FeatureVisualizer +from PIL import Image -class TestFeatureVisualizer: - """Test suite for FeatureVisualizer class.""" +class TestImageUploadSimilaritySearch: + """Test suite for image upload similarity search functionality.""" - def test_histogram_generation(self): - """Test histogram generation from features.""" + def test_base64_to_pil_image(self): + """Test conversion from base64 string to PIL Image.""" + # Create a test image + img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) + img = Image.fromarray(img_array) - viz = FeatureVisualizer() - features = torch.randn(20, 256) + # Convert to base64 + buffer = io.BytesIO() + img.save(buffer, format="PNG") + img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - fig = viz.plot_histogram(features, title="Test Histogram") + # Add data URI prefix (as Dash provides) + img_base64_with_prefix = f"data:image/png;base64,{img_base64}" - assert fig is not None - assert "Test Histogram" in fig.layout.title.text + # Parse base64 to PIL Image + # Remove prefix + base64_str = img_base64_with_prefix.split(",")[1] + img_bytes = base64.b64decode(base64_str) + parsed_img = Image.open(io.BytesIO(img_bytes)) - def test_pca_2d_generation(self): - """Test PCA 2D scatter plot generation.""" - - viz = FeatureVisualizer() - features = torch.randn(20, 256) - labels = ["cat"] * 10 + ["dog"] * 10 - - fig = viz.plot_pca_2d(features, labels=labels) - - assert fig is not None - assert "PCA 2D" in fig.layout.title.text - - def test_comparison_plot_generation(self): - """Test comparison plot generation.""" - - viz = FeatureVisualizer() - features_list = [torch.randn(20, 256), torch.randn(20, 256)] - names = ["Set A", "Set B"] - - fig = viz.plot_comparison(features_list, names) - - assert fig is not None - assert "Comparison" in fig.layout.title.text - - def test_html_export(self): - """Test HTML export format.""" - - viz = FeatureVisualizer() - features = torch.randn(10, 256) - - fig = viz.plot_histogram(features) - - with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "test_plot") - viz.save(fig, output_path, formats=["html"]) - - assert os.path.exists(output_path + ".html") - - def test_png_export(self): - """Test PNG export format.""" - - viz = FeatureVisualizer() - features = torch.randn(10, 256) - - fig = viz.plot_histogram(features) - - with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "test_plot") - - # Skip PNG export if Chrome not available - try: - viz.save(fig, output_path, formats=["png"]) - assert os.path.exists(output_path + ".png") - except RuntimeError as e: - if "Chrome" in str(e): - pass - else: - raise - - def test_json_export(self): - """Test JSON export format.""" - - viz = FeatureVisualizer() - features = torch.randn(10, 256) - - fig = viz.plot_histogram(features) - - with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, "test_plot") - viz.save(fig, output_path, formats=["json"]) - - assert os.path.exists(output_path + ".json") + # Verify the image is valid + assert parsed_img.size == (224, 224) + assert parsed_img.mode == "RGB" diff --git a/mini-nav/visualizer/app.py b/mini-nav/visualizer/app.py index dc8bc6c..313125a 100644 --- a/mini-nav/visualizer/app.py +++ b/mini-nav/visualizer/app.py @@ -1,10 +1,15 @@ +import base64 import datetime -from typing import List, Optional, Union +import io +from typing import List, Optional import dash_ag_grid as dag import dash_mantine_components as dmc from dash import Dash, Input, Output, State, callback, dcc, html from database import db_manager +from feature_retrieval import FeatureRetrieval +from PIL import Image +from transformers import AutoImageProcessor, AutoModel class APP(Dash): @@ -12,6 +17,9 @@ class APP(Dash): _instance: Optional["APP"] = None + # Feature retrieval singleton + _feature_retrieval: FeatureRetrieval + def __new__(cls) -> "APP": if cls._instance is None: cls._instance = super().__new__(cls) @@ -20,6 +28,11 @@ class APP(Dash): def __init__(self): super().__init__(__name__) + # Initialize FeatureRetrieval + processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large") + model = AutoModel.from_pretrained("facebook/dinov2-large") + APP._feature_retrieval = FeatureRetrieval(processor, model) + df = ( db_manager.table.search() .select(["id", "label", "vector"]) @@ -56,6 +69,7 @@ class APP(Dash): ), html.Div(id="output-image-upload"), dag.AgGrid( + id="ag-grid", rowData=df.to_dicts(), columnDefs=columnDefs, ), @@ -71,6 +85,7 @@ class APP(Dash): @callback( Output("output-image-upload", "children"), + Output("ag-grid", "rowData"), Input("upload-image", "contents"), State("upload-image", "filename"), State("upload-image", "last_modified"), @@ -80,23 +95,51 @@ class APP(Dash): list_of_names: List[str], list_of_dates: List[int] | List[float], ): - def parse_contents(contents: str, filename: str, date: Union[int, float]): - return html.Div( - [ - html.H5(filename), - html.H6(datetime.datetime.fromtimestamp(date)), - # HTML images accept base64 encoded strings in the same format - # that is supplied by the upload - dmc.Image(src=contents), - ] - ) + def parse_base64_to_pil(contents: str) -> Image.Image: + """Parse base64 string to PIL Image.""" + # Remove data URI prefix (e.g., "data:image/png;base64,") + base64_str = contents.split(",")[1] + img_bytes = base64.b64decode(base64_str) + return Image.open(io.BytesIO(img_bytes)) if list_of_contents is not None: + # Process first uploaded image for similarity search + filename = list_of_names[0] + uploaddate = list_of_dates[0] + imagecontent = list_of_contents[0] + + pil_image = parse_base64_to_pil(imagecontent) + + # Extract feature vector using DINOv2 + feature_vector = APP._feature_retrieval.extract_single_image_feature( + pil_image + ) + + # Search for similar images in database + results_df = ( + db_manager.table.search(feature_vector) + .select(["id", "label", "vector"]) + .limit(10) + .to_polars() + ) + + # Convert to AgGrid row format + row_data = results_df.to_dicts() + + # Display uploaded images children = [ - parse_contents(c, n, d) - for c, n, d in zip(list_of_contents, list_of_names, list_of_dates) + html.H5(filename), + html.H6(str(datetime.datetime.fromtimestamp(uploaddate))), + # HTML images accept base64 encoded strings in same format + # that is supplied by the upload + dmc.Image(src=imagecontent), + dmc.Text(f"{feature_vector[:5]}", size="xs"), ] - return children + + return children, row_data + + # Return empty if no content + return [], [] app = APP()