refactor(feature-retrieval): convert standalone function to singleton class

This commit is contained in:
2026-02-05 14:11:26 +08:00
parent 90087194ec
commit ea747f0e5b

View File

@@ -7,51 +7,99 @@ from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel from transformers import AutoImageProcessor, AutoModel
@torch.no_grad() class FeatureRetrieval:
def establish_database( """Singleton feature retrieval manager for image feature extraction."""
processor,
model,
images: List[Any],
labels: List[int] | List[str],
batch_size=64,
label_map: Optional[Dict[int, str] | List[str]] = None,
):
device = model.device
model.eval()
for i in tqdm(range(0, len(images), batch_size)): _instance: Optional["FeatureRetrieval"] = None
batch_imgs = images[i : i + batch_size]
inputs = processor(images=batch_imgs, return_tensors="pt") _initialized: bool = False
processor: Any
model: Any
# 迁移数据到GPU def __new__(cls, *args, **kwargs) -> "FeatureRetrieval":
inputs.to(device, non_blocking=True) if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
outputs = model(**inputs) def __init__(
self, processor: Optional[Any] = None, model: Optional[Any] = None
) -> None:
"""Initialize the singleton with processor and model.
# 后处理 Args:
feats = outputs.last_hidden_state # [B, N, D] processor: Image processor for preprocessing images.
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items model: Model for feature extraction.
cls_tokens = cast(torch.Tensor, cls_tokens) """
# 如果已经初始化过,直接返回
if self._initialized:
return
# 迁移输出到CPU # 首次初始化时必须提供 processor 和 model
cls_tokens = cls_tokens.cpu() if processor is None or model is None:
batch_labels = ( raise ValueError(
labels[i : i + batch_size] "Processor and model must be provided on first initialization."
if label_map is None
else list(
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
) )
)
actual_batch_size = len(batch_labels)
# 存库 self.processor = processor
db_manager.table.add( self.model = model
[ self._initialized = True
{"id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy()}
for j in range(actual_batch_size) @torch.no_grad()
] def establish_database(
) self,
images: List[Any],
labels: List[int] | List[str],
batch_size: int = 64,
label_map: Optional[Dict[int, str] | List[str]] = None,
) -> None:
"""Extract features from images and store them in the database.
Args:
images: List of images to process.
labels: List of labels corresponding to images.
batch_size: Number of images to process in a batch.
label_map: Optional mapping from label indices to string names.
"""
device = self.model.device
self.model.eval()
for i in tqdm(range(0, len(images), batch_size)):
batch_imgs = images[i : i + batch_size]
inputs = self.processor(images=batch_imgs, return_tensors="pt")
# 迁移数据到GPU
inputs.to(device, non_blocking=True)
outputs = self.model(**inputs)
# 后处理
feats = outputs.last_hidden_state # [B, N, D]
cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items
cls_tokens = cast(torch.Tensor, cls_tokens)
# 迁移输出到CPU
cls_tokens = cls_tokens.cpu()
batch_labels = (
labels[i : i + batch_size]
if label_map is None
else list(
map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size])
)
)
actual_batch_size = len(batch_labels)
# 存库
db_manager.table.add(
[
{
"id": i + j,
"label": batch_labels[j],
"vector": cls_tokens[j].numpy(),
}
for j in range(actual_batch_size)
]
)
if __name__ == "__main__": if __name__ == "__main__":
@@ -75,9 +123,9 @@ if __name__ == "__main__":
) )
model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda") model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda")
establish_database( feature_retrieval = FeatureRetrieval(processor, model)
processor,
model, feature_retrieval.establish_database(
train_dataset["img"], train_dataset["img"],
train_dataset["label"], train_dataset["label"],
label_map=label_map, label_map=label_map,