feat(database): add vector database with ConfigType enum

This commit is contained in:
2026-02-03 17:25:24 +08:00
parent cf83c09165
commit 9efdbb3327
6 changed files with 592 additions and 72 deletions

View File

@@ -8,6 +8,7 @@ from .models import (
from .loader import load_yaml, save_yaml, ConfigError
from .config import (
ConfigManager,
ConfigType,
cfg_manager,
)
@@ -24,5 +25,6 @@ __all__ = [
"ConfigError",
# Manager
"ConfigManager",
"ConfigType",
"cfg_manager",
]

View File

@@ -1,5 +1,6 @@
"""Configuration manager for multiple configurations."""
from enum import Enum
from pathlib import Path
from typing import Dict, Optional
@@ -7,6 +8,10 @@ from .loader import load_yaml, save_yaml
from .models import FeatureCompressorConfig
class ConfigType(str, Enum):
FeatureCompressor = "feature_compressor"
class ConfigManager:
"""Singleton configuration manager supporting multiple configs."""
@@ -21,9 +26,7 @@ class ConfigManager:
def __init__(self):
self.config_dir = Path(__file__).parent
def load_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
def load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Load configuration from YAML file.
Args:
@@ -56,9 +59,7 @@ class ConfigManager:
self._configs.update(loaded_configs)
return loaded_configs
def get_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
def get_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Get loaded configuration by name.
Args:
@@ -77,9 +78,7 @@ class ConfigManager:
)
return self._configs[config_name]
def get_or_load_config(
self, config_name: str = "feature_compressor"
) -> FeatureCompressorConfig:
def get_or_load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
"""Get loaded configuration by name or load it if not loaded.
Args:
@@ -105,7 +104,7 @@ class ConfigManager:
def save_config(
self,
config_name: str = "feature_compressor",
config_name: ConfigType,
config: Optional[FeatureCompressorConfig] = None,
path: Optional[Path] = None,
) -> None:

41
mini-nav/database.py Normal file
View File

@@ -0,0 +1,41 @@
from pathlib import Path
from typing import Optional
import lancedb
import pyarrow as pa
from configs import ConfigType, cfg_manager
db_schema = pa.schema(
[
pa.field("id", pa.int32()),
pa.field("label", pa.string()),
pa.field("vector", pa.list_(pa.float32(), 1024)),
]
)
class DatabaseManager:
"""Singleton Database manager"""
_instance: Optional["DatabaseManager"] = None
db: lancedb.DBConnection
table: lancedb.Table
def __new__(cls) -> "DatabaseManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# 获取数据库位置
config = cfg_manager.get_or_load_config(ConfigType.FeatureCompressor)
db_path = config.output.directory / "database"
# 初始化数据库与表格
self.db = lancedb.connect(db_path)
if "default" not in self.db.table_names():
self.table = self.db.create_table("default", schema=db_schema)
else:
self.table = self.db.open_table("default")
db_manager = DatabaseManager()

View File

@@ -1,28 +1,44 @@
from typing import cast
import torch
from tqdm.auto import tqdm
from database import db_manager
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel
@torch.no_grad()
def establish_database(processor, model, images, batch_size=64):
def establish_database(processor, model, images, labels, batch_size=64):
device = model.device
model.eval()
for i in tqdm(range(0, len(images), batch_size)):
batch = images[i : i + batch_size]
batch_imgs = images[i : i + batch_size]
inputs = processor(images=batch, return_tensors="pt")
inputs = processor(images=batch_imgs, return_tensors="pt")
# 迁移数据到GPU
inputs.to(device, non_blocking=True)
outputs = model(**inputs)
feats = outputs.last_hidden_state # [B, N, D]
# 后处理 / 存库
# 后处理
feats = outputs.last_hidden_state # [B, N, D]
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
cls_tokens = cast(torch.Tensor, cls_tokens)
# 迁移输出到CPU
cls_tokens = cls_tokens.cpu()
batch_labels = labels[i : i + batch_size]
actual_batch_size = len(batch_labels)
# 存库
db_manager.table.add(
[
{"id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy()}
for j in range(actual_batch_size)
]
)
if __name__ == "__main__":
@@ -34,4 +50,4 @@ if __name__ == "__main__":
)
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
establish_database(processor, model, train_dataset["img"])
establish_database(processor, model, train_dataset["img"], train_dataset["label"])