diff --git a/mini-nav/feature_retrieval.py b/mini-nav/feature_retrieval.py index ed0f1a5..df4ab32 100644 --- a/mini-nav/feature_retrieval.py +++ b/mini-nav/feature_retrieval.py @@ -7,51 +7,99 @@ from tqdm.auto import tqdm from transformers import AutoImageProcessor, AutoModel -@torch.no_grad() -def establish_database( - processor, - model, - images: List[Any], - labels: List[int] | List[str], - batch_size=64, - label_map: Optional[Dict[int, str] | List[str]] = None, -): - device = model.device - model.eval() +class FeatureRetrieval: + """Singleton feature retrieval manager for image feature extraction.""" - for i in tqdm(range(0, len(images), batch_size)): - batch_imgs = images[i : i + batch_size] + _instance: Optional["FeatureRetrieval"] = None - inputs = processor(images=batch_imgs, return_tensors="pt") + _initialized: bool = False + processor: Any + model: Any - # 迁移数据到GPU - inputs.to(device, non_blocking=True) + def __new__(cls, *args, **kwargs) -> "FeatureRetrieval": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance - outputs = model(**inputs) + def __init__( + self, processor: Optional[Any] = None, model: Optional[Any] = None + ) -> None: + """Initialize the singleton with processor and model. - # 后处理 - feats = outputs.last_hidden_state # [B, N, D] - cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items - cls_tokens = cast(torch.Tensor, cls_tokens) + Args: + processor: Image processor for preprocessing images. + model: Model for feature extraction. + """ + # 如果已经初始化过,直接返回 + if self._initialized: + return - # 迁移输出到CPU - cls_tokens = cls_tokens.cpu() - batch_labels = ( - labels[i : i + batch_size] - if label_map is None - else list( - map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size]) + # 首次初始化时必须提供 processor 和 model + if processor is None or model is None: + raise ValueError( + "Processor and model must be provided on first initialization." ) - ) - actual_batch_size = len(batch_labels) - # 存库 - db_manager.table.add( - [ - {"id": i + j, "label": batch_labels[j], "vector": cls_tokens[j].numpy()} - for j in range(actual_batch_size) - ] - ) + self.processor = processor + self.model = model + self._initialized = True + + @torch.no_grad() + def establish_database( + self, + images: List[Any], + labels: List[int] | List[str], + batch_size: int = 64, + label_map: Optional[Dict[int, str] | List[str]] = None, + ) -> None: + """Extract features from images and store them in the database. + + Args: + images: List of images to process. + labels: List of labels corresponding to images. + batch_size: Number of images to process in a batch. + label_map: Optional mapping from label indices to string names. + """ + device = self.model.device + self.model.eval() + + for i in tqdm(range(0, len(images), batch_size)): + batch_imgs = images[i : i + batch_size] + + inputs = self.processor(images=batch_imgs, return_tensors="pt") + + # 迁移数据到GPU + inputs.to(device, non_blocking=True) + + outputs = self.model(**inputs) + + # 后处理 + feats = outputs.last_hidden_state # [B, N, D] + cls_tokens = feats[:, 0] # Get CLS token (first token) for all batch items + cls_tokens = cast(torch.Tensor, cls_tokens) + + # 迁移输出到CPU + cls_tokens = cls_tokens.cpu() + batch_labels = ( + labels[i : i + batch_size] + if label_map is None + else list( + map(lambda x: label_map[cast(int, x)], labels[i : i + batch_size]) + ) + ) + actual_batch_size = len(batch_labels) + + # 存库 + db_manager.table.add( + [ + { + "id": i + j, + "label": batch_labels[j], + "vector": cls_tokens[j].numpy(), + } + for j in range(actual_batch_size) + ] + ) if __name__ == "__main__": @@ -75,9 +123,9 @@ if __name__ == "__main__": ) model = AutoModel.from_pretrained("facebook/dinov2-large", device_map="cuda") - establish_database( - processor, - model, + feature_retrieval = FeatureRetrieval(processor, model) + + feature_retrieval.establish_database( train_dataset["img"], train_dataset["label"], label_map=label_map,