feat(visualizer): add image selection and binary data storage

This commit is contained in:
2026-02-06 16:59:22 +08:00
parent e859fef2b3
commit 0b10ab6cfa
5 changed files with 134 additions and 68 deletions

View File

@@ -9,6 +9,7 @@ db_schema = pa.schema(
pa.field("id", pa.int32()), pa.field("id", pa.int32()),
pa.field("label", pa.string()), pa.field("label", pa.string()),
pa.field("vector", pa.list_(pa.float32(), 1024)), pa.field("vector", pa.list_(pa.float32(), 1024)),
pa.field("binary", pa.binary()),
] ]
) )

View File

@@ -4,6 +4,7 @@ import torch
from database import db_manager from database import db_manager
from datasets import load_dataset from datasets import load_dataset
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
@@ -48,7 +49,7 @@ class FeatureRetrieval:
@torch.no_grad() @torch.no_grad()
def establish_database( def establish_database(
self, self,
images: List[Any], images: List[PngImageFile],
labels: List[int] | List[str], labels: List[int] | List[str],
batch_size: int = 64, batch_size: int = 64,
label_map: Optional[Dict[int, str] | List[str]] = None, label_map: Optional[Dict[int, str] | List[str]] = None,
@@ -97,6 +98,7 @@ class FeatureRetrieval:
"id": i + j, "id": i + j,
"label": batch_labels[j], "label": batch_labels[j],
"vector": cls_tokens[j].numpy(), "vector": cls_tokens[j].numpy(),
"binary": batch_imgs[j].tobytes(),
} }
for j in range(actual_batch_size) for j in range(actual_batch_size)
] ]

0
mini-nav/utils/image.py Normal file
View File

View File

@@ -11,6 +11,8 @@ from feature_retrieval import FeatureRetrieval
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
from .events import CellClickedEvent
class APP(Dash): class APP(Dash):
"""Singleton Dash Application""" """Singleton Dash Application"""
@@ -33,18 +35,6 @@ class APP(Dash):
model = AutoModel.from_pretrained("facebook/dinov2-large") model = AutoModel.from_pretrained("facebook/dinov2-large")
APP._feature_retrieval = FeatureRetrieval(processor, model) APP._feature_retrieval = FeatureRetrieval(processor, model)
df = (
db_manager.table.search()
.select(["id", "label", "vector"])
.limit(1000)
.to_polars()
)
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
self.layout = dmc.MantineProvider( self.layout = dmc.MantineProvider(
dmc.Container( dmc.Container(
dmc.Flex( dmc.Flex(
@@ -64,15 +54,21 @@ class APP(Dash):
"textAlign": "center", "textAlign": "center",
"margin": "10px", "margin": "10px",
}, },
# Allow multiple files to be uploaded # Disallow multiple files to be uploaded
multiple=True, multiple=False,
), ),
html.Div(id="output-image-upload"), dmc.Flex(
dag.AgGrid( [
id="ag-grid", html.Div(id="output-image-upload"),
rowData=df.to_dicts(), html.Div(id="output-image-select"),
columnDefs=columnDefs, ],
gap="md",
justify="center",
align="center",
direction="row",
wrap="wrap",
), ),
dag.AgGrid(id="ag-grid"),
], ],
gap="md", gap="md",
justify="center", justify="center",
@@ -83,63 +79,97 @@ class APP(Dash):
) )
) )
@callback(
Output("output-image-upload", "children"), @callback(
Output("ag-grid", "rowData"), Output("output-image-upload", "children"),
Input("upload-image", "contents"), Output("ag-grid", "rowData"),
State("upload-image", "filename"), Output("ag-grid", "columnDefs"),
State("upload-image", "last_modified"), Input("upload-image", "contents"),
) State("upload-image", "filename"),
def update_output( State("upload-image", "last_modified"),
list_of_contents: List[str], )
list_of_names: List[str], def update_output(
list_of_dates: List[int] | List[float], list_of_contents: Optional[List[str]],
list_of_names: Optional[List[str]],
list_of_dates: Optional[List[int] | List[float]],
):
def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image."""
# Remove data URI prefix (e.g., "data:image/png;base64,")
base64_str = contents.split(",")[1]
img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes))
if (
list_of_contents is not None
and list_of_names is not None
and list_of_dates is not None
): ):
def parse_base64_to_pil(contents: str) -> Image.Image: # Process first uploaded image for similarity search
"""Parse base64 string to PIL Image.""" filename = list_of_names[0]
# Remove data URI prefix (e.g., "data:image/png;base64,") uploaddate = list_of_dates[0]
base64_str = contents.split(",")[1] imagecontent = list_of_contents[0]
img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes))
if list_of_contents is not None: pil_image = parse_base64_to_pil(imagecontent)
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
pil_image = parse_base64_to_pil(imagecontent) # Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
# Extract feature vector using DINOv2 # Search for similar images in database
feature_vector = APP._feature_retrieval.extract_single_image_feature( # Exclude 'vector' and 'binary' columns as they are not JSON serializable
pil_image results_df = (
) db_manager.table.search(feature_vector)
.select(["id", "label"])
.limit(10)
.to_polars()
)
# Search for similar images in database # Convert to AgGrid row format
results_df = ( row_data = results_df.to_dicts()
db_manager.table.search(feature_vector)
.select(["id", "label", "vector"])
.limit(10)
.to_polars()
)
# Convert to AgGrid row format columnDefs = [
row_data = results_df.to_dicts() {"headerName": column.capitalize(), "field": column}
for column in results_df.columns
]
# Display uploaded images # Display uploaded images
children = [ children = [
html.H5(filename), html.H5(filename),
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))), html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format # HTML images accept base64 encoded strings in same format
# that is supplied by the upload # that is supplied by the upload
dmc.Image(src=imagecontent), dmc.Image(src=imagecontent),
dmc.Text(f"{feature_vector[:5]}", size="xs"), dmc.Text(f"{feature_vector[:5]}", size="xs"),
] ]
return children, row_data return children, row_data, columnDefs
else:
# When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
# Return empty if no content row_data = df.to_dicts()
return [], []
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
return [], row_data, columnDefs
@callback(
Input("ag-grid", "cellClicked"),
State("ag-grid", "row_data"),
Output("output-image-select", "children"),
)
def update_images_comparison(
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
):
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
return []
return []
app = APP() app = APP()

View File

@@ -0,0 +1,33 @@
from dataclasses import dataclass
from typing import Optional, Union
@dataclass
class CellClickedEvent:
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
- rowIndex (number; optional): rowIndex, typically a row number.
- rowId (boolean I number I string I dict | list; optional): Row Id from the grid, this could be a number automatically, orset via getRowId.
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
"""
value: Optional[Union[bool, int, float, str, dict, list]]
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
"""
colId: Optional[Union[bool, int, float, str, dict, list]]
"""
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
"""
rowIndex: Optional[Union[int, float]]
"""
- rowIndex (number; optional): rowIndex, typically a row number.
"""
timestamp: Optional[Union[bool, int, float, str, dict, list]]
"""
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
"""