feat(feature-retrieval): add label mapping and visualization upload

This commit is contained in:
2026-02-05 13:27:29 +08:00
parent 701fa9f289
commit 90087194ec
2 changed files with 108 additions and 9 deletions

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import Any, Dict, List, Optional, cast
import torch
from database import db_manager
@@ -8,7 +8,14 @@ from transformers import AutoImageProcessor, AutoModel
@torch.no_grad()
def establish_database(processor, model, images, labels, batch_size=64):
def establish_database(
processor,
model,
images: List[Any],
labels: List[int] | List[str],
batch_size=64,
label_map: Optional[Dict[int, str] | List[str]] = None,
):
device = model.device
model.eval()
@@ -29,7 +36,13 @@ def establish_database(processor, model, images, labels, batch_size=64):
# 迁移输出到CPU
cls_tokens = cls_tokens.cpu()
batch_labels = labels[i : i + batch_size]
batch_labels = (
labels[i : i + batch_size]
if label_map is None
else list(
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
)
)
actual_batch_size = len(batch_labels)
# 存库
@@ -44,10 +57,28 @@ def establish_database(processor, model, images, labels, batch_size=64):
if __name__ == "__main__":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
train_dataset = cast(Dataset, train_dataset)
label_map = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
processor = AutoImageProcessor.from_pretrained(
"facebook/dinov2-large", device_map="cuda"
)
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
establish_database(processor, model, train_dataset["img"], train_dataset["label"])
establish_database(
processor,
model,
train_dataset["img"],
train_dataset["label"],
label_map=label_map,
)