mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(feature-retrieval): add label mapping and visualization upload
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
@@ -8,7 +8,14 @@ from transformers import AutoImageProcessor, AutoModel
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def establish_database(processor, model, images, labels, batch_size=64):
|
def establish_database(
|
||||||
|
processor,
|
||||||
|
model,
|
||||||
|
images: List[Any],
|
||||||
|
labels: List[int] | List[str],
|
||||||
|
batch_size=64,
|
||||||
|
label_map: Optional[Dict[int, str] | List[str]] = None,
|
||||||
|
):
|
||||||
device = model.device
|
device = model.device
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@@ -29,7 +36,13 @@ def establish_database(processor, model, images, labels, batch_size=64):
|
|||||||
|
|
||||||
# 迁移输出到CPU
|
# 迁移输出到CPU
|
||||||
cls_tokens = cls_tokens.cpu()
|
cls_tokens = cls_tokens.cpu()
|
||||||
batch_labels = labels[i : i + batch_size]
|
batch_labels = (
|
||||||
|
labels[i : i + batch_size]
|
||||||
|
if label_map is None
|
||||||
|
else list(
|
||||||
|
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
||||||
|
)
|
||||||
|
)
|
||||||
actual_batch_size = len(batch_labels)
|
actual_batch_size = len(batch_labels)
|
||||||
|
|
||||||
# 存库
|
# 存库
|
||||||
@@ -44,10 +57,28 @@ def establish_database(processor, model, images, labels, batch_size=64):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||||
train_dataset = cast(Dataset, train_dataset)
|
train_dataset = cast(Dataset, train_dataset)
|
||||||
|
label_map = [
|
||||||
|
"airplane",
|
||||||
|
"automobile",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"deer",
|
||||||
|
"dog",
|
||||||
|
"frog",
|
||||||
|
"horse",
|
||||||
|
"ship",
|
||||||
|
"truck",
|
||||||
|
]
|
||||||
|
|
||||||
processor = AutoImageProcessor.from_pretrained(
|
processor = AutoImageProcessor.from_pretrained(
|
||||||
"facebook/dinov2-large", device_map="cuda"
|
"facebook/dinov2-large", device_map="cuda"
|
||||||
)
|
)
|
||||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
||||||
|
|
||||||
establish_database(processor, model, train_dataset["img"], train_dataset["label"])
|
establish_database(
|
||||||
|
processor,
|
||||||
|
model,
|
||||||
|
train_dataset["img"],
|
||||||
|
train_dataset["label"],
|
||||||
|
label_map=label_map,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,11 +1,30 @@
|
|||||||
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import dash_ag_grid as dag
|
import dash_ag_grid as dag
|
||||||
import dash_mantine_components as dmc
|
import dash_mantine_components as dmc
|
||||||
from dash import Dash
|
from dash import Dash, Input, Output, State, callback, dcc, html
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
|
|
||||||
|
|
||||||
|
def parse_contents(contents, filename, date):
|
||||||
|
return html.Div(
|
||||||
|
[
|
||||||
|
html.H5(filename),
|
||||||
|
html.H6(datetime.datetime.fromtimestamp(date)),
|
||||||
|
# HTML images accept base64 encoded strings in the same format
|
||||||
|
# that is supplied by the upload
|
||||||
|
html.Img(src=contents),
|
||||||
|
html.Hr(),
|
||||||
|
html.Div("Raw Content"),
|
||||||
|
html.Pre(
|
||||||
|
contents[0:200] + "...",
|
||||||
|
style={"whiteSpace": "pre-wrap", "wordBreak": "break-all"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class APP(Dash):
|
class APP(Dash):
|
||||||
"""Singleton Dash Application"""
|
"""Singleton Dash Application"""
|
||||||
|
|
||||||
@@ -19,7 +38,12 @@ class APP(Dash):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(__name__)
|
super().__init__(__name__)
|
||||||
|
|
||||||
df = db_manager.table.search().select(["id", "label", "vector"]).to_polars()
|
df = (
|
||||||
|
db_manager.table.search()
|
||||||
|
.select(["id", "label", "vector"])
|
||||||
|
.limit(1000)
|
||||||
|
.to_polars()
|
||||||
|
)
|
||||||
|
|
||||||
columnDefs = [
|
columnDefs = [
|
||||||
{"headerName": column.capitalize(), "field": column}
|
{"headerName": column.capitalize(), "field": column}
|
||||||
@@ -27,11 +51,55 @@ class APP(Dash):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.layout = dmc.MantineProvider(
|
self.layout = dmc.MantineProvider(
|
||||||
dag.AgGrid(
|
dmc.Container(
|
||||||
rowData=df.to_dicts(),
|
dmc.Flex(
|
||||||
columnDefs=columnDefs,
|
[
|
||||||
|
dcc.Upload(
|
||||||
|
id="upload-image",
|
||||||
|
children=html.Div(
|
||||||
|
["Drag and Drop or ", html.A("Select Files")]
|
||||||
|
),
|
||||||
|
style={
|
||||||
|
"width": "100%",
|
||||||
|
"height": "60px",
|
||||||
|
"lineHeight": "60px",
|
||||||
|
"borderWidth": "1px",
|
||||||
|
"borderStyle": "dashed",
|
||||||
|
"borderRadius": "5px",
|
||||||
|
"textAlign": "center",
|
||||||
|
"margin": "10px",
|
||||||
|
},
|
||||||
|
# Allow multiple files to be uploaded
|
||||||
|
multiple=True,
|
||||||
|
),
|
||||||
|
html.Div(id="output-image-upload"),
|
||||||
|
dag.AgGrid(
|
||||||
|
rowData=df.to_dicts(),
|
||||||
|
columnDefs=columnDefs,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
gap="md",
|
||||||
|
justify="center",
|
||||||
|
align="center",
|
||||||
|
direction="column",
|
||||||
|
wrap="wrap",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback(
|
||||||
|
Output("output-image-upload", "children"),
|
||||||
|
Input("upload-image", "contents"),
|
||||||
|
State("upload-image", "filename"),
|
||||||
|
State("upload-image", "last_modified"),
|
||||||
|
)
|
||||||
|
def update_output(list_of_contents, list_of_names, list_of_dates):
|
||||||
|
if list_of_contents is not None:
|
||||||
|
children = [
|
||||||
|
parse_contents(c, n, d)
|
||||||
|
for c, n, d in zip(list_of_contents, list_of_names, list_of_dates)
|
||||||
|
]
|
||||||
|
return children
|
||||||
|
|
||||||
|
|
||||||
app = APP()
|
app = APP()
|
||||||
|
|||||||
Reference in New Issue
Block a user