mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(feature-retrieval): convert standalone function to singleton class
This commit is contained in:
@@ -7,51 +7,99 @@ from tqdm.auto import tqdm
|
|||||||
from transformers import AutoImageProcessor, AutoModel
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
class FeatureRetrieval:
|
||||||
def establish_database(
|
"""Singleton feature retrieval manager for image feature extraction."""
|
||||||
processor,
|
|
||||||
model,
|
|
||||||
images: List[Any],
|
|
||||||
labels: List[int] | List[str],
|
|
||||||
batch_size=64,
|
|
||||||
label_map: Optional[Dict[int, str] | List[str]] = None,
|
|
||||||
):
|
|
||||||
device = model.device
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
for i in tqdm(range(0, len(images), batch_size)):
|
_instance: Optional["FeatureRetrieval"] = None
|
||||||
batch_imgs = images[i : i + batch_size]
|
|
||||||
|
|
||||||
inputs = processor(images=batch_imgs, return_tensors="pt")
|
_initialized: bool = False
|
||||||
|
processor: Any
|
||||||
|
model: Any
|
||||||
|
|
||||||
# 迁移数据到GPU
|
def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
|
||||||
inputs.to(device, non_blocking=True)
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
outputs = model(**inputs)
|
def __init__(
|
||||||
|
self, processor: Optional[Any] = None, model: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the singleton with processor and model.
|
||||||
|
|
||||||
# 后处理
|
Args:
|
||||||
feats = outputs.last_hidden_state # [B, N, D]
|
processor: Image processor for preprocessing images.
|
||||||
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
|
model: Model for feature extraction.
|
||||||
cls_tokens = cast(torch.Tensor, cls_tokens)
|
"""
|
||||||
|
# 如果已经初始化过,直接返回
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
# 迁移输出到CPU
|
# 首次初始化时必须提供 processor 和 model
|
||||||
cls_tokens = cls_tokens.cpu()
|
if processor is None or model is None:
|
||||||
batch_labels = (
|
raise ValueError(
|
||||||
labels[i : i + batch_size]
|
"Processor and model must be provided on first initialization."
|
||||||
if label_map is None
|
|
||||||
else list(
|
|
||||||
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
actual_batch_size = len(batch_labels)
|
|
||||||
|
|
||||||
# 存库
|
self.processor = processor
|
||||||
db_manager.table.add(
|
self.model = model
|
||||||
[
|
self._initialized = True
|
||||||
{"id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy()}
|
|
||||||
for j in range(actual_batch_size)
|
@torch.no_grad()
|
||||||
]
|
def establish_database(
|
||||||
)
|
self,
|
||||||
|
images: List[Any],
|
||||||
|
labels: List[int] | List[str],
|
||||||
|
batch_size: int = 64,
|
||||||
|
label_map: Optional[Dict[int, str] | List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Extract features from images and store them in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of images to process.
|
||||||
|
labels: List of labels corresponding to images.
|
||||||
|
batch_size: Number of images to process in a batch.
|
||||||
|
label_map: Optional mapping from label indices to string names.
|
||||||
|
"""
|
||||||
|
device = self.model.device
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
for i in tqdm(range(0, len(images), batch_size)):
|
||||||
|
batch_imgs = images[i : i + batch_size]
|
||||||
|
|
||||||
|
inputs = self.processor(images=batch_imgs, return_tensors="pt")
|
||||||
|
|
||||||
|
# 迁移数据到GPU
|
||||||
|
inputs.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
# 后处理
|
||||||
|
feats = outputs.last_hidden_state # [B, N, D]
|
||||||
|
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
|
||||||
|
cls_tokens = cast(torch.Tensor, cls_tokens)
|
||||||
|
|
||||||
|
# 迁移输出到CPU
|
||||||
|
cls_tokens = cls_tokens.cpu()
|
||||||
|
batch_labels = (
|
||||||
|
labels[i : i + batch_size]
|
||||||
|
if label_map is None
|
||||||
|
else list(
|
||||||
|
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
actual_batch_size = len(batch_labels)
|
||||||
|
|
||||||
|
# 存库
|
||||||
|
db_manager.table.add(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": i + j,
|
||||||
|
"label": batch_labels[j],
|
||||||
|
"vector": cls_tokens[j].numpy(),
|
||||||
|
}
|
||||||
|
for j in range(actual_batch_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -75,9 +123,9 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
|
||||||
|
|
||||||
establish_database(
|
feature_retrieval = FeatureRetrieval(processor, model)
|
||||||
processor,
|
|
||||||
model,
|
feature_retrieval.establish_database(
|
||||||
train_dataset["img"],
|
train_dataset["img"],
|
||||||
train_dataset["label"],
|
train_dataset["label"],
|
||||||
label_map=label_map,
|
label_map=label_map,
|
||||||
|
|||||||
Reference in New Issue
Block a user