diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index 0e27476..0181fd5 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -1,3 +1,4 @@ +import io from typing import Any, Dict, List, Optional, Union, cast import torch @@ -9,6 +10,21 @@ from tqdm.auto import tqdm from transformers import AutoImageProcessor, AutoModel +def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes: + """Convert a PIL Image to bytes in the specified format. + + Args: + image: PIL Image to convert. + format: Image format (e.g., 'PNG', 'JPEG'). + + Returns: + bytes: The encoded image bytes. + """ + buffer = io.BytesIO() + image.save(buffer, format=format) + return buffer.getvalue() + + class FeatureRetrieval: """Singleton feature retrieval manager for image feature extraction.""" @@ -98,7 +114,7 @@ class FeatureRetrieval: "id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy(), - "binary": batch_imgs[j].tobytes(), + "binary": pil_image_to_bytes(batch_imgs[j]), } for j in range(actual_batch_size) ] diff --git a/mini-nav/visualizer/app.py b/mini-nav/visualizer/app.py index 193c52b..73ee2a3 100644 --- a/mini-nav/visualizer/app.py +++ b/mini-nav/visualizer/app.py @@ -1,17 +1,16 @@ import base64 -import datetime import io from typing import List, Optional import dash_ag_grid as dag import dash_mantine_components as dmc +import numpy as np from dash import Dash, Input, Output, State, callback, dcc, html from database import db_manager from feature_retrieval import FeatureRetrieval from PIL import Image from transformers import AutoImageProcessor, AutoModel - -from .events import CellClickedEvent +from visualizer.events import CellClickedEvent class APP(Dash): @@ -59,12 +58,59 @@ class APP(Dash): ), dmc.Flex( [ - html.Div(id="output-image-upload"), - html.Div(id="output-image-select"), + html.Div( + [ + html.H4( + "Uploaded Image", + style={"textAlign": "center"}, + ), + html.Div( + id="output-image-upload", + style={ + "minWidth": "300px", + "minHeight": "300px", + "display": "flex", + "flexDirection": "column", + "alignItems": "center", + "justifyContent": "center", + "border": "1px dashed #ccc", + "borderRadius": "5px", + "padding": "10px", + }, + ), + ], + style={"flex": 1, "maxWidth": "45%"}, + ), + dmc.Divider( + variant="solid", orientation="vertical", size="sm" + ), + html.Div( + [ + html.H4( + "Selected Image", + style={"textAlign": "center"}, + ), + html.Div( + id="output-image-select", + style={ + "minWidth": "300px", + "minHeight": "300px", + "display": "flex", + "flexDirection": "column", + "alignItems": "center", + "justifyContent": "center", + "border": "1px dashed #ccc", + "borderRadius": "5px", + "padding": "10px", + }, + ), + ], + style={"flex": 1, "maxWidth": "45%"}, + ), ], gap="md", justify="center", - align="center", + align="stretch", direction="row", wrap="wrap", ), @@ -89,9 +135,9 @@ class APP(Dash): State("upload-image", "last_modified"), ) def update_output( - list_of_contents: Optional[List[str]], - list_of_names: Optional[List[str]], - list_of_dates: Optional[List[int] | List[float]], + image_content: Optional[str], + filename: Optional[str], + timestamp: Optional[int | float], ): def parse_base64_to_pil(contents: str) -> Image.Image: """Parse base64 string to PIL Image.""" @@ -100,17 +146,8 @@ def update_output( img_bytes = base64.b64decode(base64_str) return Image.open(io.BytesIO(img_bytes)) - if ( - list_of_contents is not None - and list_of_names is not None - and list_of_dates is not None - ): - # Process first uploaded image for similarity search - filename = list_of_names[0] - uploaddate = list_of_dates[0] - imagecontent = list_of_contents[0] - - pil_image = parse_base64_to_pil(imagecontent) + if image_content is not None and filename is not None and timestamp is not None: + pil_image = parse_base64_to_pil(image_content) # Extract feature vector using DINOv2 feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image) @@ -134,11 +171,9 @@ def update_output( # Display uploaded images children = [ - html.H5(filename), - html.H6(str(datetime.datetime.fromtimestamp(uploaddate))), # HTML images accept base64 encoded strings in same format # that is supplied by the upload - dmc.Image(src=imagecontent), + dmc.Image(src=image_content, w="100%", h="auto"), dmc.Text(f"{feature_vector[:5]}", size="xs"), ] @@ -146,7 +181,7 @@ def update_output( else: # When contents is None # Exclude 'vector' and 'binary' columns as they are not JSON serializable - df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars() + df = db_manager.table.search().select(["id", "label"]).limit(100).to_polars() row_data = df.to_dicts() @@ -159,17 +194,81 @@ def update_output( @callback( - Input("ag-grid", "cellClicked"), - State("ag-grid", "row_data"), Output("output-image-select", "children"), + Input("ag-grid", "cellClicked"), + State("ag-grid", "rowData"), ) def update_images_comparison( - clicked_event: Optional[CellClickedEvent], row_data: Optional[dict] + clicked_event: Optional[CellClickedEvent], + row_data: Optional[List[dict]], ): - if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None: + if ( + clicked_event is None + or clicked_event["rowIndex"] is None + or row_data is None + or len(row_data) == 0 + ): return [] - return [] + # Get the selected row's data + row_index = int(clicked_event["rowIndex"]) + if row_index >= len(row_data): + return [] + + selected_row = row_data[row_index] + image_id = selected_row.get("id") + + if image_id is None: + return [] + + # Query database for binary data using the id + result = ( + db_manager.table.search() + .where(f"id = {image_id}") + .select(["id", "label", "vector", "binary"]) + .limit(1) + .to_polars() + ) + + if result.height == 0: + return [] + + # Get binary data + binary_data = result.row(0, named=True)["binary"] + vector = result.row(0, named=True)["vector"] + + # Try to detect if binary_data is a valid image format (PNG/JPEG) or raw pixels + # PNG files start with bytes: 89 50 4E 47 (hex) = \x89PNG + # JPEG files start with bytes: FF D8 FF + is_png = binary_data[:4] == b"\x89PNG" + is_jpeg = binary_data[:3] == b"\xff\xd8\xff" + + if is_png or is_jpeg: + # Binary data is already in a valid image format + mime_type = "image/png" if is_png else "image/jpeg" + base64_str = base64.b64encode(binary_data).decode("utf-8") + image_content = f"data:{mime_type};base64,{base64_str}" + else: + # Legacy format: raw pixel bytes (CIFAR-10 images are 32x32 RGB) + img_array = np.frombuffer(binary_data, dtype=np.uint8).reshape(32, 32, 3) + pil_image = Image.fromarray(img_array) + + # Encode as PNG to get proper image format + buffer = io.BytesIO() + pil_image.save(buffer, format="PNG") + buffer.seek(0) + + # Convert to base64 with correct MIME type + base64_str = base64.b64encode(buffer.getvalue()).decode("utf-8") + image_content = f"data:image/png;base64,{base64_str}" + + # Display selected image + children = [ + dmc.Image(src=image_content, w="100%", h="auto"), + dmc.Text(f"{vector[:5]}", size="xs"), + ] + + return children app = APP() diff --git a/mini-nav/visualizer/events.py b/mini-nav/visualizer/events.py index 951686a..0559ce2 100644 --- a/mini-nav/visualizer/events.py +++ b/mini-nav/visualizer/events.py @@ -1,9 +1,7 @@ -from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, TypedDict, Union -@dataclass -class CellClickedEvent: +class CellClickedEvent(TypedDict): """ - value (boolean I number | string I dict | list; optional): value of the clicked cell. - colId (boolean I number I string I dict | list; optional): column where the cell was clicked.