mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(utils): add feature extraction utilities and tests
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import io
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
from database import db_manager
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngImageFile
|
||||
from torch import nn
|
||||
@@ -14,6 +13,9 @@ from transformers import (
|
||||
BitImageProcessorFast,
|
||||
Dinov2Model,
|
||||
)
|
||||
from utils.feature_extractor import extract_batch_features
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def pil_image_to_bytes(image: Image.Image, format: str = "PNG") -> bytes:
|
||||
@@ -86,78 +88,26 @@ class FeatureRetrieval:
|
||||
batch_size: Number of images to process in a batch.
|
||||
label_map: Optional mapping from label indices to string names.
|
||||
"""
|
||||
device = self.model.device
|
||||
self.model.eval()
|
||||
# Extract features using the utility function
|
||||
cls_tokens = extract_batch_features(
|
||||
self.processor, self.model, images, batch_size=batch_size
|
||||
)
|
||||
|
||||
for i in tqdm(range(0, len(images), batch_size)):
|
||||
batch_imgs = images[i : i + batch_size]
|
||||
for i in tqdm(range(len(labels)), desc="Storing to database"):
|
||||
batch_label = labels[i] if label_map is None else label_map[labels[i]]
|
||||
|
||||
inputs = self.processor(batch_imgs, return_tensors="pt")
|
||||
|
||||
# 迁移数据到GPU
|
||||
inputs.to(device)
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# 后处理
|
||||
feats = outputs.last_hidden_state # [B, N, D]
|
||||
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
|
||||
cls_tokens = cast(torch.Tensor, cls_tokens)
|
||||
|
||||
# 迁移输出到CPU
|
||||
cls_tokens = cls_tokens.cpu()
|
||||
batch_labels = (
|
||||
labels[i : i + batch_size]
|
||||
if label_map is None
|
||||
else list(
|
||||
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
||||
)
|
||||
)
|
||||
actual_batch_size = len(batch_labels)
|
||||
|
||||
# 存库
|
||||
# Store to database
|
||||
db_manager.table.add(
|
||||
[
|
||||
{
|
||||
"id": i + j,
|
||||
"label": batch_labels[j],
|
||||
"vector": cls_tokens[j].numpy(),
|
||||
"binary": pil_image_to_bytes(batch_imgs[j]),
|
||||
"id": i,
|
||||
"label": batch_label,
|
||||
"vector": cls_tokens[i].numpy(),
|
||||
"binary": pil_image_to_bytes(images[i]),
|
||||
}
|
||||
for j in range(actual_batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_single_image_feature(
|
||||
self, image: Union[Image.Image, Any]
|
||||
) -> List[float]:
|
||||
"""Extract feature from a single image without storing to database.
|
||||
|
||||
Args:
|
||||
image: A single image (PIL Image or other supported format).
|
||||
|
||||
Returns:
|
||||
pl.Series: The extracted CLS token feature vector as a Polars Series.
|
||||
"""
|
||||
device = self.model.device
|
||||
self.model.eval()
|
||||
|
||||
# 预处理图片
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
inputs.to(device, non_blocking=True)
|
||||
|
||||
# 提取特征
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# 获取 CLS token
|
||||
feats = outputs.last_hidden_state # [1, N, D]
|
||||
cls_token = feats[:, 0] # [1, D]
|
||||
cls_token = cast(torch.Tensor, cls_token)
|
||||
|
||||
# 返回 CLS List
|
||||
return cls_token.cpu().squeeze(0).tolist()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||
|
||||
Reference in New Issue
Block a user