feat(feature-retrieval): add single image feature extraction method

This commit is contained in:
2026-02-05 21:08:49 +08:00
parent 7ce97c1965
commit a0df45ab05
3 changed files with 52 additions and 23 deletions

View File

@@ -1,8 +1,10 @@
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Union, cast
import polars as pl
import torch
from database import db_manager
from datasets import Dataset, load_dataset
from datasets import load_dataset
from PIL import Image
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
@@ -101,10 +103,37 @@ class FeatureRetrieval:
]
)
@torch.no_grad()
def extract_single_image_feature(self, image: Union[Image.Image, Any]) -> pl.Series:
"""Extract feature from a single image without storing to database.
Args:
image: A single image (PIL Image or other supported format).
Returns:
pl.Series: The extracted CLS token feature vector as a Polars Series.
"""
device = self.model.device
self.model.eval()
# 预处理图片
inputs = self.processor(images=image, return_tensors="pt")
inputs.to(device, non_blocking=True)
# 提取特征
outputs = self.model(**inputs)
# 获取 CLS token
feats = outputs.last_hidden_state # [1, N, D]
cls_token = feats[:, 0] # [1, D]
cls_token = cast(torch.Tensor, cls_token)
# 返回 Polars Series
return pl.Series("feature", cls_token.cpu().squeeze(0).tolist())
if __name__ == "__main__":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
train_dataset = cast(Dataset, train_dataset)
label_map = [
"airplane",
"automobile",

View File

@@ -1,5 +1,5 @@
import datetime
from typing import Optional
from typing import List, Optional, Union
import dash_ag_grid as dag
import dash_mantine_components as dmc
@@ -7,24 +7,6 @@ from dash import Dash, Input, Output, State, callback, dcc, html
from database import db_manager
def parse_contents(contents, filename, date):
return html.Div(
[
html.H5(filename),
html.H6(datetime.datetime.fromtimestamp(date)),
# HTML images accept base64 encoded strings in the same format
# that is supplied by the upload
html.Img(src=contents),
html.Hr(),
html.Div("Raw Content"),
html.Pre(
contents[0:200] + "...",
style={"whiteSpace": "pre-wrap", "wordBreak": "break-all"},
),
]
)
class APP(Dash):
"""Singleton Dash Application"""
@@ -93,7 +75,22 @@ class APP(Dash):
State("upload-image", "filename"),
State("upload-image", "last_modified"),
)
def update_output(list_of_contents, list_of_names, list_of_dates):
def update_output(
list_of_contents: List[str],
list_of_names: List[str],
list_of_dates: List[int] | List[float],
):
def parse_contents(contents: str, filename: str, date: Union[int, float]):
return html.Div(
[
html.H5(filename),
html.H6(datetime.datetime.fromtimestamp(date)),
# HTML images accept base64 encoded strings in the same format
# that is supplied by the upload
dmc.Image(src=contents),
]
)
if list_of_contents is not None:
children = [
parse_contents(c, n, d)