mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(visualizer): add image selection and binary data storage
This commit is contained in:
@@ -9,6 +9,7 @@ db_schema = pa.schema(
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("label", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
||||
pa.field("binary", pa.binary()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
from database import db_manager
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngImageFile
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
@@ -48,7 +49,7 @@ class FeatureRetrieval:
|
||||
@torch.no_grad()
|
||||
def establish_database(
|
||||
self,
|
||||
images: List[Any],
|
||||
images: List[PngImageFile],
|
||||
labels: List[int] | List[str],
|
||||
batch_size: int = 64,
|
||||
label_map: Optional[Dict[int, str] | List[str]] = None,
|
||||
@@ -97,6 +98,7 @@ class FeatureRetrieval:
|
||||
"id": i + j,
|
||||
"label": batch_labels[j],
|
||||
"vector": cls_tokens[j].numpy(),
|
||||
"binary": batch_imgs[j].tobytes(),
|
||||
}
|
||||
for j in range(actual_batch_size)
|
||||
]
|
||||
|
||||
0
mini-nav/utils/image.py
Normal file
0
mini-nav/utils/image.py
Normal file
@@ -11,6 +11,8 @@ from feature_retrieval import FeatureRetrieval
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .events import CellClickedEvent
|
||||
|
||||
|
||||
class APP(Dash):
|
||||
"""Singleton Dash Application"""
|
||||
@@ -33,18 +35,6 @@ class APP(Dash):
|
||||
model = AutoModel.from_pretrained("facebook/dinov2-large")
|
||||
APP._feature_retrieval = FeatureRetrieval(processor, model)
|
||||
|
||||
df = (
|
||||
db_manager.table.search()
|
||||
.select(["id", "label", "vector"])
|
||||
.limit(1000)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
columnDefs = [
|
||||
{"headerName": column.capitalize(), "field": column}
|
||||
for column in df.columns
|
||||
]
|
||||
|
||||
self.layout = dmc.MantineProvider(
|
||||
dmc.Container(
|
||||
dmc.Flex(
|
||||
@@ -64,15 +54,21 @@ class APP(Dash):
|
||||
"textAlign": "center",
|
||||
"margin": "10px",
|
||||
},
|
||||
# Allow multiple files to be uploaded
|
||||
multiple=True,
|
||||
# Disallow multiple files to be uploaded
|
||||
multiple=False,
|
||||
),
|
||||
html.Div(id="output-image-upload"),
|
||||
dag.AgGrid(
|
||||
id="ag-grid",
|
||||
rowData=df.to_dicts(),
|
||||
columnDefs=columnDefs,
|
||||
dmc.Flex(
|
||||
[
|
||||
html.Div(id="output-image-upload"),
|
||||
html.Div(id="output-image-select"),
|
||||
],
|
||||
gap="md",
|
||||
justify="center",
|
||||
align="center",
|
||||
direction="row",
|
||||
wrap="wrap",
|
||||
),
|
||||
dag.AgGrid(id="ag-grid"),
|
||||
],
|
||||
gap="md",
|
||||
justify="center",
|
||||
@@ -83,63 +79,97 @@ class APP(Dash):
|
||||
)
|
||||
)
|
||||
|
||||
@callback(
|
||||
Output("output-image-upload", "children"),
|
||||
Output("ag-grid", "rowData"),
|
||||
Input("upload-image", "contents"),
|
||||
State("upload-image", "filename"),
|
||||
State("upload-image", "last_modified"),
|
||||
)
|
||||
def update_output(
|
||||
list_of_contents: List[str],
|
||||
list_of_names: List[str],
|
||||
list_of_dates: List[int] | List[float],
|
||||
|
||||
@callback(
|
||||
Output("output-image-upload", "children"),
|
||||
Output("ag-grid", "rowData"),
|
||||
Output("ag-grid", "columnDefs"),
|
||||
Input("upload-image", "contents"),
|
||||
State("upload-image", "filename"),
|
||||
State("upload-image", "last_modified"),
|
||||
)
|
||||
def update_output(
|
||||
list_of_contents: Optional[List[str]],
|
||||
list_of_names: Optional[List[str]],
|
||||
list_of_dates: Optional[List[int] | List[float]],
|
||||
):
|
||||
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||
"""Parse base64 string to PIL Image."""
|
||||
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
||||
base64_str = contents.split(",")[1]
|
||||
img_bytes = base64.b64decode(base64_str)
|
||||
return Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
if (
|
||||
list_of_contents is not None
|
||||
and list_of_names is not None
|
||||
and list_of_dates is not None
|
||||
):
|
||||
def parse_base64_to_pil(contents: str) -> Image.Image:
|
||||
"""Parse base64 string to PIL Image."""
|
||||
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
||||
base64_str = contents.split(",")[1]
|
||||
img_bytes = base64.b64decode(base64_str)
|
||||
return Image.open(io.BytesIO(img_bytes))
|
||||
# Process first uploaded image for similarity search
|
||||
filename = list_of_names[0]
|
||||
uploaddate = list_of_dates[0]
|
||||
imagecontent = list_of_contents[0]
|
||||
|
||||
if list_of_contents is not None:
|
||||
# Process first uploaded image for similarity search
|
||||
filename = list_of_names[0]
|
||||
uploaddate = list_of_dates[0]
|
||||
imagecontent = list_of_contents[0]
|
||||
pil_image = parse_base64_to_pil(imagecontent)
|
||||
|
||||
pil_image = parse_base64_to_pil(imagecontent)
|
||||
# Extract feature vector using DINOv2
|
||||
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
|
||||
|
||||
# Extract feature vector using DINOv2
|
||||
feature_vector = APP._feature_retrieval.extract_single_image_feature(
|
||||
pil_image
|
||||
)
|
||||
# Search for similar images in database
|
||||
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||
results_df = (
|
||||
db_manager.table.search(feature_vector)
|
||||
.select(["id", "label"])
|
||||
.limit(10)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
# Search for similar images in database
|
||||
results_df = (
|
||||
db_manager.table.search(feature_vector)
|
||||
.select(["id", "label", "vector"])
|
||||
.limit(10)
|
||||
.to_polars()
|
||||
)
|
||||
# Convert to AgGrid row format
|
||||
row_data = results_df.to_dicts()
|
||||
|
||||
# Convert to AgGrid row format
|
||||
row_data = results_df.to_dicts()
|
||||
columnDefs = [
|
||||
{"headerName": column.capitalize(), "field": column}
|
||||
for column in results_df.columns
|
||||
]
|
||||
|
||||
# Display uploaded images
|
||||
children = [
|
||||
html.H5(filename),
|
||||
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
||||
# HTML images accept base64 encoded strings in same format
|
||||
# that is supplied by the upload
|
||||
dmc.Image(src=imagecontent),
|
||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||
]
|
||||
# Display uploaded images
|
||||
children = [
|
||||
html.H5(filename),
|
||||
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
||||
# HTML images accept base64 encoded strings in same format
|
||||
# that is supplied by the upload
|
||||
dmc.Image(src=imagecontent),
|
||||
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
||||
]
|
||||
|
||||
return children, row_data
|
||||
return children, row_data, columnDefs
|
||||
else:
|
||||
# When contents is None
|
||||
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
|
||||
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
|
||||
|
||||
# Return empty if no content
|
||||
return [], []
|
||||
row_data = df.to_dicts()
|
||||
|
||||
columnDefs = [
|
||||
{"headerName": column.capitalize(), "field": column}
|
||||
for column in df.columns
|
||||
]
|
||||
|
||||
return [], row_data, columnDefs
|
||||
|
||||
|
||||
@callback(
|
||||
Input("ag-grid", "cellClicked"),
|
||||
State("ag-grid", "row_data"),
|
||||
Output("output-image-select", "children"),
|
||||
)
|
||||
def update_images_comparison(
|
||||
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
|
||||
):
|
||||
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
|
||||
return []
|
||||
|
||||
return []
|
||||
|
||||
|
||||
app = APP()
|
||||
|
||||
33
mini-nav/visualizer/events.py
Normal file
33
mini-nav/visualizer/events.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class CellClickedEvent:
|
||||
"""
|
||||
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
||||
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
||||
- rowIndex (number; optional): rowIndex, typically a row number.
|
||||
- rowId (boolean I number I string I dict | list; optional): Row Id from the grid, this could be a number automatically, orset via getRowId.
|
||||
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
|
||||
"""
|
||||
|
||||
value: Optional[Union[bool, int, float, str, dict, list]]
|
||||
"""
|
||||
- value (boolean I number | string I dict | list; optional): value of the clicked cell.
|
||||
"""
|
||||
|
||||
colId: Optional[Union[bool, int, float, str, dict, list]]
|
||||
"""
|
||||
- colId (boolean I number I string I dict | list; optional): column where the cell was clicked.
|
||||
"""
|
||||
|
||||
rowIndex: Optional[Union[int, float]]
|
||||
"""
|
||||
- rowIndex (number; optional): rowIndex, typically a row number.
|
||||
"""
|
||||
|
||||
timestamp: Optional[Union[bool, int, float, str, dict, list]]
|
||||
"""
|
||||
- timestamp (boolean I number I string I dict I list; optional): timestamp of last action.
|
||||
"""
|
||||
Reference in New Issue
Block a user