mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): integrate image upload with similarity search
This commit is contained in:
@@ -1,97 +1,35 @@
|
||||
"""Tests for FeatureVisualizer module."""
|
||||
"""Tests for visualizer app image upload similarity search."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from feature_compressor.core.visualizer import FeatureVisualizer
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestFeatureVisualizer:
|
||||
"""Test suite for FeatureVisualizer class."""
|
||||
class TestImageUploadSimilaritySearch:
|
||||
"""Test suite for image upload similarity search functionality."""
|
||||
|
||||
def test_histogram_generation(self):
|
||||
"""Test histogram generation from features."""
|
||||
def test_base64_to_pil_image(self):
|
||||
"""Test conversion from base64 string to PIL Image."""
|
||||
# Create a test image
|
||||
img_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
||||
img = Image.fromarray(img_array)
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(20, 256)
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
fig = viz.plot_histogram(features, title="Test Histogram")
|
||||
# Add data URI prefix (as Dash provides)
|
||||
img_base64_with_prefix = f"data:image/png;base64,{img_base64}"
|
||||
|
||||
assert fig is not None
|
||||
assert "Test Histogram" in fig.layout.title.text
|
||||
# Parse base64 to PIL Image
|
||||
# Remove prefix
|
||||
base64_str = img_base64_with_prefix.split(",")[1]
|
||||
img_bytes = base64.b64decode(base64_str)
|
||||
parsed_img = Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
def test_pca_2d_generation(self):
|
||||
"""Test PCA 2D scatter plot generation."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(20, 256)
|
||||
labels = ["cat"] * 10 + ["dog"] * 10
|
||||
|
||||
fig = viz.plot_pca_2d(features, labels=labels)
|
||||
|
||||
assert fig is not None
|
||||
assert "PCA 2D" in fig.layout.title.text
|
||||
|
||||
def test_comparison_plot_generation(self):
|
||||
"""Test comparison plot generation."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features_list = [torch.randn(20, 256), torch.randn(20, 256)]
|
||||
names = ["Set A", "Set B"]
|
||||
|
||||
fig = viz.plot_comparison(features_list, names)
|
||||
|
||||
assert fig is not None
|
||||
assert "Comparison" in fig.layout.title.text
|
||||
|
||||
def test_html_export(self):
|
||||
"""Test HTML export format."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(10, 256)
|
||||
|
||||
fig = viz.plot_histogram(features)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "test_plot")
|
||||
viz.save(fig, output_path, formats=["html"])
|
||||
|
||||
assert os.path.exists(output_path + ".html")
|
||||
|
||||
def test_png_export(self):
|
||||
"""Test PNG export format."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(10, 256)
|
||||
|
||||
fig = viz.plot_histogram(features)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "test_plot")
|
||||
|
||||
# Skip PNG export if Chrome not available
|
||||
try:
|
||||
viz.save(fig, output_path, formats=["png"])
|
||||
assert os.path.exists(output_path + ".png")
|
||||
except RuntimeError as e:
|
||||
if "Chrome" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def test_json_export(self):
|
||||
"""Test JSON export format."""
|
||||
|
||||
viz = FeatureVisualizer()
|
||||
features = torch.randn(10, 256)
|
||||
|
||||
fig = viz.plot_histogram(features)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "test_plot")
|
||||
viz.save(fig, output_path, formats=["json"])
|
||||
|
||||
assert os.path.exists(output_path + ".json")
|
||||
# Verify the image is valid
|
||||
assert parsed_img.size == (224, 224)
|
||||
assert parsed_img.mode == "RGB"
|
||||
|
||||
Reference in New Issue
Block a user