mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(feature-retrieval): add single image feature extraction method
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
import torch
|
import torch
|
||||||
from database import db_manager
|
from database import db_manager
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import load_dataset
|
||||||
|
from PIL import Image
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
@@ -101,10 +103,37 @@ class FeatureRetrieval:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series:
|
||||||
|
"""Extract feature from a single image without storing to database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: A single image (PIL Image or other supported format).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pl.Series: The extracted CLS token feature vector as a Polars Series.
|
||||||
|
"""
|
||||||
|
device = self.model.device
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# 预处理图片
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
inputs.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
# 获取 CLS token
|
||||||
|
feats = outputs.last_hidden_state # [1, N, D]
|
||||||
|
cls_token = feats[:, 0] # [1, D]
|
||||||
|
cls_token = cast(torch.Tensor, cls_token)
|
||||||
|
|
||||||
|
# 返回 Polars Series
|
||||||
|
return pl.Series("feature", cls_token.cpu().squeeze(0).tolist())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||||
train_dataset = cast(Dataset, train_dataset)
|
|
||||||
label_map = [
|
label_map = [
|
||||||
"airplane",
|
"airplane",
|
||||||
"automobile",
|
"automobile",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import dash_ag_grid as dag
|
import dash_ag_grid as dag
|
||||||
import dash_mantine_components as dmc
|
import dash_mantine_components as dmc
|
||||||
@@ -7,24 +7,6 @@ from dash import Dash, Input, Output, State, callback, dcc, html
|
|||||||
from database import db_manager
|
from database import db_manager
|
||||||
|
|
||||||
|
|
||||||
def parse_contents(contents, filename, date):
|
|
||||||
return html.Div(
|
|
||||||
[
|
|
||||||
html.H5(filename),
|
|
||||||
html.H6(datetime.datetime.fromtimestamp(date)),
|
|
||||||
# HTML images accept base64 encoded strings in the same format
|
|
||||||
# that is supplied by the upload
|
|
||||||
html.Img(src=contents),
|
|
||||||
html.Hr(),
|
|
||||||
html.Div("Raw Content"),
|
|
||||||
html.Pre(
|
|
||||||
contents[0:200] + "...",
|
|
||||||
style={"whiteSpace": "pre-wrap", "wordBreak": "break-all"},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class APP(Dash):
|
class APP(Dash):
|
||||||
"""Singleton Dash Application"""
|
"""Singleton Dash Application"""
|
||||||
|
|
||||||
@@ -93,7 +75,22 @@ class APP(Dash):
|
|||||||
State("upload-image", "filename"),
|
State("upload-image", "filename"),
|
||||||
State("upload-image", "last_modified"),
|
State("upload-image", "last_modified"),
|
||||||
)
|
)
|
||||||
def update_output(list_of_contents, list_of_names, list_of_dates):
|
def update_output(
|
||||||
|
list_of_contents: List[str],
|
||||||
|
list_of_names: List[str],
|
||||||
|
list_of_dates: List[int] | List[float],
|
||||||
|
):
|
||||||
|
def parse_contents(contents: str, filename: str, date: Union[int, float]):
|
||||||
|
return html.Div(
|
||||||
|
[
|
||||||
|
html.H5(filename),
|
||||||
|
html.H6(datetime.datetime.fromtimestamp(date)),
|
||||||
|
# HTML images accept base64 encoded strings in the same format
|
||||||
|
# that is supplied by the upload
|
||||||
|
dmc.Image(src=contents),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if list_of_contents is not None:
|
if list_of_contents is not None:
|
||||||
children = [
|
children = [
|
||||||
parse_contents(c, n, d)
|
parse_contents(c, n, d)
|
||||||
|
|||||||
@@ -38,3 +38,6 @@ torch = [
|
|||||||
torchvision = [
|
torchvision = [
|
||||||
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ty.environment]
|
||||||
|
root = ["./mini-nav"]
|
||||||
|
|||||||
Reference in New Issue
Block a user