mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat: initialize project with Python 3.10 and feature retrieval module
This commit is contained in:
37
mini-nav/feature_retrieval.py
Normal file
37
mini-nav/feature_retrieval.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import cast
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def establish_database(processor, model, images, batch_size=64):
|
||||
device = model.device
|
||||
model.eval()
|
||||
|
||||
for i in tqdm(range(0, len(images), batch_size)):
|
||||
batch = images[i : i + batch_size]
|
||||
|
||||
inputs = processor(images=batch, return_tensors="pt")
|
||||
|
||||
# 迁移数据到GPU
|
||||
inputs.to(device, non_blocking=True)
|
||||
|
||||
outputs = model(**inputs)
|
||||
feats = outputs.last_hidden_state # [B, N, D]
|
||||
# 后处理 / 存库
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
||||
train_dataset = cast(Dataset, train_dataset)
|
||||
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
"facebook/dinov2-large", device_map="cuda"
|
||||
)
|
||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
||||
|
||||
establish_database(processor, model, train_dataset["img"])
|
||||
Reference in New Issue
Block a user