feat(visualizer): add image selection and binary data storage

This commit is contained in:
2026-02-06 16:59:22 +08:00
parent e859fef2b3
commit 0b10ab6cfa
5 changed files with 134 additions and 68 deletions

View File

@@ -9,6 +9,7 @@ db_schema = pa.schema(
pa.field("id", pa.int32()),
pa.field("label", pa.string()),
pa.field("vector", pa.list_(pa.float32(), 1024)),
pa.field("binary", pa.binary()),
]
)

View File

@@ -4,6 +4,7 @@ import torch
from database import db_manager
from datasets import load_dataset
from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
@@ -48,7 +49,7 @@ class FeatureRetrieval:
@torch.no_grad()
def establish_database(
self,
images: List[Any],
images: List[PngImageFile],
labels: List[int] | List[str],
batch_size: int = 64,
label_map: Optional[Dict[int, str] | List[str]] = None,
@@ -97,6 +98,7 @@ class FeatureRetrieval:
"id": i + j,
"label": batch_labels[j],
"vector": cls_tokens[j].numpy(),
"binary": batch_imgs[j].tobytes(),
}
for j in range(actual_batch_size)
]

0
mini-nav/utils/image.py Normal file
View File

View File

@@ -11,6 +11,8 @@ from feature_retrieval import FeatureRetrieval
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from .events import CellClickedEvent
class APP(Dash):
"""Singleton Dash Application"""
@@ -33,18 +35,6 @@ class APP(Dash):
model = AutoModel.from_pretrained("facebook/dinov2-large")
APP._feature_retrieval = FeatureRetrieval(processor, model)
df = (
db_manager.table.search()
.select(["id", "label", "vector"])
.limit(1000)
.to_polars()
)
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
self.layout = dmc.MantineProvider(
dmc.Container(
dmc.Flex(
@@ -64,15 +54,21 @@ class APP(Dash):
"textAlign": "center",
"margin": "10px",
},
# Allow multiple files to be uploaded
multiple=True,
# Disallow multiple files to be uploaded
multiple=False,
),
html.Div(id="output-image-upload"),
dag.AgGrid(
id="ag-grid",
rowData=df.to_dicts(),
columnDefs=columnDefs,
dmc.Flex(
[
html.Div(id="output-image-upload"),
html.Div(id="output-image-select"),
],
gap="md",
justify="center",
align="center",
direction="row",
wrap="wrap",
),
dag.AgGrid(id="ag-grid"),
],
gap="md",
justify="center",
@@ -83,63 +79,97 @@ class APP(Dash):
)
)
@callback(
Output("output-image-upload", "children"),
Output("ag-grid", "rowData"),
Input("upload-image", "contents"),
State("upload-image", "filename"),
State("upload-image", "last_modified"),
)
def update_output(
list_of_contents: List[str],
list_of_names: List[str],
list_of_dates: List[int] | List[float],
@callback(
Output("output-image-upload", "children"),
Output("ag-grid", "rowData"),
Output("ag-grid", "columnDefs"),
Input("upload-image", "contents"),
State("upload-image", "filename"),
State("upload-image", "last_modified"),
)
def update_output(
list_of_contents: Optional[List[str]],
list_of_names: Optional[List[str]],
list_of_dates: Optional[List[int] | List[float]],
):
def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image."""
# Remove data URI prefix (e.g., "data:image/png;base64,")
base64_str = contents.split(",")[1]
img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes))
if (
list_of_contents is not None
and list_of_names is not None
and list_of_dates is not None
):
def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image."""
# Remove data URI prefix (e.g., "data:image/png;base64,")
base64_str = contents.split(",")[1]
img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes))
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
if list_of_contents is not None:
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
pil_image = parse_base64_to_pil(imagecontent)
pil_image = parse_base64_to_pil(imagecontent)
# Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
# Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(
pil_image
)
# Search for similar images in database
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
results_df = (
db_manager.table.search(feature_vector)
.select(["id", "label"])
.limit(10)
.to_polars()
)
# Search for similar images in database
results_df = (
db_manager.table.search(feature_vector)
.select(["id", "label", "vector"])
.limit(10)
.to_polars()
)
# Convert to AgGrid row format
row_data = results_df.to_dicts()
# Convert to AgGrid row format
row_data = results_df.to_dicts()
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in results_df.columns
]
# Display uploaded images
children = [
html.H5(filename),
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format
# that is supplied by the upload
dmc.Image(src=imagecontent),
dmc.Text(f"{feature_vector[:5]}", size="xs"),
]
# Display uploaded images
children = [
html.H5(filename),
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format
# that is supplied by the upload
dmc.Image(src=imagecontent),
dmc.Text(f"{feature_vector[:5]}", size="xs"),
]
return children, row_data
return children, row_data, columnDefs
else:
# When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
# Return empty if no content
return [], []
row_data = df.to_dicts()
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
return [], row_data, columnDefs
@callback(
Input("ag-grid", "cellClicked"),
State("ag-grid", "row_data"),
Output("output-image-select", "children"),
)
def update_images_comparison(
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
):
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
return []
return []
app = APP()

View File

@@ -0,0 +1,33 @@
from dataclasses import dataclass
from typing import Optional, Union
@dataclass
class CellClickedEvent:
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
- rowIndex (number; optional): rowIndex, typically a row number.
- rowId (boolean I number I string I dict | list; optional): Row Id from the grid, this could be a number automatically, orset via getRowId.
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
"""
value: Optional[Union[bool, int, float, str, dict, list]]
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
"""
colId: Optional[Union[bool, int, float, str, dict, list]]
"""
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
"""
rowIndex: Optional[Union[int, float]]
"""
- rowIndex (number; optional): rowIndex, typically a row number.
"""
timestamp: Optional[Union[bool, int, float, str, dict, list]]
"""
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
"""