diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index b3eec8a..ed0f1a5 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, Dict, List, Optional, cast import torch from database import db_manager @@ -8,7 +8,14 @@ from transformers import AutoImageProcessor, AutoModel @torch.no_grad() -def establish_database(processor, model, images, labels, batch_size=64): +def establish_database( + processor, + model, + images: List[Any], + labels: List[int] | List[str], + batch_size=64, + label_map: Optional[Dict[int, str] | List[str]] = None, +): device = model.device model.eval() @@ -29,7 +36,13 @@ def establish_database(processor, model, images, labels, batch_size=64): # 迁移输出到CPU cls_tokens = cls_tokens.cpu() - batch_labels = labels[i : i + batch_size] + batch_labels = ( + labels[i : i + batch_size] + if label_map is None + else list( + map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size]) + ) + ) actual_batch_size = len(batch_labels) # 存库 @@ -44,10 +57,28 @@ def establish_database(processor, model, images, labels, batch_size=64): if __name__ == "__main__": train_dataset = load_dataset("uoft-cs/cifar10", split="train") train_dataset = cast(Dataset, train_dataset) + label_map = [ + "airplane", + "automobile", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ] processor = AutoImageProcessor.from_pretrained( "facebook/dinov2-large", device_map="cuda" ) model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda") - establish_database(processor, model, train_dataset["img"], train_dataset["label"]) + establish_database( + processor, + model, + train_dataset["img"], + train_dataset["label"], + label_map=label_map, + ) diff --git a/mini-nav/visualizer/app.py b/mini-nav/visualizer/app.py index fe4adee..2425d5a 100644 --- a/mini-nav/visualizer/app.py +++ b/mini-nav/visualizer/app.py @@ -1,11 +1,30 @@ +import datetime from typing import Optional import dash_ag_grid as dag import dash_mantine_components as dmc -from dash import Dash +from dash import Dash, Input, Output, State, callback, dcc, html from database import db_manager +def parse_contents(contents, filename, date): + return html.Div( + [ + html.H5(filename), + html.H6(datetime.datetime.fromtimestamp(date)), + # HTML images accept base64 encoded strings in the same format + # that is supplied by the upload + html.Img(src=contents), + html.Hr(), + html.Div("Raw Content"), + html.Pre( + contents[0:200] + "...", + style={"whiteSpace": "pre-wrap", "wordBreak": "break-all"}, + ), + ] + ) + + class APP(Dash): """Singleton Dash Application""" @@ -19,7 +38,12 @@ class APP(Dash): def __init__(self): super().__init__(__name__) - df = db_manager.table.search().select(["id", "label", "vector"]).to_polars() + df = ( + db_manager.table.search() + .select(["id", "label", "vector"]) + .limit(1000) + .to_polars() + ) columnDefs = [ {"headerName": column.capitalize(), "field": column} @@ -27,11 +51,55 @@ class APP(Dash): ] self.layout = dmc.MantineProvider( - dag.AgGrid( - rowData=df.to_dicts(), - columnDefs=columnDefs, + dmc.Container( + dmc.Flex( + [ + dcc.Upload( + id="upload-image", + children=html.Div( + ["Drag and Drop or ", html.A("Select Files")] + ), + style={ + "width": "100%", + "height": "60px", + "lineHeight": "60px", + "borderWidth": "1px", + "borderStyle": "dashed", + "borderRadius": "5px", + "textAlign": "center", + "margin": "10px", + }, + # Allow multiple files to be uploaded + multiple=True, + ), + html.Div(id="output-image-upload"), + dag.AgGrid( + rowData=df.to_dicts(), + columnDefs=columnDefs, + ), + ], + gap="md", + justify="center", + align="center", + direction="column", + wrap="wrap", + ), ) ) + @callback( + Output("output-image-upload", "children"), + Input("upload-image", "contents"), + State("upload-image", "filename"), + State("upload-image", "last_modified"), + ) + def update_output(list_of_contents, list_of_names, list_of_dates): + if list_of_contents is not None: + children = [ + parse_contents(c, n, d) + for c, n, d in zip(list_of_contents, list_of_names, list_of_dates) + ] + return children + app = APP()