mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(compressors): Simplify module by removing SAM/DINO separation code
- Remove dino_compressor.py and segament_compressor.py - Rewrite pipeline.py to inline DINO into HashPipeline - Maintain backward compatibility: SAMHashPipeline alias - Update tests and benchmark.py
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import typer
|
||||
from commands import app
|
||||
@@ -7,15 +7,15 @@ from commands import app
|
||||
@app.command()
|
||||
def benchmark(
|
||||
ctx: typer.Context,
|
||||
model_path: str = typer.Option(
|
||||
model_path: Optional[str] = typer.Option(
|
||||
None, "--model", "-m", help="Path to compressor model weights"
|
||||
),
|
||||
):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from benchmarks import run_benchmark
|
||||
from compressors import DinoCompressor
|
||||
from configs import cfg_manager
|
||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||
from transformers import AutoImageProcessor, AutoModel, BitImageProcessorFast
|
||||
from utils import get_device
|
||||
|
||||
config = cfg_manager.get()
|
||||
@@ -29,7 +29,12 @@ def benchmark(
|
||||
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
|
||||
)
|
||||
|
||||
model = DinoCompressor().to(device)
|
||||
# Load DINO model for feature extraction
|
||||
dino = AutoModel.from_pretrained(model_cfg.dino_model, device_map=device)
|
||||
dino.eval()
|
||||
|
||||
# Optional hash compressor
|
||||
compressor = None
|
||||
if model_path:
|
||||
from compressors import HashCompressor
|
||||
|
||||
@@ -38,7 +43,31 @@ def benchmark(
|
||||
hash_bits=model_cfg.compression_dim,
|
||||
)
|
||||
compressor.load_state_dict(torch.load(model_path))
|
||||
model.compressor = compressor
|
||||
compressor.to(device)
|
||||
compressor.eval()
|
||||
|
||||
# Create wrapper with extract_features method
|
||||
class DinoFeatureExtractor:
|
||||
def __init__(self, dino, compressor=None):
|
||||
self.dino = dino
|
||||
self.compressor = compressor
|
||||
|
||||
def extract_features(self, images: list) -> torch.Tensor:
|
||||
inputs = processor(images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = self.dino(**inputs)
|
||||
features = outputs.last_hidden_state.mean(dim=1)
|
||||
features = F.normalize(features, dim=-1)
|
||||
return features
|
||||
|
||||
def encode(self, images: list) -> torch.Tensor:
|
||||
if self.compressor is None:
|
||||
return self.extract_features(images)
|
||||
tokens = self.dino(**processor(images, return_tensors="pt").to(device)).last_hidden_state
|
||||
_, _, bits = self.compressor(tokens)
|
||||
return bits
|
||||
|
||||
model = DinoFeatureExtractor(dino, compressor)
|
||||
|
||||
run_benchmark(
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user