feat(visualizer): implement image selection and display from grid

This commit is contained in:
2026-02-07 11:08:13 +08:00
parent aa6baa87fe
commit d6bb233651
3 changed files with 147 additions and 34 deletions

View File

@@ -1,3 +1,4 @@
import io
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Union, cast
import torch import torch
@@ -9,6 +10,21 @@ from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
"""Convert a PIL Image to bytes in the specified format.
Args:
image: PIL Image to convert.
format: Image format (e.g., 'PNG', 'JPEG').
Returns:
bytes: The encoded image bytes.
"""
buffer = io.BytesIO()
image.save(buffer, format=format)
return buffer.getvalue()
class FeatureRetrieval: class FeatureRetrieval:
"""Singleton feature retrieval manager for image feature extraction.""" """Singleton feature retrieval manager for image feature extraction."""
@@ -98,7 +114,7 @@ class FeatureRetrieval:
"id": i + j, "id": i + j,
"label": batch_labels[j], "label": batch_labels[j],
"vector": cls_tokens[j].numpy(), "vector": cls_tokens[j].numpy(),
"binary": batch_imgs[j].tobytes(), "binary": pil_image_to_bytes(batch_imgs[j]),
} }
for j in range(actual_batch_size) for j in range(actual_batch_size)
] ]

View File

@@ -1,17 +1,16 @@
import base64 import base64
import datetime
import io import io
from typing import List, Optional from typing import List, Optional
import dash_ag_grid as dag import dash_ag_grid as dag
import dash_mantine_components as dmc import dash_mantine_components as dmc
import numpy as np
from dash import Dash, Input, Output, State, callback, dcc, html from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager from database import db_manager
from feature_retrieval import FeatureRetrieval from feature_retrieval import FeatureRetrieval
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
from visualizer.events import CellClickedEvent
from .events import CellClickedEvent
class APP(Dash): class APP(Dash):
@@ -59,12 +58,59 @@ class APP(Dash):
), ),
dmc.Flex( dmc.Flex(
[ [
html.Div(id="output-image-upload"), html.Div(
html.Div(id="output-image-select"), [
html.H4(
"Uploaded Image",
style={"textAlign": "center"},
),
html.Div(
id="output-image-upload",
style={
"minWidth": "300px",
"minHeight": "300px",
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"justifyContent": "center",
"border": "1px dashed #ccc",
"borderRadius": "5px",
"padding": "10px",
},
),
],
style={"flex": 1, "maxWidth": "45%"},
),
dmc.Divider(
variant="solid", orientation="vertical", size="sm"
),
html.Div(
[
html.H4(
"Selected Image",
style={"textAlign": "center"},
),
html.Div(
id="output-image-select",
style={
"minWidth": "300px",
"minHeight": "300px",
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"justifyContent": "center",
"border": "1px dashed #ccc",
"borderRadius": "5px",
"padding": "10px",
},
),
],
style={"flex": 1, "maxWidth": "45%"},
),
], ],
gap="md", gap="md",
justify="center", justify="center",
align="center", align="stretch",
direction="row", direction="row",
wrap="wrap", wrap="wrap",
), ),
@@ -89,9 +135,9 @@ class APP(Dash):
State("upload-image", "last_modified"), State("upload-image", "last_modified"),
) )
def update_output( def update_output(
list_of_contents: Optional[List[str]], image_content: Optional[str],
list_of_names: Optional[List[str]], filename: Optional[str],
list_of_dates: Optional[List[int] | List[float]], timestamp: Optional[int | float],
): ):
def parse_base64_to_pil(contents: str) -> Image.Image: def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image.""" """Parse base64 string to PIL Image."""
@@ -100,17 +146,8 @@ def update_output(
img_bytes = base64.b64decode(base64_str) img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes)) return Image.open(io.BytesIO(img_bytes))
if ( if image_content is not None and filename is not None and timestamp is not None:
list_of_contents is not None pil_image = parse_base64_to_pil(image_content)
and list_of_names is not None
and list_of_dates is not None
):
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
pil_image = parse_base64_to_pil(imagecontent)
# Extract feature vector using DINOv2 # Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image) feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
@@ -134,11 +171,9 @@ def update_output(
# Display uploaded images # Display uploaded images
children = [ children = [
html.H5(filename),
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format # HTML images accept base64 encoded strings in same format
# that is supplied by the upload # that is supplied by the upload
dmc.Image(src=imagecontent), dmc.Image(src=image_content, w="100%", h="auto"),
dmc.Text(f"{feature_vector[:5]}", size="xs"), dmc.Text(f"{feature_vector[:5]}", size="xs"),
] ]
@@ -146,7 +181,7 @@ def update_output(
else: else:
# When contents is None # When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable # Exclude 'vector' and 'binary' columns as they are not JSON serializable
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars() df = db_manager.table.search().select(["id", "label"]).limit(100).to_polars()
row_data = df.to_dicts() row_data = df.to_dicts()
@@ -159,17 +194,81 @@ def update_output(
@callback( @callback(
Input("ag-grid", "cellClicked"),
State("ag-grid", "row_data"),
Output("output-image-select", "children"), Output("output-image-select", "children"),
Input("ag-grid", "cellClicked"),
State("ag-grid", "rowData"),
) )
def update_images_comparison( def update_images_comparison(
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict] clicked_event: Optional[CellClickedEvent],
row_data: Optional[List[dict]],
): ):
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None: if (
clicked_event is None
or clicked_event["rowIndex"] is None
or row_data is None
or len(row_data) == 0
):
return [] return []
return [] # Get the selected row's data
row_index = int(clicked_event["rowIndex"])
if row_index >= len(row_data):
return []
selected_row = row_data[row_index]
image_id = selected_row.get("id")
if image_id is None:
return []
# Query database for binary data using the id
result = (
db_manager.table.search()
.where(f"id = {image_id}")
.select(["id", "label", "vector", "binary"])
.limit(1)
.to_polars()
)
if result.height == 0:
return []
# Get binary data
binary_data = result.row(0, named=True)["binary"]
vector = result.row(0, named=True)["vector"]
# Try to detect if binary_data is a valid image format (PNG/JPEG) or raw pixels
# PNG files start with bytes: 89 50 4E 47 (hex) = \x89PNG
# JPEG files start with bytes: FF D8 FF
is_png = binary_data[:4] == b"\x89PNG"
is_jpeg = binary_data[:3] == b"\xff\xd8\xff"
if is_png or is_jpeg:
# Binary data is already in a valid image format
mime_type = "image/png" if is_png else "image/jpeg"
base64_str = base64.b64encode(binary_data).decode("utf-8")
image_content = f"data:{mime_type};base64,{base64_str}"
else:
# Legacy format: raw pixel bytes (CIFAR-10 images are 32x32 RGB)
img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape(32, 32, 3)
pil_image = Image.fromarray(img_array)
# Encode as PNG to get proper image format
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
# Convert to base64 with correct MIME type
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
image_content = f"data:image/png;base64,{base64_str}"
# Display selected image
children = [
dmc.Image(src=image_content, w="100%", h="auto"),
dmc.Text(f"{vector[:5]}", size="xs"),
]
return children
app = APP() app = APP()

View File

@@ -1,9 +1,7 @@
from dataclasses import dataclass from typing import Optional, TypedDict, Union
from typing import Optional, Union
@dataclass class CellClickedEvent(TypedDict):
class CellClickedEvent:
""" """
- value (boolean I number | string I dict | list; optional): value of the clicked cell. - value (boolean I number | string I dict | list; optional): value of the clicked cell.
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked. - colId (boolean I number I string I dict | list; optional): column where the cell was clicked.