feat(visualizer): add image selection and binary data storage

This commit is contained in:
2026-02-06 16:59:22 +08:00
parent e859fef2b3
commit 0b10ab6cfa
5 changed files with 134 additions and 68 deletions

View File

@@ -9,6 +9,7 @@ db_schema = pa.schema(
pa.field("id", pa.int32()), pa.field("id", pa.int32()),
pa.field("label", pa.string()), pa.field("label", pa.string()),
pa.field("vector", pa.list_(pa.float32(), 1024)), pa.field("vector", pa.list_(pa.float32(), 1024)),
pa.field("binary", pa.binary()),
] ]
) )

View File

@@ -4,6 +4,7 @@ import torch
from database import db_manager from database import db_manager
from datasets import load_dataset from datasets import load_dataset
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngImageFile
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
@@ -48,7 +49,7 @@ class FeatureRetrieval:
@torch.no_grad() @torch.no_grad()
def establish_database( def establish_database(
self, self,
images: List[Any], images: List[PngImageFile],
labels: List[int] | List[str], labels: List[int] | List[str],
batch_size: int = 64, batch_size: int = 64,
label_map: Optional[Dict[int, str] | List[str]] = None, label_map: Optional[Dict[int, str] | List[str]] = None,
@@ -97,6 +98,7 @@ class FeatureRetrieval:
"id": i + j, "id": i + j,
"label": batch_labels[j], "label": batch_labels[j],
"vector": cls_tokens[j].numpy(), "vector": cls_tokens[j].numpy(),
"binary": batch_imgs[j].tobytes(),
} }
for j in range(actual_batch_size) for j in range(actual_batch_size)
] ]

0
mini-nav/utils/image.py Normal file
View File

View File

@@ -11,6 +11,8 @@ from feature_retrieval import FeatureRetrieval
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
from .events import CellClickedEvent
class APP(Dash): class APP(Dash):
"""Singleton Dash Application""" """Singleton Dash Application"""
@@ -33,18 +35,6 @@ class APP(Dash):
model = AutoModel.from_pretrained("facebook/dinov2-large") model = AutoModel.from_pretrained("facebook/dinov2-large")
APP._feature_retrieval = FeatureRetrieval(processor, model) APP._feature_retrieval = FeatureRetrieval(processor, model)
df = (
db_manager.table.search()
.select(["id", "label", "vector"])
.limit(1000)
.to_polars()
)
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
self.layout = dmc.MantineProvider( self.layout = dmc.MantineProvider(
dmc.Container( dmc.Container(
dmc.Flex( dmc.Flex(
@@ -64,15 +54,21 @@ class APP(Dash):
"textAlign": "center", "textAlign": "center",
"margin": "10px", "margin": "10px",
}, },
# Allow multiple files to be uploaded # Disallow multiple files to be uploaded
multiple=True, multiple=False,
), ),
dmc.Flex(
[
html.Div(id="output-image-upload"), html.Div(id="output-image-upload"),
dag.AgGrid( html.Div(id="output-image-select"),
id="ag-grid", ],
rowData=df.to_dicts(), gap="md",
columnDefs=columnDefs, justify="center",
align="center",
direction="row",
wrap="wrap",
), ),
dag.AgGrid(id="ag-grid"),
], ],
gap="md", gap="md",
justify="center", justify="center",
@@ -83,18 +79,20 @@ class APP(Dash):
) )
) )
@callback(
@callback(
Output("output-image-upload", "children"), Output("output-image-upload", "children"),
Output("ag-grid", "rowData"), Output("ag-grid", "rowData"),
Output("ag-grid", "columnDefs"),
Input("upload-image", "contents"), Input("upload-image", "contents"),
State("upload-image", "filename"), State("upload-image", "filename"),
State("upload-image", "last_modified"), State("upload-image", "last_modified"),
) )
def update_output( def update_output(
list_of_contents: List[str], list_of_contents: Optional[List[str]],
list_of_names: List[str], list_of_names: Optional[List[str]],
list_of_dates: List[int] | List[float], list_of_dates: Optional[List[int] | List[float]],
): ):
def parse_base64_to_pil(contents: str) -> Image.Image: def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image.""" """Parse base64 string to PIL Image."""
# Remove data URI prefix (e.g., "data:image/png;base64,") # Remove data URI prefix (e.g., "data:image/png;base64,")
@@ -102,7 +100,11 @@ class APP(Dash):
img_bytes = base64.b64decode(base64_str) img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes)) return Image.open(io.BytesIO(img_bytes))
if list_of_contents is not None: if (
list_of_contents is not None
and list_of_names is not None
and list_of_dates is not None
):
# Process first uploaded image for similarity search # Process first uploaded image for similarity search
filename = list_of_names[0] filename = list_of_names[0]
uploaddate = list_of_dates[0] uploaddate = list_of_dates[0]
@@ -111,14 +113,13 @@ class APP(Dash):
pil_image = parse_base64_to_pil(imagecontent) pil_image = parse_base64_to_pil(imagecontent)
# Extract feature vector using DINOv2 # Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature( feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
pil_image
)
# Search for similar images in database # Search for similar images in database
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
results_df = ( results_df = (
db_manager.table.search(feature_vector) db_manager.table.search(feature_vector)
.select(["id", "label", "vector"]) .select(["id", "label"])
.limit(10) .limit(10)
.to_polars() .to_polars()
) )
@@ -126,6 +127,11 @@ class APP(Dash):
# Convert to AgGrid row format # Convert to AgGrid row format
row_data = results_df.to_dicts() row_data = results_df.to_dicts()
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in results_df.columns
]
# Display uploaded images # Display uploaded images
children = [ children = [
html.H5(filename), html.H5(filename),
@@ -136,10 +142,34 @@ class APP(Dash):
dmc.Text(f"{feature_vector[:5]}", size="xs"), dmc.Text(f"{feature_vector[:5]}", size="xs"),
] ]
return children, row_data return children, row_data, columnDefs
else:
# When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
# Return empty if no content row_data = df.to_dicts()
return [], []
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
return [], row_data, columnDefs
@callback(
Input("ag-grid", "cellClicked"),
State("ag-grid", "row_data"),
Output("output-image-select", "children"),
)
def update_images_comparison(
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
):
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
return []
return []
app = APP() app = APP()

View File

@@ -0,0 +1,33 @@
from dataclasses import dataclass
from typing import Optional, Union
@dataclass
class CellClickedEvent:
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
- rowIndex (number; optional): rowIndex, typically a row number.
- rowId (boolean I number I string I dict | list; optional): Row Id from the grid, this could be a number automatically, orset via getRowId.
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
"""
value: Optional[Union[bool, int, float, str, dict, list]]
"""
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
"""
colId: Optional[Union[bool, int, float, str, dict, list]]
"""
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
"""
rowIndex: Optional[Union[int, float]]
"""
- rowIndex (number; optional): rowIndex, typically a row number.
"""
timestamp: Optional[Union[bool, int, float, str, dict, list]]
"""
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
"""