feat(visualizer): add cosine similarity computation for image comparison

This commit is contained in:
2026-02-07 15:05:12 +08:00
parent d6bb233651
commit 051bae5483
4 changed files with 362 additions and 7 deletions

View File

@@ -5,6 +5,7 @@ import io
import numpy as np
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
class TestImageUploadSimilaritySearch:
@@ -33,3 +34,44 @@ class TestImageUploadSimilaritySearch:
# Verify the image is valid
assert parsed_img.size == (224, 224)
assert parsed_img.mode == "RGB"
class TestCosineSimilarity:
"""Test suite for cosine similarity computation between feature vectors."""
def test_identical_vectors_return_one(self):
"""Identical vectors should have cosine similarity of 1.0."""
vec = np.random.randn(1024).tolist()
similarity = cosine_similarity([vec], [vec])[0][0]
assert np.isclose(similarity, 1.0)
def test_orthogonal_vectors_return_zero(self):
"""Orthogonal vectors should have cosine similarity of 0.0."""
vec_a = [1.0, 0.0]
vec_b = [0.0, 1.0]
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
assert np.isclose(similarity, 0.0)
def test_opposite_vectors_return_negative_one(self):
"""Opposite vectors should have cosine similarity of -1.0."""
vec_a = [1.0, 0.0, 0.0]
vec_b = [-1.0, 0.0, 0.0]
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
assert np.isclose(similarity, -1.0)
def test_similarity_range(self):
"""Cosine similarity should always be within [-1, 1]."""
# Random vectors
for _ in range(10):
vec_a = np.random.randn(1024).tolist()
vec_b = np.random.randn(1024).tolist()
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
assert -1.0 <= similarity <= 1.0
def test_similarity_with_list_input(self):
"""Cosine similarity should work with Python list inputs (as stored in dcc.Store)."""
# Simulate feature vectors stored as Python lists in dcc.Store
vec_a = [0.1, 0.2, 0.3, 0.4, 0.5]
vec_b = [0.1, 0.2, 0.3, 0.4, 0.5]
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
assert np.isclose(similarity, 1.0)

View File

@@ -9,6 +9,7 @@ from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager
from feature_retrieval import FeatureRetrieval
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoImageProcessor, AutoModel
from visualizer.events import CellClickedEvent
@@ -114,6 +115,18 @@ class APP(Dash):
direction="row",
wrap="wrap",
),
# Store feature vectors for cosine similarity computation
dcc.Store(id="store-upload-vector"),
dcc.Store(id="store-select-vector"),
# Cosine similarity display (below image comparison)
html.Div(
id="cosine-similarity-display",
style={
"textAlign": "center",
"padding": "10px",
"margin": "10px",
},
),
dag.AgGrid(id="ag-grid"),
],
gap="md",
@@ -130,6 +143,7 @@ class APP(Dash):
Output("output-image-upload", "children"),
Output("ag-grid", "rowData"),
Output("ag-grid", "columnDefs"),
Output("store-upload-vector", "data"),
Input("upload-image", "contents"),
State("upload-image", "filename"),
State("upload-image", "last_modified"),
@@ -177,7 +191,7 @@ def update_output(
dmc.Text(f"{feature_vector[:5]}", size="xs"),
]
return children, row_data, columnDefs
return children, row_data, columnDefs, feature_vector
else:
# When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
@@ -190,11 +204,12 @@ def update_output(
for column in df.columns
]
return [], row_data, columnDefs
return [], row_data, columnDefs, None
@callback(
Output("output-image-select", "children"),
Output("store-select-vector", "data"),
Input("ag-grid", "cellClicked"),
State("ag-grid", "rowData"),
)
@@ -208,18 +223,18 @@ def update_images_comparison(
or row_data is None
or len(row_data) == 0
):
return []
return [], None
# Get the selected row's data
row_index = int(clicked_event["rowIndex"])
if row_index >= len(row_data):
return []
return [], None
selected_row = row_data[row_index]
image_id = selected_row.get("id")
if image_id is None:
return []
return [], None
# Query database for binary data using the id
result = (
@@ -231,7 +246,7 @@ def update_images_comparison(
)
if result.height == 0:
return []
return [], None
# Get binary data
binary_data = result.row(0, named=True)["binary"]
@@ -268,7 +283,34 @@ def update_images_comparison(
dmc.Text(f"{vector[:5]}", size="xs"),
]
return children
# Convert vector to list for JSON serialization in dcc.Store
select_vector = vector if isinstance(vector, list) else list(vector)
return children, select_vector
@callback(
Output("cosine-similarity-display", "children"),
Input("store-upload-vector", "data"),
Input("store-select-vector", "data"),
)
def update_cosine_similarity(
upload_vector: Optional[List[float]],
select_vector: Optional[List[float]],
):
"""Compute and display cosine similarity when both vectors are available."""
# Only display when both upload and select images are present
if upload_vector is None or select_vector is None:
return []
similarity = cosine_similarity([upload_vector], [select_vector])[0][0]
return dmc.Text(
f"Cosine Similarity: {similarity:.4f}",
size="lg",
fw=700,
ta="center",
)
app = APP()