diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index df4ab32..cb12b60 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -1,8 +1,10 @@ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Union, cast +import polars as pl import torch from database import db_manager -from datasets import Dataset, load_dataset +from datasets import load_dataset +from PIL import Image from tqdm.auto import tqdm from transformers import AutoImageProcessor, AutoModel @@ -101,10 +103,37 @@ class FeatureRetrieval: ] ) + @torch.no_grad() + def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series: + """Extract feature from a single image without storing to database. + + Args: + image: A single image (PIL Image or other supported format). + + Returns: + pl.Series: The extracted CLS token feature vector as a Polars Series. + """ + device = self.model.device + self.model.eval() + + # 预处理图片 + inputs = self.processor(images=image, return_tensors="pt") + inputs.to(device, non_blocking=True) + + # 提取特征 + outputs = self.model(**inputs) + + # 获取 CLS token + feats = outputs.last_hidden_state # [1, N, D] + cls_token = feats[:, 0] # [1, D] + cls_token = cast(torch.Tensor, cls_token) + + # 返回 Polars Series + return pl.Series("feature", cls_token.cpu().squeeze(0).tolist()) + if __name__ == "__main__": train_dataset = load_dataset("uoft-cs/cifar10", split="train") - train_dataset = cast(Dataset, train_dataset) label_map = [ "airplane", "automobile", diff --git a/mini-nav/visualizer/app.py b/mini-nav/visualizer/app.py index 2425d5a..dc8bc6c 100644 --- a/mini-nav/visualizer/app.py +++ b/mini-nav/visualizer/app.py @@ -1,5 +1,5 @@ import datetime -from typing import Optional +from typing import List, Optional, Union import dash_ag_grid as dag import dash_mantine_components as dmc @@ -7,24 +7,6 @@ from dash import Dash, Input, Output, State, callback, dcc, html from database import db_manager -def parse_contents(contents, filename, date): - return html.Div( - [ - html.H5(filename), - html.H6(datetime.datetime.fromtimestamp(date)), - # HTML images accept base64 encoded strings in the same format - # that is supplied by the upload - html.Img(src=contents), - html.Hr(), - html.Div("Raw Content"), - html.Pre( - contents[0:200] + "...", - style={"whiteSpace": "pre-wrap", "wordBreak": "break-all"}, - ), - ] - ) - - class APP(Dash): """Singleton Dash Application""" @@ -93,7 +75,22 @@ class APP(Dash): State("upload-image", "filename"), State("upload-image", "last_modified"), ) - def update_output(list_of_contents, list_of_names, list_of_dates): + def update_output( + list_of_contents: List[str], + list_of_names: List[str], + list_of_dates: List[int] | List[float], + ): + def parse_contents(contents: str, filename: str, date: Union[int, float]): + return html.Div( + [ + html.H5(filename), + html.H6(datetime.datetime.fromtimestamp(date)), + # HTML images accept base64 encoded strings in the same format + # that is supplied by the upload + dmc.Image(src=contents), + ] + ) + if list_of_contents is not None: children = [ parse_contents(c, n, d) diff --git a/pyproject.toml b/pyproject.toml index 4af41db..119b52c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,6 @@ torch = [ torchvision = [ { index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] + +[tool.ty.environment] +root = ["./mini-nav"]