mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): add image selection and binary data storage
This commit is contained in:
@@ -9,6 +9,7 @@ db_schema = pa.schema(
|
|||||||
pa.field("id", pa.int32()),
|
pa.field("id", pa.int32()),
|
||||||
pa.field("label", pa.string()),
|
pa.field("label", pa.string()),
|
||||||
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
||||||
|
pa.field("binary", pa.binary()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch
|
|||||||
from database import db_manager
|
from database import db_manager
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from PIL.PngImagePlugin import PngImageFile
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
@@ -48,7 +49,7 @@ class FeatureRetrieval:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def establish_database(
|
def establish_database(
|
||||||
self,
|
self,
|
||||||
images: List[Any],
|
images: List[PngImageFile],
|
||||||
labels: List[int] | List[str],
|
labels: List[int] | List[str],
|
||||||
batch_size: int = 64,
|
batch_size: int = 64,
|
||||||
label_map: Optional[Dict[int, str] | List[str]] = None,
|
label_map: Optional[Dict[int, str] | List[str]] = None,
|
||||||
@@ -97,6 +98,7 @@ class FeatureRetrieval:
|
|||||||
"id": i + j,
|
"id": i + j,
|
||||||
"label": batch_labels[j],
|
"label": batch_labels[j],
|
||||||
"vector": cls_tokens[j].numpy(),
|
"vector": cls_tokens[j].numpy(),
|
||||||
|
"binary": batch_imgs[j].tobytes(),
|
||||||
}
|
}
|
||||||
for j in range(actual_batch_size)
|
for j in range(actual_batch_size)
|
||||||
]
|
]
|
||||||
|
|||||||
0
mini-nav/utils/image.py
Normal file
0
mini-nav/utils/image.py
Normal file
@@ -11,6 +11,8 @@ from feature_retrieval import FeatureRetrieval
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
from .events import CellClickedEvent
|
||||||
|
|
||||||
|
|
||||||
class APP(Dash):
|
class APP(Dash):
|
||||||
"""Singleton Dash Application"""
|
"""Singleton Dash Application"""
|
||||||
@@ -33,18 +35,6 @@ class APP(Dash):
|
|||||||
model = AutoModel.from_pretrained("facebook/dinov2-large")
|
model = AutoModel.from_pretrained("facebook/dinov2-large")
|
||||||
APP._feature_retrieval = FeatureRetrieval(processor, model)
|
APP._feature_retrieval = FeatureRetrieval(processor, model)
|
||||||
|
|
||||||
df = (
|
|
||||||
db_manager.table.search()
|
|
||||||
.select(["id", "label", "vector"])
|
|
||||||
.limit(1000)
|
|
||||||
.to_polars()
|
|
||||||
)
|
|
||||||
|
|
||||||
columnDefs = [
|
|
||||||
{"headerName": column.capitalize(), "field": column}
|
|
||||||
for column in df.columns
|
|
||||||
]
|
|
||||||
|
|
||||||
self.layout = dmc.MantineProvider(
|
self.layout = dmc.MantineProvider(
|
||||||
dmc.Container(
|
dmc.Container(
|
||||||
dmc.Flex(
|
dmc.Flex(
|
||||||
@@ -64,15 +54,21 @@ class APP(Dash):
|
|||||||
"textAlign": "center",
|
"textAlign": "center",
|
||||||
"margin": "10px",
|
"margin": "10px",
|
||||||
},
|
},
|
||||||
# Allow multiple files to be uploaded
|
# Disallow multiple files to be uploaded
|
||||||
multiple=True,
|
multiple=False,
|
||||||
),
|
),
|
||||||
html.Div(id="output-image-upload"),
|
dmc.Flex(
|
||||||
dag.AgGrid(
|
[
|
||||||
id="ag-grid",
|
html.Div(id="output-image-upload"),
|
||||||
rowData=df.to_dicts(),
|
html.Div(id="output-image-select"),
|
||||||
columnDefs=columnDefs,
|
],
|
||||||
|
gap="md",
|
||||||
|
justify="center",
|
||||||
|
align="center",
|
||||||
|
direction="row",
|
||||||
|
wrap="wrap",
|
||||||
),
|
),
|
||||||
|
dag.AgGrid(id="ag-grid"),
|
||||||
],
|
],
|
||||||
gap="md",
|
gap="md",
|
||||||
justify="center",
|
justify="center",
|
||||||
@@ -83,63 +79,97 @@ class APP(Dash):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback(
|
|
||||||
Output("output-image-upload", "children"),
|
@callback(
|
||||||
Output("ag-grid", "rowData"),
|
Output("output-image-upload", "children"),
|
||||||
Input("upload-image", "contents"),
|
Output("ag-grid", "rowData"),
|
||||||
State("upload-image", "filename"),
|
Output("ag-grid", "columnDefs"),
|
||||||
State("upload-image", "last_modified"),
|
Input("upload-image", "contents"),
|
||||||
)
|
State("upload-image", "filename"),
|
||||||
def update_output(
|
State("upload-image", "last_modified"),
|
||||||
list_of_contents: List[str],
|
)
|
||||||
list_of_names: List[str],
|
def update_output(
|
||||||
list_of_dates: List[int] | List[float],
|
list_of_contents: Optional[List[str]],
|
||||||
|
list_of_names: Optional[List[str]],
|
||||||
|
list_of_dates: Optional[List[int] | List[float]],
|
||||||
|
):
|
||||||
|
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||||
|
"""Parse base64 string to PIL Image."""
|
||||||
|
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
||||||
|
base64_str = contents.split(",")[1]
|
||||||
|
img_bytes = base64.b64decode(base64_str)
|
||||||
|
return Image.open(io.BytesIO(img_bytes))
|
||||||
|
|
||||||
|
if (
|
||||||
|
list_of_contents is not None
|
||||||
|
and list_of_names is not None
|
||||||
|
and list_of_dates is not None
|
||||||
):
|
):
|
||||||
def parse_base64_to_pil(contents: str) -> Image.Image:
|
# Process first uploaded image for similarity search
|
||||||
"""Parse base64 string to PIL Image."""
|
filename = list_of_names[0]
|
||||||
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
uploaddate = list_of_dates[0]
|
||||||
base64_str = contents.split(",")[1]
|
imagecontent = list_of_contents[0]
|
||||||
img_bytes = base64.b64decode(base64_str)
|
|
||||||
return Image.open(io.BytesIO(img_bytes))
|
|
||||||
|
|
||||||
if list_of_contents is not None:
|
pil_image = parse_base64_to_pil(imagecontent)
|
||||||
# Process first uploaded image for similarity search
|
|
||||||
filename = list_of_names[0]
|
|
||||||
uploaddate = list_of_dates[0]
|
|
||||||
imagecontent = list_of_contents[0]
|
|
||||||
|
|
||||||
pil_image = parse_base64_to_pil(imagecontent)
|
# Extract feature vector using DINOv2
|
||||||
|
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
|
||||||
|
|
||||||
# Extract feature vector using DINOv2
|
# Search for similar images in database
|
||||||
feature_vector = APP._feature_retrieval.extract_single_image_feature(
|
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||||
pil_image
|
results_df = (
|
||||||
)
|
db_manager.table.search(feature_vector)
|
||||||
|
.select(["id", "label"])
|
||||||
|
.limit(10)
|
||||||
|
.to_polars()
|
||||||
|
)
|
||||||
|
|
||||||
# Search for similar images in database
|
# Convert to AgGrid row format
|
||||||
results_df = (
|
row_data = results_df.to_dicts()
|
||||||
db_manager.table.search(feature_vector)
|
|
||||||
.select(["id", "label", "vector"])
|
|
||||||
.limit(10)
|
|
||||||
.to_polars()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to AgGrid row format
|
columnDefs = [
|
||||||
row_data = results_df.to_dicts()
|
{"headerName": column.capitalize(), "field": column}
|
||||||
|
for column in results_df.columns
|
||||||
|
]
|
||||||
|
|
||||||
# Display uploaded images
|
# Display uploaded images
|
||||||
children = [
|
children = [
|
||||||
html.H5(filename),
|
html.H5(filename),
|
||||||
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
||||||
# HTML images accept base64 encoded strings in same format
|
# HTML images accept base64 encoded strings in same format
|
||||||
# that is supplied by the upload
|
# that is supplied by the upload
|
||||||
dmc.Image(src=imagecontent),
|
dmc.Image(src=imagecontent),
|
||||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||||
]
|
]
|
||||||
|
|
||||||
return children, row_data
|
return children, row_data, columnDefs
|
||||||
|
else:
|
||||||
|
# When contents is None
|
||||||
|
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||||
|
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
|
||||||
|
|
||||||
# Return empty if no content
|
row_data = df.to_dicts()
|
||||||
return [], []
|
|
||||||
|
columnDefs = [
|
||||||
|
{"headerName": column.capitalize(), "field": column}
|
||||||
|
for column in df.columns
|
||||||
|
]
|
||||||
|
|
||||||
|
return [], row_data, columnDefs
|
||||||
|
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Input("ag-grid", "cellClicked"),
|
||||||
|
State("ag-grid", "row_data"),
|
||||||
|
Output("output-image-select", "children"),
|
||||||
|
)
|
||||||
|
def update_images_comparison(
|
||||||
|
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
|
||||||
|
):
|
||||||
|
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
app = APP()
|
app = APP()
|
||||||
|
|||||||
33
mini-nav/visualizer/events.py
Normal file
33
mini-nav/visualizer/events.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CellClickedEvent:
|
||||||
|
"""
|
||||||
|
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
||||||
|
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
||||||
|
- rowIndex (number; optional): rowIndex, typically a row number.
|
||||||
|
- rowId (boolean I number I string I dict | list; optional): Row Id from the grid, this could be a number automatically, orset via getRowId.
|
||||||
|
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
value: Optional[Union[bool, int, float, str, dict, list]]
|
||||||
|
"""
|
||||||
|
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
||||||
|
"""
|
||||||
|
|
||||||
|
colId: Optional[Union[bool, int, float, str, dict, list]]
|
||||||
|
"""
|
||||||
|
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rowIndex: Optional[Union[int, float]]
|
||||||
|
"""
|
||||||
|
- rowIndex (number; optional): rowIndex, typically a row number.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timestamp: Optional[Union[bool, int, float, str, dict, list]]
|
||||||
|
"""
|
||||||
|
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
|
||||||
|
"""
|
||||||
Reference in New Issue
Block a user