feat(visualizer): implement image selection and display from grid

This commit is contained in:
2026-02-07 11:08:13 +08:00
parent aa6baa87fe
commit d6bb233651
3 changed files with 147 additions and 34 deletions

View File

@@ -1,3 +1,4 @@
import io
from typing import Any, Dict, List, Optional, Union, cast
import torch
@@ -9,6 +10,21 @@ from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
"""Convert a PIL Image to bytes in the specified format.
Args:
image: PIL Image to convert.
format: Image format (e.g., 'PNG', 'JPEG').
Returns:
bytes: The encoded image bytes.
"""
buffer = io.BytesIO()
image.save(buffer, format=format)
return buffer.getvalue()
class FeatureRetrieval:
"""Singleton feature retrieval manager for image feature extraction."""
@@ -98,7 +114,7 @@ class FeatureRetrieval:
"id": i + j,
"label": batch_labels[j],
"vector": cls_tokens[j].numpy(),
"binary": batch_imgs[j].tobytes(),
"binary": pil_image_to_bytes(batch_imgs[j]),
}
for j in range(actual_batch_size)
]

View File

@@ -1,17 +1,16 @@
import base64
import datetime
import io
from typing import List, Optional
import dash_ag_grid as dag
import dash_mantine_components as dmc
import numpy as np
from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager
from feature_retrieval import FeatureRetrieval
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from .events import CellClickedEvent
from visualizer.events import CellClickedEvent
class APP(Dash):
@@ -59,12 +58,59 @@ class APP(Dash):
),
dmc.Flex(
[
html.Div(id="output-image-upload"),
html.Div(id="output-image-select"),
html.Div(
[
html.H4(
"Uploaded Image",
style={"textAlign": "center"},
),
html.Div(
id="output-image-upload",
style={
"minWidth": "300px",
"minHeight": "300px",
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"justifyContent": "center",
"border": "1px dashed #ccc",
"borderRadius": "5px",
"padding": "10px",
},
),
],
style={"flex": 1, "maxWidth": "45%"},
),
dmc.Divider(
variant="solid", orientation="vertical", size="sm"
),
html.Div(
[
html.H4(
"Selected Image",
style={"textAlign": "center"},
),
html.Div(
id="output-image-select",
style={
"minWidth": "300px",
"minHeight": "300px",
"display": "flex",
"flexDirection": "column",
"alignItems": "center",
"justifyContent": "center",
"border": "1px dashed #ccc",
"borderRadius": "5px",
"padding": "10px",
},
),
],
style={"flex": 1, "maxWidth": "45%"},
),
],
gap="md",
justify="center",
align="center",
align="stretch",
direction="row",
wrap="wrap",
),
@@ -89,9 +135,9 @@ class APP(Dash):
State("upload-image", "last_modified"),
)
def update_output(
list_of_contents: Optional[List[str]],
list_of_names: Optional[List[str]],
list_of_dates: Optional[List[int] | List[float]],
image_content: Optional[str],
filename: Optional[str],
timestamp: Optional[int | float],
):
def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image."""
@@ -100,17 +146,8 @@ def update_output(
img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes))
if (
list_of_contents is not None
and list_of_names is not None
and list_of_dates is not None
):
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
pil_image = parse_base64_to_pil(imagecontent)
if image_content is not None and filename is not None and timestamp is not None:
pil_image = parse_base64_to_pil(image_content)
# Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
@@ -134,11 +171,9 @@ def update_output(
# Display uploaded images
children = [
html.H5(filename),
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format
# that is supplied by the upload
dmc.Image(src=imagecontent),
dmc.Image(src=image_content, w="100%", h="auto"),
dmc.Text(f"{feature_vector[:5]}", size="xs"),
]
@@ -146,7 +181,7 @@ def update_output(
else:
# When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
df = db_manager.table.search().select(["id", "label"]).limit(100).to_polars()
row_data = df.to_dicts()
@@ -159,17 +194,81 @@ def update_output(
@callback(
Input("ag-grid", "cellClicked"),
State("ag-grid", "row_data"),
Output("output-image-select", "children"),
Input("ag-grid", "cellClicked"),
State("ag-grid", "rowData"),
)
def update_images_comparison(
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
clicked_event: Optional[CellClickedEvent],
row_data: Optional[List[dict]],
):
if (
clicked_event is None
or clicked_event["rowIndex"] is None
or row_data is None
or len(row_data) == 0
):
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
return []
# Get the selected row's data
row_index = int(clicked_event["rowIndex"])
if row_index >= len(row_data):
return []
selected_row = row_data[row_index]
image_id = selected_row.get("id")
if image_id is None:
return []
# Query database for binary data using the id
result = (
db_manager.table.search()
.where(f"id = {image_id}")
.select(["id", "label", "vector", "binary"])
.limit(1)
.to_polars()
)
if result.height == 0:
return []
# Get binary data
binary_data = result.row(0, named=True)["binary"]
vector = result.row(0, named=True)["vector"]
# Try to detect if binary_data is a valid image format (PNG/JPEG) or raw pixels
# PNG files start with bytes: 89 50 4E 47 (hex) = \x89PNG
# JPEG files start with bytes: FF D8 FF
is_png = binary_data[:4] == b"\x89PNG"
is_jpeg = binary_data[:3] == b"\xff\xd8\xff"
if is_png or is_jpeg:
# Binary data is already in a valid image format
mime_type = "image/png" if is_png else "image/jpeg"
base64_str = base64.b64encode(binary_data).decode("utf-8")
image_content = f"data:{mime_type};base64,{base64_str}"
else:
# Legacy format: raw pixel bytes (CIFAR-10 images are 32x32 RGB)
img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape(32, 32, 3)
pil_image = Image.fromarray(img_array)
# Encode as PNG to get proper image format
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
buffer.seek(0)
# Convert to base64 with correct MIME type
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
image_content = f"data:image/png;base64,{base64_str}"
# Display selected image
children = [
dmc.Image(src=image_content, w="100%", h="auto"),
dmc.Text(f"{vector[:5]}", size="xs"),
]
return children
app = APP()

View File

@@ -1,9 +1,7 @@
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, TypedDict, Union
@dataclass
class CellClickedEvent:
class CellClickedEvent(TypedDict):
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.