From 0b10ab6cfa0babef078fe6cf1c50da44285ed2c5 Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Fri, 6 Feb 2026 16:59:22 +0800 Subject: [PATCH] feat(visualizer): add image selection and binary data storage --- mini-nav/database.py | 1 + mini-nav/feature_retrieval.py | 4 +- mini-nav/utils/image.py | 0 mini-nav/visualizer/app.py | 164 ++++++++++++++++++++-------------- mini-nav/visualizer/events.py | 33 +++++++ 5 files changed, 134 insertions(+), 68 deletions(-) create mode 100644 mini-nav/utils/image.py create mode 100644 mini-nav/visualizer/events.py diff --git a/mini-nav/database.py b/mini-nav/database.py index eccb8a2..acf1435 100644 --- a/mini-nav/database.py +++ b/mini-nav/database.py @@ -9,6 +9,7 @@ db_schema = pa.schema( pa.field("id", pa.int32()), pa.field("label", pa.string()), pa.field("vector", pa.list_(pa.float32(), 1024)), + pa.field("binary", pa.binary()), ] ) diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index f4cf345..0e27476 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -4,6 +4,7 @@ import torch from database import db_manager from datasets import load_dataset from PIL import Image +from PIL.PngImagePlugin import PngImageFile from tqdm.auto import tqdm from transformers import AutoImageProcessor, AutoModel @@ -48,7 +49,7 @@ class FeatureRetrieval: @torch.no_grad() def establish_database( self, - images: List[Any], + images: List[PngImageFile], labels: List[int] | List[str], batch_size: int = 64, label_map: Optional[Dict[int, str] | List[str]] = None, @@ -97,6 +98,7 @@ class FeatureRetrieval: "id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy(), + "binary": batch_imgs[j].tobytes(), } for j in range(actual_batch_size) ] diff --git a/mini-nav/utils/image.py b/mini-nav/utils/image.py new file mode 100644 index 0000000..e69de29 diff --git a/mini-nav/visualizer/app.py b/mini-nav/visualizer/app.py index 313125a..193c52b 100644 --- a/mini-nav/visualizer/app.py +++ b/mini-nav/visualizer/app.py @@ -11,6 +11,8 @@ from feature_retrieval import FeatureRetrieval from PIL import Image from transformers import AutoImageProcessor, AutoModel +from .events import CellClickedEvent + class APP(Dash): """Singleton Dash Application""" @@ -33,18 +35,6 @@ class APP(Dash): model = AutoModel.from_pretrained("facebook/dinov2-large") APP._feature_retrieval = FeatureRetrieval(processor, model) - df = ( - db_manager.table.search() - .select(["id", "label", "vector"]) - .limit(1000) - .to_polars() - ) - - columnDefs = [ - {"headerName": column.capitalize(), "field": column} - for column in df.columns - ] - self.layout = dmc.MantineProvider( dmc.Container( dmc.Flex( @@ -64,15 +54,21 @@ class APP(Dash): "textAlign": "center", "margin": "10px", }, - # Allow multiple files to be uploaded - multiple=True, + # Disallow multiple files to be uploaded + multiple=False, ), - html.Div(id="output-image-upload"), - dag.AgGrid( - id="ag-grid", - rowData=df.to_dicts(), - columnDefs=columnDefs, + dmc.Flex( + [ + html.Div(id="output-image-upload"), + html.Div(id="output-image-select"), + ], + gap="md", + justify="center", + align="center", + direction="row", + wrap="wrap", ), + dag.AgGrid(id="ag-grid"), ], gap="md", justify="center", @@ -83,63 +79,97 @@ class APP(Dash): ) ) - @callback( - Output("output-image-upload", "children"), - Output("ag-grid", "rowData"), - Input("upload-image", "contents"), - State("upload-image", "filename"), - State("upload-image", "last_modified"), - ) - def update_output( - list_of_contents: List[str], - list_of_names: List[str], - list_of_dates: List[int] | List[float], + +@callback( + Output("output-image-upload", "children"), + Output("ag-grid", "rowData"), + Output("ag-grid", "columnDefs"), + Input("upload-image", "contents"), + State("upload-image", "filename"), + State("upload-image", "last_modified"), +) +def update_output( + list_of_contents: Optional[List[str]], + list_of_names: Optional[List[str]], + list_of_dates: Optional[List[int] | List[float]], +): + def parse_base64_to_pil(contents: str) -> Image.Image: + """Parse base64 string to PIL Image.""" + # Remove data URI prefix (e.g., "data:image/png;base64,") + base64_str = contents.split(",")[1] + img_bytes = base64.b64decode(base64_str) + return Image.open(io.BytesIO(img_bytes)) + + if ( + list_of_contents is not None + and list_of_names is not None + and list_of_dates is not None ): - def parse_base64_to_pil(contents: str) -> Image.Image: - """Parse base64 string to PIL Image.""" - # Remove data URI prefix (e.g., "data:image/png;base64,") - base64_str = contents.split(",")[1] - img_bytes = base64.b64decode(base64_str) - return Image.open(io.BytesIO(img_bytes)) + # Process first uploaded image for similarity search + filename = list_of_names[0] + uploaddate = list_of_dates[0] + imagecontent = list_of_contents[0] - if list_of_contents is not None: - # Process first uploaded image for similarity search - filename = list_of_names[0] - uploaddate = list_of_dates[0] - imagecontent = list_of_contents[0] + pil_image = parse_base64_to_pil(imagecontent) - pil_image = parse_base64_to_pil(imagecontent) + # Extract feature vector using DINOv2 + feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image) - # Extract feature vector using DINOv2 - feature_vector = APP._feature_retrieval.extract_single_image_feature( - pil_image - ) + # Search for similar images in database + # Exclude 'vector' and 'binary' columns as they are not JSON serializable + results_df = ( + db_manager.table.search(feature_vector) + .select(["id", "label"]) + .limit(10) + .to_polars() + ) - # Search for similar images in database - results_df = ( - db_manager.table.search(feature_vector) - .select(["id", "label", "vector"]) - .limit(10) - .to_polars() - ) + # Convert to AgGrid row format + row_data = results_df.to_dicts() - # Convert to AgGrid row format - row_data = results_df.to_dicts() + columnDefs = [ + {"headerName": column.capitalize(), "field": column} + for column in results_df.columns + ] - # Display uploaded images - children = [ - html.H5(filename), - html.H6(str(datetime.datetime.fromtimestamp(uploaddate))), - # HTML images accept base64 encoded strings in same format - # that is supplied by the upload - dmc.Image(src=imagecontent), - dmc.Text(f"{feature_vector[:5]}", size="xs"), - ] + # Display uploaded images + children = [ + html.H5(filename), + html.H6(str(datetime.datetime.fromtimestamp(uploaddate))), + # HTML images accept base64 encoded strings in same format + # that is supplied by the upload + dmc.Image(src=imagecontent), + dmc.Text(f"{feature_vector[:5]}", size="xs"), + ] - return children, row_data + return children, row_data, columnDefs + else: + # When contents is None + # Exclude 'vector' and 'binary' columns as they are not JSON serializable + df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars() - # Return empty if no content - return [], [] + row_data = df.to_dicts() + + columnDefs = [ + {"headerName": column.capitalize(), "field": column} + for column in df.columns + ] + + return [], row_data, columnDefs + + +@callback( + Input("ag-grid", "cellClicked"), + State("ag-grid", "row_data"), + Output("output-image-select", "children"), +) +def update_images_comparison( + clicked_event: Optional[CellClickedEvent], row_data: Optional[dict] +): + if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None: + return [] + + return [] app = APP() diff --git a/mini-nav/visualizer/events.py b/mini-nav/visualizer/events.py new file mode 100644 index 0000000..951686a --- /dev/null +++ b/mini-nav/visualizer/events.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import Optional, Union + + +@dataclass +class CellClickedEvent: + """ + - value (boolean I number | string I dict | list; optional): value of the clicked cell. + - colId (boolean I number I string I dict | list; optional): column where the cell was clicked. + - rowIndex (number; optional): rowIndex, typically a row number. + - rowId (boolean I number I string I dict | list; optional): Row Id from the grid, this could be a number automatically, orset via getRowId. + - timestamp (boolean I number I string I dict I list; optional): timestamp of last action. + """ + + value: Optional[Union[bool, int, float, str, dict, list]] + """ + - value (boolean I number | string I dict | list; optional): value of the clicked cell. + """ + + colId: Optional[Union[bool, int, float, str, dict, list]] + """ + - colId (boolean I number I string I dict | list; optional): column where the cell was clicked. + """ + + rowIndex: Optional[Union[int, float]] + """ + - rowIndex (number; optional): rowIndex, typically a row number. + """ + + timestamp: Optional[Union[bool, int, float, str, dict, list]] + """ + - timestamp (boolean I number I string I dict I list; optional): timestamp of last action. + """