feat(visualizer): integrate image upload with similarity search

This commit is contained in:
2026-02-05 21:41:39 +08:00
parent a0df45ab05
commit e859fef2b3
5 changed files with 89 additions and 287 deletions

View File

@@ -1,97 +1,35 @@
"""Tests for FeatureVisualizer module."""
"""Tests for visualizer app image upload similarity search."""
import os
import tempfile
import base64
import io
import numpy as np
import pytest
import torch
from feature_compressor.core.visualizer import FeatureVisualizer
from PIL import Image
class TestFeatureVisualizer:
"""Test suite for FeatureVisualizer class."""
class TestImageUploadSimilaritySearch:
"""Test suite for image upload similarity search functionality."""
def test_histogram_generation(self):
"""Test histogram generation from features."""
def test_base64_to_pil_image(self):
"""Test conversion from base64 string to PIL Image."""
# Create a test image
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
img = Image.fromarray(img_array)
viz = FeatureVisualizer()
features = torch.randn(20, 256)
# Convert to base64
buffer = io.BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
fig = viz.plot_histogram(features, title="Test Histogram")
# Add data URI prefix (as Dash provides)
img_base64_with_prefix = f"data:image/png;base64,{img_base64}"
assert fig is not None
assert "Test Histogram" in fig.layout.title.text
# Parse base64 to PIL Image
# Remove prefix
base64_str = img_base64_with_prefix.split(",")[1]
img_bytes = base64.b64decode(base64_str)
parsed_img = Image.open(io.BytesIO(img_bytes))
def test_pca_2d_generation(self):
"""Test PCA 2D scatter plot generation."""
viz = FeatureVisualizer()
features = torch.randn(20, 256)
labels = ["cat"] * 10 + ["dog"] * 10
fig = viz.plot_pca_2d(features, labels=labels)
assert fig is not None
assert "PCA 2D" in fig.layout.title.text
def test_comparison_plot_generation(self):
"""Test comparison plot generation."""
viz = FeatureVisualizer()
features_list = [torch.randn(20, 256), torch.randn(20, 256)]
names = ["Set A", "Set B"]
fig = viz.plot_comparison(features_list, names)
assert fig is not None
assert "Comparison" in fig.layout.title.text
def test_html_export(self):
"""Test HTML export format."""
viz = FeatureVisualizer()
features = torch.randn(10, 256)
fig = viz.plot_histogram(features)
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_plot")
viz.save(fig, output_path, formats=["html"])
assert os.path.exists(output_path + ".html")
def test_png_export(self):
"""Test PNG export format."""
viz = FeatureVisualizer()
features = torch.randn(10, 256)
fig = viz.plot_histogram(features)
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_plot")
# Skip PNG export if Chrome not available
try:
viz.save(fig, output_path, formats=["png"])
assert os.path.exists(output_path + ".png")
except RuntimeError as e:
if "Chrome" in str(e):
pass
else:
raise
def test_json_export(self):
"""Test JSON export format."""
viz = FeatureVisualizer()
features = torch.randn(10, 256)
fig = viz.plot_histogram(features)
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, "test_plot")
viz.save(fig, output_path, formats=["json"])
assert os.path.exists(output_path + ".json")
# Verify the image is valid
assert parsed_img.size == (224, 224)
assert parsed_img.mode == "RGB"