mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
import base64
|
|
import datetime
|
|
import io
|
|
from typing import List, Optional
|
|
|
|
import dash_ag_grid as dag
|
|
import dash_mantine_components as dmc
|
|
from dash import Dash, Input, Output, State, callback, dcc, html
|
|
from database import db_manager
|
|
from feature_retrieval import FeatureRetrieval
|
|
from PIL import Image
|
|
from transformers import AutoImageProcessor, AutoModel
|
|
|
|
|
|
class APP(Dash):
|
|
"""Singleton Dash Application"""
|
|
|
|
_instance: Optional["APP"] = None
|
|
|
|
# Feature retrieval singleton
|
|
_feature_retrieval: FeatureRetrieval
|
|
|
|
def __new__(cls) -> "APP":
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
super().__init__(__name__)
|
|
|
|
# Initialize FeatureRetrieval
|
|
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
|
|
model = AutoModel.from_pretrained("facebook/dinov2-large")
|
|
APP._feature_retrieval = FeatureRetrieval(processor, model)
|
|
|
|
df = (
|
|
db_manager.table.search()
|
|
.select(["id", "label", "vector"])
|
|
.limit(1000)
|
|
.to_polars()
|
|
)
|
|
|
|
columnDefs = [
|
|
{"headerName": column.capitalize(), "field": column}
|
|
for column in df.columns
|
|
]
|
|
|
|
self.layout = dmc.MantineProvider(
|
|
dmc.Container(
|
|
dmc.Flex(
|
|
[
|
|
dcc.Upload(
|
|
id="upload-image",
|
|
children=html.Div(
|
|
["Drag and Drop or ", html.A("Select Files")]
|
|
),
|
|
style={
|
|
"width": "100%",
|
|
"height": "60px",
|
|
"lineHeight": "60px",
|
|
"borderWidth": "1px",
|
|
"borderStyle": "dashed",
|
|
"borderRadius": "5px",
|
|
"textAlign": "center",
|
|
"margin": "10px",
|
|
},
|
|
# Allow multiple files to be uploaded
|
|
multiple=True,
|
|
),
|
|
html.Div(id="output-image-upload"),
|
|
dag.AgGrid(
|
|
id="ag-grid",
|
|
rowData=df.to_dicts(),
|
|
columnDefs=columnDefs,
|
|
),
|
|
],
|
|
gap="md",
|
|
justify="center",
|
|
align="center",
|
|
direction="column",
|
|
wrap="wrap",
|
|
),
|
|
)
|
|
)
|
|
|
|
@callback(
|
|
Output("output-image-upload", "children"),
|
|
Output("ag-grid", "rowData"),
|
|
Input("upload-image", "contents"),
|
|
State("upload-image", "filename"),
|
|
State("upload-image", "last_modified"),
|
|
)
|
|
def update_output(
|
|
list_of_contents: List[str],
|
|
list_of_names: List[str],
|
|
list_of_dates: List[int] | List[float],
|
|
):
|
|
def parse_base64_to_pil(contents: str) -> Image.Image:
|
|
"""Parse base64 string to PIL Image."""
|
|
# Remove data URI prefix (e.g., "data:image/png;base64,")
|
|
base64_str = contents.split(",")[1]
|
|
img_bytes = base64.b64decode(base64_str)
|
|
return Image.open(io.BytesIO(img_bytes))
|
|
|
|
if list_of_contents is not None:
|
|
# Process first uploaded image for similarity search
|
|
filename = list_of_names[0]
|
|
uploaddate = list_of_dates[0]
|
|
imagecontent = list_of_contents[0]
|
|
|
|
pil_image = parse_base64_to_pil(imagecontent)
|
|
|
|
# Extract feature vector using DINOv2
|
|
feature_vector = APP._feature_retrieval.extract_single_image_feature(
|
|
pil_image
|
|
)
|
|
|
|
# Search for similar images in database
|
|
results_df = (
|
|
db_manager.table.search(feature_vector)
|
|
.select(["id", "label", "vector"])
|
|
.limit(10)
|
|
.to_polars()
|
|
)
|
|
|
|
# Convert to AgGrid row format
|
|
row_data = results_df.to_dicts()
|
|
|
|
# Display uploaded images
|
|
children = [
|
|
html.H5(filename),
|
|
html.H6(str(datetime.datetime.fromtimestamp(uploaddate))),
|
|
# HTML images accept base64 encoded strings in same format
|
|
# that is supplied by the upload
|
|
dmc.Image(src=imagecontent),
|
|
dmc.Text(f"{feature_vector[:5]}", size="xs"),
|
|
]
|
|
|
|
return children, row_data
|
|
|
|
# Return empty if no content
|
|
return [], []
|
|
|
|
|
|
app = APP()
|