feat(feature-retrieval): add label mapping and visualization upload

This commit is contained in:
2026-02-05 13:27:29 +08:00
parent 701fa9f289
commit 90087194ec
2 changed files with 108 additions and 9 deletions

View File

@@ -1,4 +1,4 @@
from typing import cast from typing import Any, Dict, List, Optional, cast
import torch import torch
from database import db_manager from database import db_manager
@@ -8,7 +8,14 @@ from transformers import AutoImageProcessor, AutoModel
@torch.no_grad() @torch.no_grad()
def establish_database(processor, model, images, labels, batch_size=64): def establish_database(
processor,
model,
images: List[Any],
labels: List[int] | List[str],
batch_size=64,
label_map: Optional[Dict[int, str] | List[str]] = None,
):
device = model.device device = model.device
model.eval() model.eval()
@@ -29,7 +36,13 @@ def establish_database(processor, model, images, labels, batch_size=64):
# 迁移输出到CPU # 迁移输出到CPU
cls_tokens = cls_tokens.cpu() cls_tokens = cls_tokens.cpu()
batch_labels = labels[i : i + batch_size] batch_labels = (
labels[i : i + batch_size]
if label_map is None
else list(
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
)
)
actual_batch_size = len(batch_labels) actual_batch_size = len(batch_labels)
# 存库 # 存库
@@ -44,10 +57,28 @@ def establish_database(processor, model, images, labels, batch_size=64):
if __name__ == "__main__": if __name__ == "__main__":
train_dataset = load_dataset("uoft-cs/cifar10", split="train") train_dataset = load_dataset("uoft-cs/cifar10", split="train")
train_dataset = cast(Dataset, train_dataset) train_dataset = cast(Dataset, train_dataset)
label_map = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
processor = AutoImageProcessor.from_pretrained( processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map="cuda" "facebook/dinov2-large", device_map="cuda"
) )
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda") model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
establish_database(processor, model, train_dataset["img"], train_dataset["label"]) establish_database(
processor,
model,
train_dataset["img"],
train_dataset["label"],
label_map=label_map,
)

View File

@@ -1,11 +1,30 @@
import datetime
from typing import Optional from typing import Optional
import dash_ag_grid as dag import dash_ag_grid as dag
import dash_mantine_components as dmc import dash_mantine_components as dmc
from dash import Dash from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager from database import db_manager
def parse_contents(contents, filename, date):
return html.Div(
[
html.H5(filename),
html.H6(datetime.datetime.fromtimestamp(date)),
# HTML images accept base64 encoded strings in the same format
# that is supplied by the upload
html.Img(src=contents),
html.Hr(),
html.Div("Raw Content"),
html.Pre(
contents[0:200] + "...",
style={"whiteSpace": "pre-wrap", "wordBreak": "break-all"},
),
]
)
class APP(Dash): class APP(Dash):
"""Singleton Dash Application""" """Singleton Dash Application"""
@@ -19,7 +38,12 @@ class APP(Dash):
def __init__(self): def __init__(self):
super().__init__(__name__) super().__init__(__name__)
df = db_manager.table.search().select(["id", "label", "vector"]).to_polars() df = (
db_manager.table.search()
.select(["id", "label", "vector"])
.limit(1000)
.to_polars()
)
columnDefs = [ columnDefs = [
{"headerName": column.capitalize(), "field": column} {"headerName": column.capitalize(), "field": column}
@@ -27,11 +51,55 @@ class APP(Dash):
] ]
self.layout = dmc.MantineProvider( self.layout = dmc.MantineProvider(
dag.AgGrid( dmc.Container(
rowData=df.to_dicts(), dmc.Flex(
columnDefs=columnDefs, [
dcc.Upload(
id="upload-image",
children=html.Div(
["Drag and Drop or ", html.A("Select Files")]
),
style={
"width": "100%",
"height": "60px",
"lineHeight": "60px",
"borderWidth": "1px",
"borderStyle": "dashed",
"borderRadius": "5px",
"textAlign": "center",
"margin": "10px",
},
# Allow multiple files to be uploaded
multiple=True,
),
html.Div(id="output-image-upload"),
dag.AgGrid(
rowData=df.to_dicts(),
columnDefs=columnDefs,
),
],
gap="md",
justify="center",
align="center",
direction="column",
wrap="wrap",
),
) )
) )
@callback(
Output("output-image-upload", "children"),
Input("upload-image", "contents"),
State("upload-image", "filename"),
State("upload-image", "last_modified"),
)
def update_output(list_of_contents, list_of_names, list_of_dates):
if list_of_contents is not None:
children = [
parse_contents(c, n, d)
for c, n, d in zip(list_of_contents, list_of_names, list_of_dates)
]
return children
app = APP() app = APP()