mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): implement image selection and display from grid
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import io
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -9,6 +10,21 @@ from tqdm.auto import tqdm
|
|||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
|
||||||
|
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||||
|
"""Convert a PIL Image to bytes in the specified format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image to convert.
|
||||||
|
format: Image format (e.g., 'PNG', 'JPEG').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The encoded image bytes.
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
image.save(buffer, format=format)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
class FeatureRetrieval:
|
class FeatureRetrieval:
|
||||||
"""Singleton feature retrieval manager for image feature extraction."""
|
"""Singleton feature retrieval manager for image feature extraction."""
|
||||||
|
|
||||||
@@ -98,7 +114,7 @@ class FeatureRetrieval:
|
|||||||
"id": i + j,
|
"id": i + j,
|
||||||
"label": batch_labels[j],
|
"label": batch_labels[j],
|
||||||
"vector": cls_tokens[j].numpy(),
|
"vector": cls_tokens[j].numpy(),
|
||||||
"binary": batch_imgs[j].tobytes(),
|
"binary": pil_image_to_bytes(batch_imgs[j]),
|
||||||
}
|
}
|
||||||
for j in range(actual_batch_size)
|
for j in range(actual_batch_size)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
|
||||||
import io
|
import io
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import dash_ag_grid as dag
|
import dash_ag_grid as dag
|
||||||
import dash_mantine_components as dmc
|
import dash_mantine_components as dmc
|
||||||
|
import numpy as np
|
||||||
from dash import Dash, Input, Output, State, callback, dcc, html
|
from dash import Dash, Input, Output, State, callback, dcc, html
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
from feature_retrieval import FeatureRetrieval
|
from feature_retrieval import FeatureRetrieval
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
from visualizer.events import CellClickedEvent
|
||||||
from .events import CellClickedEvent
|
|
||||||
|
|
||||||
|
|
||||||
class APP(Dash):
|
class APP(Dash):
|
||||||
@@ -59,12 +58,59 @@ class APP(Dash):
|
|||||||
),
|
),
|
||||||
dmc.Flex(
|
dmc.Flex(
|
||||||
[
|
[
|
||||||
html.Div(id="output-image-upload"),
|
html.Div(
|
||||||
html.Div(id="output-image-select"),
|
[
|
||||||
|
html.H4(
|
||||||
|
"Uploaded Image",
|
||||||
|
style={"textAlign": "center"},
|
||||||
|
),
|
||||||
|
html.Div(
|
||||||
|
id="output-image-upload",
|
||||||
|
style={
|
||||||
|
"minWidth": "300px",
|
||||||
|
"minHeight": "300px",
|
||||||
|
"display": "flex",
|
||||||
|
"flexDirection": "column",
|
||||||
|
"alignItems": "center",
|
||||||
|
"justifyContent": "center",
|
||||||
|
"border": "1px dashed #ccc",
|
||||||
|
"borderRadius": "5px",
|
||||||
|
"padding": "10px",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
style={"flex": 1, "maxWidth": "45%"},
|
||||||
|
),
|
||||||
|
dmc.Divider(
|
||||||
|
variant="solid", orientation="vertical", size="sm"
|
||||||
|
),
|
||||||
|
html.Div(
|
||||||
|
[
|
||||||
|
html.H4(
|
||||||
|
"Selected Image",
|
||||||
|
style={"textAlign": "center"},
|
||||||
|
),
|
||||||
|
html.Div(
|
||||||
|
id="output-image-select",
|
||||||
|
style={
|
||||||
|
"minWidth": "300px",
|
||||||
|
"minHeight": "300px",
|
||||||
|
"display": "flex",
|
||||||
|
"flexDirection": "column",
|
||||||
|
"alignItems": "center",
|
||||||
|
"justifyContent": "center",
|
||||||
|
"border": "1px dashed #ccc",
|
||||||
|
"borderRadius": "5px",
|
||||||
|
"padding": "10px",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
style={"flex": 1, "maxWidth": "45%"},
|
||||||
|
),
|
||||||
],
|
],
|
||||||
gap="md",
|
gap="md",
|
||||||
justify="center",
|
justify="center",
|
||||||
align="center",
|
align="stretch",
|
||||||
direction="row",
|
direction="row",
|
||||||
wrap="wrap",
|
wrap="wrap",
|
||||||
),
|
),
|
||||||
@@ -89,9 +135,9 @@ class APP(Dash):
|
|||||||
State("upload-image", "last_modified"),
|
State("upload-image", "last_modified"),
|
||||||
)
|
)
|
||||||
def update_output(
|
def update_output(
|
||||||
list_of_contents: Optional[List[str]],
|
image_content: Optional[str],
|
||||||
list_of_names: Optional[List[str]],
|
filename: Optional[str],
|
||||||
list_of_dates: Optional[List[int] | List[float]],
|
timestamp: Optional[int | float],
|
||||||
):
|
):
|
||||||
def parse_base64_to_pil(contents: str) -> Image.Image:
|
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||||
"""Parse base64 string to PIL Image."""
|
"""Parse base64 string to PIL Image."""
|
||||||
@@ -100,17 +146,8 @@ def update_output(
|
|||||||
img_bytes = base64.b64decode(base64_str)
|
img_bytes = base64.b64decode(base64_str)
|
||||||
return Image.open(io.BytesIO(img_bytes))
|
return Image.open(io.BytesIO(img_bytes))
|
||||||
|
|
||||||
if (
|
if image_content is not None and filename is not None and timestamp is not None:
|
||||||
list_of_contents is not None
|
pil_image = parse_base64_to_pil(image_content)
|
||||||
and list_of_names is not None
|
|
||||||
and list_of_dates is not None
|
|
||||||
):
|
|
||||||
# Process first uploaded image for similarity search
|
|
||||||
filename = list_of_names[0]
|
|
||||||
uploaddate = list_of_dates[0]
|
|
||||||
imagecontent = list_of_contents[0]
|
|
||||||
|
|
||||||
pil_image = parse_base64_to_pil(imagecontent)
|
|
||||||
|
|
||||||
# Extract feature vector using DINOv2
|
# Extract feature vector using DINOv2
|
||||||
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
|
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
|
||||||
@@ -134,11 +171,9 @@ def update_output(
|
|||||||
|
|
||||||
# Display uploaded images
|
# Display uploaded images
|
||||||
children = [
|
children = [
|
||||||
html.H5(filename),
|
|
||||||
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
|
||||||
# HTML images accept base64 encoded strings in same format
|
# HTML images accept base64 encoded strings in same format
|
||||||
# that is supplied by the upload
|
# that is supplied by the upload
|
||||||
dmc.Image(src=imagecontent),
|
dmc.Image(src=image_content, w="100%", h="auto"),
|
||||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -146,7 +181,7 @@ def update_output(
|
|||||||
else:
|
else:
|
||||||
# When contents is None
|
# When contents is None
|
||||||
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||||
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
|
df = db_manager.table.search().select(["id", "label"]).limit(100).to_polars()
|
||||||
|
|
||||||
row_data = df.to_dicts()
|
row_data = df.to_dicts()
|
||||||
|
|
||||||
@@ -159,17 +194,81 @@ def update_output(
|
|||||||
|
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Input("ag-grid", "cellClicked"),
|
|
||||||
State("ag-grid", "row_data"),
|
|
||||||
Output("output-image-select", "children"),
|
Output("output-image-select", "children"),
|
||||||
|
Input("ag-grid", "cellClicked"),
|
||||||
|
State("ag-grid", "rowData"),
|
||||||
)
|
)
|
||||||
def update_images_comparison(
|
def update_images_comparison(
|
||||||
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
|
clicked_event: Optional[CellClickedEvent],
|
||||||
|
row_data: Optional[List[dict]],
|
||||||
):
|
):
|
||||||
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
|
if (
|
||||||
|
clicked_event is None
|
||||||
|
or clicked_event["rowIndex"] is None
|
||||||
|
or row_data is None
|
||||||
|
or len(row_data) == 0
|
||||||
|
):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# Get the selected row's data
|
||||||
|
row_index = int(clicked_event["rowIndex"])
|
||||||
|
if row_index >= len(row_data):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
selected_row = row_data[row_index]
|
||||||
|
image_id = selected_row.get("id")
|
||||||
|
|
||||||
|
if image_id is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Query database for binary data using the id
|
||||||
|
result = (
|
||||||
|
db_manager.table.search()
|
||||||
|
.where(f"id = {image_id}")
|
||||||
|
.select(["id", "label", "vector", "binary"])
|
||||||
|
.limit(1)
|
||||||
|
.to_polars()
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.height == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get binary data
|
||||||
|
binary_data = result.row(0, named=True)["binary"]
|
||||||
|
vector = result.row(0, named=True)["vector"]
|
||||||
|
|
||||||
|
# Try to detect if binary_data is a valid image format (PNG/JPEG) or raw pixels
|
||||||
|
# PNG files start with bytes: 89 50 4E 47 (hex) = \x89PNG
|
||||||
|
# JPEG files start with bytes: FF D8 FF
|
||||||
|
is_png = binary_data[:4] == b"\x89PNG"
|
||||||
|
is_jpeg = binary_data[:3] == b"\xff\xd8\xff"
|
||||||
|
|
||||||
|
if is_png or is_jpeg:
|
||||||
|
# Binary data is already in a valid image format
|
||||||
|
mime_type = "image/png" if is_png else "image/jpeg"
|
||||||
|
base64_str = base64.b64encode(binary_data).decode("utf-8")
|
||||||
|
image_content = f"data:{mime_type};base64,{base64_str}"
|
||||||
|
else:
|
||||||
|
# Legacy format: raw pixel bytes (CIFAR-10 images are 32x32 RGB)
|
||||||
|
img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape(32, 32, 3)
|
||||||
|
pil_image = Image.fromarray(img_array)
|
||||||
|
|
||||||
|
# Encode as PNG to get proper image format
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
pil_image.save(buffer, format="PNG")
|
||||||
|
buffer.seek(0)
|
||||||
|
|
||||||
|
# Convert to base64 with correct MIME type
|
||||||
|
base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
image_content = f"data:image/png;base64,{base64_str}"
|
||||||
|
|
||||||
|
# Display selected image
|
||||||
|
children = [
|
||||||
|
dmc.Image(src=image_content, w="100%", h="auto"),
|
||||||
|
dmc.Text(f"{vector[:5]}", size="xs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return children
|
||||||
|
|
||||||
|
|
||||||
app = APP()
|
app = APP()
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from typing import Optional, TypedDict, Union
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class CellClickedEvent(TypedDict):
|
||||||
class CellClickedEvent:
|
|
||||||
"""
|
"""
|
||||||
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
||||||
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
||||||
|
|||||||
Reference in New Issue
Block a user