Files
Mini-Nav/mini-nav/visualizer/app.py

176 lines
5.6 KiB
Python

import base64
import datetime
import io
from typing import List, Optional
import dash_ag_grid as dag
import dash_mantine_components as dmc
from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager
from feature_retrieval import FeatureRetrieval
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from .events import CellClickedEvent
class APP(Dash):
"""Singleton Dash Application"""
_instance: Optional["APP"] = None
# Feature retrieval singleton
_feature_retrieval: FeatureRetrieval
def __new__(cls) -> "APP":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
super().__init__(__name__)
# Initialize FeatureRetrieval
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
APP._feature_retrieval = FeatureRetrieval(processor, model)
self.layout = dmc.MantineProvider(
dmc.Container(
dmc.Flex(
[
dcc.Upload(
id="upload-image",
children=html.Div(
["Drag and Drop or ", html.A("Select Files")]
),
style={
"width": "100%",
"height": "60px",
"lineHeight": "60px",
"borderWidth": "1px",
"borderStyle": "dashed",
"borderRadius": "5px",
"textAlign": "center",
"margin": "10px",
},
# Disallow multiple files to be uploaded
multiple=False,
),
dmc.Flex(
[
html.Div(id="output-image-upload"),
html.Div(id="output-image-select"),
],
gap="md",
justify="center",
align="center",
direction="row",
wrap="wrap",
),
dag.AgGrid(id="ag-grid"),
],
gap="md",
justify="center",
align="center",
direction="column",
wrap="wrap",
),
)
)
@callback(
Output("output-image-upload", "children"),
Output("ag-grid", "rowData"),
Output("ag-grid", "columnDefs"),
Input("upload-image", "contents"),
State("upload-image", "filename"),
State("upload-image", "last_modified"),
)
def update_output(
list_of_contents: Optional[List[str]],
list_of_names: Optional[List[str]],
list_of_dates: Optional[List[int] | List[float]],
):
def parse_base64_to_pil(contents: str) -> Image.Image:
"""Parse base64 string to PIL Image."""
# Remove data URI prefix (e.g., "data:image/png;base64,")
base64_str = contents.split(",")[1]
img_bytes = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_bytes))
if (
list_of_contents is not None
and list_of_names is not None
and list_of_dates is not None
):
# Process first uploaded image for similarity search
filename = list_of_names[0]
uploaddate = list_of_dates[0]
imagecontent = list_of_contents[0]
pil_image = parse_base64_to_pil(imagecontent)
# Extract feature vector using DINOv2
feature_vector = APP._feature_retrieval.extract_single_image_feature(pil_image)
# Search for similar images in database
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
results_df = (
db_manager.table.search(feature_vector)
.select(["id", "label"])
.limit(10)
.to_polars()
)
# Convert to AgGrid row format
row_data = results_df.to_dicts()
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in results_df.columns
]
# Display uploaded images
children = [
html.H5(filename),
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
# HTML images accept base64 encoded strings in same format
# that is supplied by the upload
dmc.Image(src=imagecontent),
dmc.Text(f"{feature_vector[:5]}", size="xs"),
]
return children, row_data, columnDefs
else:
# When contents is None
# Exclude 'vector' and 'binary' columns as they are not JSON serializable
df = db_manager.table.search().select(["id", "label"]).limit(1000).to_polars()
row_data = df.to_dicts()
columnDefs = [
{"headerName": column.capitalize(), "field": column}
for column in df.columns
]
return [], row_data, columnDefs
@callback(
Input("ag-grid", "cellClicked"),
State("ag-grid", "row_data"),
Output("output-image-select", "children"),
)
def update_images_comparison(
clicked_event: Optional[CellClickedEvent], row_data: Optional[dict]
):
if clicked_event is None or CellClickedEvent.rowIndex is None or row_data is None:
return []
return []
app = APP()