mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(feature-retrieval): add label mapping and visualization upload
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
from database import db_manager
|
||||
@@ -8,7 +8,14 @@ from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def establish_database(processor, model, images, labels, batch_size=64):
|
||||
def establish_database(
|
||||
processor,
|
||||
model,
|
||||
images: List[Any],
|
||||
labels: List[int] | List[str],
|
||||
batch_size=64,
|
||||
label_map: Optional[Dict[int, str] | List[str]] = None,
|
||||
):
|
||||
device = model.device
|
||||
model.eval()
|
||||
|
||||
@@ -29,7 +36,13 @@ def establish_database(processor, model, images, labels, batch_size=64):
|
||||
|
||||
# 迁移输出到CPU
|
||||
cls_tokens = cls_tokens.cpu()
|
||||
batch_labels = labels[i : i + batch_size]
|
||||
batch_labels = (
|
||||
labels[i : i + batch_size]
|
||||
if label_map is None
|
||||
else list(
|
||||
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
||||
)
|
||||
)
|
||||
actual_batch_size = len(batch_labels)
|
||||
|
||||
# 存库
|
||||
@@ -44,10 +57,28 @@ def establish_database(processor, model, images, labels, batch_size=64):
|
||||
if __name__ == "__main__":
|
||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||
train_dataset = cast(Dataset, train_dataset)
|
||||
label_map = [
|
||||
"airplane",
|
||||
"automobile",
|
||||
"bird",
|
||||
"cat",
|
||||
"deer",
|
||||
"dog",
|
||||
"frog",
|
||||
"horse",
|
||||
"ship",
|
||||
"truck",
|
||||
]
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map="cuda"
|
||||
)
|
||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
||||
|
||||
establish_database(processor, model, train_dataset["img"], train_dataset["label"])
|
||||
establish_database(
|
||||
processor,
|
||||
model,
|
||||
train_dataset["img"],
|
||||
train_dataset["label"],
|
||||
label_map=label_map,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user