mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
feat(database): add vector database with ConfigType enum
This commit is contained in:
@@ -8,6 +8,7 @@ from .models import (
|
||||
from .loader import load_yaml, save_yaml, ConfigError
|
||||
from .config import (
|
||||
ConfigManager,
|
||||
ConfigType,
|
||||
cfg_manager,
|
||||
)
|
||||
|
||||
@@ -24,5 +25,6 @@ __all__ = [
|
||||
"ConfigError",
|
||||
# Manager
|
||||
"ConfigManager",
|
||||
"ConfigType",
|
||||
"cfg_manager",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Configuration manager for multiple configurations."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
@@ -7,6 +8,10 @@ from .loader import load_yaml, save_yaml
|
||||
from .models import FeatureCompressorConfig
|
||||
|
||||
|
||||
class ConfigType(str, Enum):
|
||||
FeatureCompressor = "feature_compressor"
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Singleton configuration manager supporting multiple configs."""
|
||||
|
||||
@@ -21,9 +26,7 @@ class ConfigManager:
|
||||
def __init__(self):
|
||||
self.config_dir = Path(__file__).parent
|
||||
|
||||
def load_config(
|
||||
self, config_name: str = "feature_compressor"
|
||||
) -> FeatureCompressorConfig:
|
||||
def load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
|
||||
"""Load configuration from YAML file.
|
||||
|
||||
Args:
|
||||
@@ -56,9 +59,7 @@ class ConfigManager:
|
||||
self._configs.update(loaded_configs)
|
||||
return loaded_configs
|
||||
|
||||
def get_config(
|
||||
self, config_name: str = "feature_compressor"
|
||||
) -> FeatureCompressorConfig:
|
||||
def get_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
|
||||
"""Get loaded configuration by name.
|
||||
|
||||
Args:
|
||||
@@ -77,9 +78,7 @@ class ConfigManager:
|
||||
)
|
||||
return self._configs[config_name]
|
||||
|
||||
def get_or_load_config(
|
||||
self, config_name: str = "feature_compressor"
|
||||
) -> FeatureCompressorConfig:
|
||||
def get_or_load_config(self, config_name: ConfigType) -> FeatureCompressorConfig:
|
||||
"""Get loaded configuration by name or load it if not loaded.
|
||||
|
||||
Args:
|
||||
@@ -105,7 +104,7 @@ class ConfigManager:
|
||||
|
||||
def save_config(
|
||||
self,
|
||||
config_name: str = "feature_compressor",
|
||||
config_name: ConfigType,
|
||||
config: Optional[FeatureCompressorConfig] = None,
|
||||
path: Optional[Path] = None,
|
||||
) -> None:
|
||||
|
||||
41
mini-nav/database.py
Normal file
41
mini-nav/database.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from configs import ConfigType, cfg_manager
|
||||
|
||||
db_schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("label", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 1024)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""Singleton Database manager"""
|
||||
|
||||
_instance: Optional["DatabaseManager"] = None
|
||||
db: lancedb.DBConnection
|
||||
table: lancedb.Table
|
||||
|
||||
def __new__(cls) -> "DatabaseManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# 获取数据库位置
|
||||
config = cfg_manager.get_or_load_config(ConfigType.FeatureCompressor)
|
||||
db_path = config.output.directory / "database"
|
||||
|
||||
# 初始化数据库与表格
|
||||
self.db = lancedb.connect(db_path)
|
||||
if "default" not in self.db.table_names():
|
||||
self.table = self.db.create_table("default", schema=db_schema)
|
||||
else:
|
||||
self.table = self.db.open_table("default")
|
||||
|
||||
|
||||
db_manager = DatabaseManager()
|
||||
@@ -1,28 +1,44 @@
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from database import db_manager
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def establish_database(processor, model, images, batch_size=64):
|
||||
def establish_database(processor, model, images, labels, batch_size=64):
|
||||
device = model.device
|
||||
model.eval()
|
||||
|
||||
for i in tqdm(range(0, len(images), batch_size)):
|
||||
batch = images[i : i + batch_size]
|
||||
batch_imgs = images[i : i + batch_size]
|
||||
|
||||
inputs = processor(images=batch, return_tensors="pt")
|
||||
inputs = processor(images=batch_imgs, return_tensors="pt")
|
||||
|
||||
# 迁移数据到GPU
|
||||
inputs.to(device, non_blocking=True)
|
||||
|
||||
outputs = model(**inputs)
|
||||
feats = outputs.last_hidden_state # [B, N, D]
|
||||
# 后处理 / 存库
|
||||
|
||||
|
||||
# 后处理
|
||||
feats = outputs.last_hidden_state # [B, N, D]
|
||||
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
|
||||
cls_tokens = cast(torch.Tensor, cls_tokens)
|
||||
|
||||
# 迁移输出到CPU
|
||||
cls_tokens = cls_tokens.cpu()
|
||||
batch_labels = labels[i : i + batch_size]
|
||||
actual_batch_size = len(batch_labels)
|
||||
|
||||
# 存库
|
||||
db_manager.table.add(
|
||||
[
|
||||
{"id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy()}
|
||||
for j in range(actual_batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -34,4 +50,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
||||
|
||||
establish_database(processor, model, train_dataset["img"])
|
||||
establish_database(processor, model, train_dataset["img"], train_dataset["label"])
|
||||
|
||||
Reference in New Issue
Block a user