mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): implement image selection and display from grid
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
@@ -9,6 +10,21 @@ from tqdm.auto import tqdm
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||
"""Convert a PIL Image to bytes in the specified format.
|
||||
|
||||
Args:
|
||||
image: PIL Image to convert.
|
||||
format: Image format (e.g., 'PNG', 'JPEG').
|
||||
|
||||
Returns:
|
||||
bytes: The encoded image bytes.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format=format)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
class FeatureRetrieval:
|
||||
"""Singleton feature retrieval manager for image feature extraction."""
|
||||
|
||||
@@ -98,7 +114,7 @@ class FeatureRetrieval:
|
||||
"id": i + j,
|
||||
"label": batch_labels[j],
|
||||
"vector": cls_tokens[j].numpy(),
|
||||
"binary": batch_imgs[j].tobytes(),
|
||||
"binary": pil_image_to_bytes(batch_imgs[j]),
|
||||
}
|
||||
for j in range(actual_batch_size)
|
||||
]
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import base64
|
||||
import datetime
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
import dash_ag_grid as dag
|
||||
import dash_mantine_components as dmc
|
||||
import numpy as np
|
||||
from dash import Dash, Input, Output, State, callback, dcc, html
|
||||
from database import db_manager
|
||||
from feature_retrieval import FeatureRetrieval
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .events import CellClickedEvent
|
||||
from visualizer.events import CellClickedEvent
|
||||
|
||||
|
||||
class APP(Dash):
|
||||
@@ -59,12 +58,59 @@ class APP(Dash):
|
||||
),
|
||||
dmc.Flex(
|
||||
[
|
||||
html.Div(id="output-image-upload"),
|
||||
html.Div(id="output-image-select"),
|
||||
html.Div(
|
||||
[
|
||||
html.H4(
|
||||
"Uploaded Image",
|
||||
style={"textAlign": "center"},
|
||||
),
|
||||
html.Div(
|
||||
id="output-image-upload",
|
||||
style={
|
||||
"minWidth": "300px",
|
||||
"minHeight": "300px",
|
||||
"display": "flex",
|
||||
"flexDirection": "column",
|
||||
"alignItems": "center",
|
||||
"justifyContent": "center",
|
||||
"border": "1px dashed #ccc",
|
||||
"borderRadius": "5px",
|
||||
"padding": "10px",
|
||||
},
|
||||
),
|
||||
],
|
||||
style={"flex": 1, "maxWidth": "45%"},
|
||||
),
|
||||
dmc.Divider(
|
||||
variant="solid", orientation="vertical", size="sm"
|
||||
),
|
||||
html.Div(
|
||||
[
|
||||
html.H4(
|
||||
"Selected Image",
|
||||
style={"textAlign": "center"},
|
||||
),
|
||||
html.Div(
|
||||
id="output-image-select",
|
||||
style={
|
||||
"minWidth": "300px",
|
||||
"minHeight": "300px",
|
||||
"display": "flex",
|
||||
"flexDirection": "column",
|
||||
"alignItems": "center",
|
||||
"justifyContent": "center",
|
||||
"border": "1px dashed #ccc",
|
||||
"borderRadius": "5px",
|
||||
"padding": "10px",
|
||||
},
|
||||
),
|
||||
],
|
||||
style={"flex": 1, "maxWidth": "45%"},
|
||||
),
|
||||
],
|
||||
gap="md",
|
||||
justify="center",
|
||||
align="center",
|
||||
align="stretch",
|
||||
direction="row",
|
||||
wrap="wrap",
|
||||
),
|
||||
@@ -89,9 +135,9 @@ class APP(Dash):
|
||||
State("upload-image", "last_modified"),
|
||||
)
|
||||
def update_output(
|
||||
list_of_contents: Optional[List[str]],
|
||||
list_of_names: Optional[List[str]],
|
||||
list_of_dates: Optional[List[int] | List[float]],
|
||||
image_content: Optional[str],
|
||||
filename: Optional[str],
|
||||
timestamp: Optional[int | float],
|
||||
):
|
||||
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||
"""Parse base64 string to PIL Image."""
|
||||
@@ -100,17 +146,8 @@ def update_output(
|
||||
img_bytes = base64.b64decode(base64_str)
|
||||
return Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
if (
|
||||
list_of_contents is not None
|
||||
and list_of_names is not None
|
||||
and list_of_dates is not None
|
||||
):
|
||||
# Process first uploaded image for similarity search
|
||||
filename = list_of_names[0]
|
||||
uploaddate = list_of_dates[0]
|
||||
imagecontent = list_of_contents[0]
|
||||
|
||||
pil_image = parse_base64_to_pil(imagecontent)
|
||||
if image_content is not None and filename is not None and timestamp is not None:
|
||||
pil_image = parse_base64_to_pil(image_content)
|
||||
|
||||
# Extract feature vector using DINOv2
|
||||
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
|
||||
@@ -134,11 +171,9 @@ def update_output(
|
||||
|
||||
# Display uploaded images
|
||||
children = [
|
||||
html.H5(filename),
|
||||
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
||||
# HTML images accept base64 encoded strings in same format
|
||||
# that is supplied by the upload
|
||||
dmc.Image(src=imagecontent),
|
||||
dmc.Image(src=image_content, w="100%", h="auto"),
|
||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||
]
|
||||
|
||||
@@ -146,7 +181,7 @@ def update_output(
|
||||
else:
|
||||
# When contents is None
|
||||
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
|
||||
df = db_manager.table.search().select(["id", "label"]).limit(100).to_polars()
|
||||
|
||||
row_data = df.to_dicts()
|
||||
|
||||
@@ -159,17 +194,81 @@ def update_output(
|
||||
|
||||
|
||||
@callback(
|
||||
Input("ag-grid", "cellClicked"),
|
||||
State("ag-grid", "row_data"),
|
||||
Output("output-image-select", "children"),
|
||||
Input("ag-grid", "cellClicked"),
|
||||
State("ag-grid", "rowData"),
|
||||
)
|
||||
def update_images_comparison(
|
||||
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
|
||||
clicked_event: Optional[CellClickedEvent],
|
||||
row_data: Optional[List[dict]],
|
||||
):
|
||||
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
|
||||
if (
|
||||
clicked_event is None
|
||||
or clicked_event["rowIndex"] is None
|
||||
or row_data is None
|
||||
or len(row_data) == 0
|
||||
):
|
||||
return []
|
||||
|
||||
return []
|
||||
# Get the selected row's data
|
||||
row_index = int(clicked_event["rowIndex"])
|
||||
if row_index >= len(row_data):
|
||||
return []
|
||||
|
||||
selected_row = row_data[row_index]
|
||||
image_id = selected_row.get("id")
|
||||
|
||||
if image_id is None:
|
||||
return []
|
||||
|
||||
# Query database for binary data using the id
|
||||
result = (
|
||||
db_manager.table.search()
|
||||
.where(f"id = {image_id}")
|
||||
.select(["id", "label", "vector", "binary"])
|
||||
.limit(1)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
if result.height == 0:
|
||||
return []
|
||||
|
||||
# Get binary data
|
||||
binary_data = result.row(0, named=True)["binary"]
|
||||
vector = result.row(0, named=True)["vector"]
|
||||
|
||||
# Try to detect if binary_data is a valid image format (PNG/JPEG) or raw pixels
|
||||
# PNG files start with bytes: 89 50 4E 47 (hex) = \x89PNG
|
||||
# JPEG files start with bytes: FF D8 FF
|
||||
is_png = binary_data[:4] == b"\x89PNG"
|
||||
is_jpeg = binary_data[:3] == b"\xff\xd8\xff"
|
||||
|
||||
if is_png or is_jpeg:
|
||||
# Binary data is already in a valid image format
|
||||
mime_type = "image/png" if is_png else "image/jpeg"
|
||||
base64_str = base64.b64encode(binary_data).decode("utf-8")
|
||||
image_content = f"data:{mime_type};base64,{base64_str}"
|
||||
else:
|
||||
# Legacy format: raw pixel bytes (CIFAR-10 images are 32x32 RGB)
|
||||
img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape(32, 32, 3)
|
||||
pil_image = Image.fromarray(img_array)
|
||||
|
||||
# Encode as PNG to get proper image format
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="PNG")
|
||||
buffer.seek(0)
|
||||
|
||||
# Convert to base64 with correct MIME type
|
||||
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
image_content = f"data:image/png;base64,{base64_str}"
|
||||
|
||||
# Display selected image
|
||||
children = [
|
||||
dmc.Image(src=image_content, w="100%", h="auto"),
|
||||
dmc.Text(f"{vector[:5]}", size="xs"),
|
||||
]
|
||||
|
||||
return children
|
||||
|
||||
|
||||
app = APP()
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, TypedDict, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class CellClickedEvent:
|
||||
class CellClickedEvent(TypedDict):
|
||||
"""
|
||||
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
||||
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
||||
|
||||
Reference in New Issue
Block a user