feat(database): add vector database with ConfigType enum

This commit is contained in:
2026-02-03 17:25:24 +08:00
parent cf83c09165
commit 9efdbb3327
6 changed files with 592 additions and 72 deletions

View File

@@ -1,28 +1,44 @@
from typing import cast
import torch
from tqdm.auto import tqdm
from database import db_manager
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
@torch.no_grad()
def establish_database(processor, model, images, batch_size=64):
def establish_database(processor, model, images, labels, batch_size=64):
device = model.device
model.eval()
for i in tqdm(range(0, len(images), batch_size)):
batch = images[i : i + batch_size]
batch_imgs = images[i : i + batch_size]
inputs = processor(images=batch, return_tensors="pt")
inputs = processor(images=batch_imgs, return_tensors="pt")
# 迁移数据到GPU
inputs.to(device, non_blocking=True)
outputs = model(**inputs)
feats = outputs.last_hidden_state # [B, N, D]
# 后处理 / 存库
# 后处理
feats = outputs.last_hidden_state # [B, N, D]
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
cls_tokens = cast(torch.Tensor, cls_tokens)
# 迁移输出到CPU
cls_tokens = cls_tokens.cpu()
batch_labels = labels[i : i + batch_size]
actual_batch_size = len(batch_labels)
# 存库
db_manager.table.add(
[
{"id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy()}
for j in range(actual_batch_size)
]
)
if __name__ == "__main__":
@@ -34,4 +50,4 @@ if __name__ == "__main__":
)
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
establish_database(processor, model, train_dataset["img"])
establish_database(processor, model, train_dataset["img"], train_dataset["label"])