mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): add cosine similarity computation for image comparison
This commit is contained in:
@@ -5,6 +5,7 @@ import io
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
|
||||
class TestImageUploadSimilaritySearch:
|
||||
@@ -33,3 +34,44 @@ class TestImageUploadSimilaritySearch:
|
||||
# Verify the image is valid
|
||||
assert parsed_img.size == (224, 224)
|
||||
assert parsed_img.mode == "RGB"
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
"""Test suite for cosine similarity computation between feature vectors."""
|
||||
|
||||
def test_identical_vectors_return_one(self):
|
||||
"""Identical vectors should have cosine similarity of 1.0."""
|
||||
vec = np.random.randn(1024).tolist()
|
||||
similarity = cosine_similarity([vec], [vec])[0][0]
|
||||
assert np.isclose(similarity, 1.0)
|
||||
|
||||
def test_orthogonal_vectors_return_zero(self):
|
||||
"""Orthogonal vectors should have cosine similarity of 0.0."""
|
||||
vec_a = [1.0, 0.0]
|
||||
vec_b = [0.0, 1.0]
|
||||
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
|
||||
assert np.isclose(similarity, 0.0)
|
||||
|
||||
def test_opposite_vectors_return_negative_one(self):
|
||||
"""Opposite vectors should have cosine similarity of -1.0."""
|
||||
vec_a = [1.0, 0.0, 0.0]
|
||||
vec_b = [-1.0, 0.0, 0.0]
|
||||
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
|
||||
assert np.isclose(similarity, -1.0)
|
||||
|
||||
def test_similarity_range(self):
|
||||
"""Cosine similarity should always be within [-1, 1]."""
|
||||
# Random vectors
|
||||
for _ in range(10):
|
||||
vec_a = np.random.randn(1024).tolist()
|
||||
vec_b = np.random.randn(1024).tolist()
|
||||
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
|
||||
assert -1.0 <= similarity <= 1.0
|
||||
|
||||
def test_similarity_with_list_input(self):
|
||||
"""Cosine similarity should work with Python list inputs (as stored in dcc.Store)."""
|
||||
# Simulate feature vectors stored as Python lists in dcc.Store
|
||||
vec_a = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
vec_b = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
similarity = cosine_similarity([vec_a], [vec_b])[0][0]
|
||||
assert np.isclose(similarity, 1.0)
|
||||
|
||||
@@ -9,6 +9,7 @@ from dash import Dash, Input, Output, State, callback, dcc, html
|
||||
from database import db_manager
|
||||
from feature_retrieval import FeatureRetrieval
|
||||
from PIL import Image
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from visualizer.events import CellClickedEvent
|
||||
|
||||
@@ -114,6 +115,18 @@ class APP(Dash):
|
||||
direction="row",
|
||||
wrap="wrap",
|
||||
),
|
||||
# Store feature vectors for cosine similarity computation
|
||||
dcc.Store(id="store-upload-vector"),
|
||||
dcc.Store(id="store-select-vector"),
|
||||
# Cosine similarity display (below image comparison)
|
||||
html.Div(
|
||||
id="cosine-similarity-display",
|
||||
style={
|
||||
"textAlign": "center",
|
||||
"padding": "10px",
|
||||
"margin": "10px",
|
||||
},
|
||||
),
|
||||
dag.AgGrid(id="ag-grid"),
|
||||
],
|
||||
gap="md",
|
||||
@@ -130,6 +143,7 @@ class APP(Dash):
|
||||
Output("output-image-upload", "children"),
|
||||
Output("ag-grid", "rowData"),
|
||||
Output("ag-grid", "columnDefs"),
|
||||
Output("store-upload-vector", "data"),
|
||||
Input("upload-image", "contents"),
|
||||
State("upload-image", "filename"),
|
||||
State("upload-image", "last_modified"),
|
||||
@@ -177,7 +191,7 @@ def update_output(
|
||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||
]
|
||||
|
||||
return children, row_data, columnDefs
|
||||
return children, row_data, columnDefs, feature_vector
|
||||
else:
|
||||
# When contents is None
|
||||
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||
@@ -190,11 +204,12 @@ def update_output(
|
||||
for column in df.columns
|
||||
]
|
||||
|
||||
return [], row_data, columnDefs
|
||||
return [], row_data, columnDefs, None
|
||||
|
||||
|
||||
@callback(
|
||||
Output("output-image-select", "children"),
|
||||
Output("store-select-vector", "data"),
|
||||
Input("ag-grid", "cellClicked"),
|
||||
State("ag-grid", "rowData"),
|
||||
)
|
||||
@@ -208,18 +223,18 @@ def update_images_comparison(
|
||||
or row_data is None
|
||||
or len(row_data) == 0
|
||||
):
|
||||
return []
|
||||
return [], None
|
||||
|
||||
# Get the selected row's data
|
||||
row_index = int(clicked_event["rowIndex"])
|
||||
if row_index >= len(row_data):
|
||||
return []
|
||||
return [], None
|
||||
|
||||
selected_row = row_data[row_index]
|
||||
image_id = selected_row.get("id")
|
||||
|
||||
if image_id is None:
|
||||
return []
|
||||
return [], None
|
||||
|
||||
# Query database for binary data using the id
|
||||
result = (
|
||||
@@ -231,7 +246,7 @@ def update_images_comparison(
|
||||
)
|
||||
|
||||
if result.height == 0:
|
||||
return []
|
||||
return [], None
|
||||
|
||||
# Get binary data
|
||||
binary_data = result.row(0, named=True)["binary"]
|
||||
@@ -268,7 +283,34 @@ def update_images_comparison(
|
||||
dmc.Text(f"{vector[:5]}", size="xs"),
|
||||
]
|
||||
|
||||
return children
|
||||
# Convert vector to list for JSON serialization in dcc.Store
|
||||
select_vector = vector if isinstance(vector, list) else list(vector)
|
||||
|
||||
return children, select_vector
|
||||
|
||||
|
||||
@callback(
|
||||
Output("cosine-similarity-display", "children"),
|
||||
Input("store-upload-vector", "data"),
|
||||
Input("store-select-vector", "data"),
|
||||
)
|
||||
def update_cosine_similarity(
|
||||
upload_vector: Optional[List[float]],
|
||||
select_vector: Optional[List[float]],
|
||||
):
|
||||
"""Compute and display cosine similarity when both vectors are available."""
|
||||
# Only display when both upload and select images are present
|
||||
if upload_vector is None or select_vector is None:
|
||||
return []
|
||||
|
||||
similarity = cosine_similarity([upload_vector], [select_vector])[0][0]
|
||||
|
||||
return dmc.Text(
|
||||
f"Cosine Similarity: {similarity:.4f}",
|
||||
size="lg",
|
||||
fw=700,
|
||||
ta="center",
|
||||
)
|
||||
|
||||
|
||||
app = APP()
|
||||
|
||||
Reference in New Issue
Block a user