mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(compressors): Simplify module by removing SAM/DINO separation code
- Remove dino_compressor.py and segament_compressor.py - Rewrite pipeline.py to inline DINO into HashPipeline - Maintain backward compatibility: SAMHashPipeline alias - Update tests and benchmark.py
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
|
||||
"""Tests for compressor modules (HashCompressor, Pipeline)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from compressors import (
|
||||
BinarySign,
|
||||
DinoCompressor,
|
||||
HashCompressor,
|
||||
HashPipeline,
|
||||
SAMHashPipeline,
|
||||
SegmentCompressor,
|
||||
VideoPositiveMask,
|
||||
bits_to_hash,
|
||||
create_pipeline_from_config,
|
||||
hamming_distance,
|
||||
@@ -124,87 +124,105 @@ class TestHammingMetrics:
|
||||
assert sim.item() == 512 # Max similarity
|
||||
|
||||
|
||||
class TestSegmentCompressor:
|
||||
"""Test suite for SegmentCompressor."""
|
||||
class TestHashLoss:
|
||||
"""Test suite for HashLoss."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image(self):
|
||||
"""Create a mock PIL image."""
|
||||
img = Image.new("RGB", (224, 224), color="red")
|
||||
return img
|
||||
def test_hash_loss_init(self):
|
||||
"""Verify HashLoss initializes with correct parameters."""
|
||||
from compressors import HashLoss
|
||||
|
||||
def test_segment_compressor_init(self):
|
||||
"""Verify SegmentCompressor initializes with correct parameters."""
|
||||
segmentor = SegmentCompressor(
|
||||
model_name="facebook/sam2.1-hiera-large",
|
||||
min_mask_area=100,
|
||||
max_masks=10,
|
||||
loss_fn = HashLoss(
|
||||
contrastive_weight=1.0,
|
||||
distill_weight=0.5,
|
||||
quant_weight=0.01,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
assert segmentor.model_name == "facebook/sam2.1-hiera-large"
|
||||
assert segmentor.min_mask_area == 100
|
||||
assert segmentor.max_masks == 10
|
||||
assert loss_fn.contrastive_weight == 1.0
|
||||
assert loss_fn.distill_weight == 0.5
|
||||
assert loss_fn.quant_weight == 0.01
|
||||
assert loss_fn.temperature == 0.2
|
||||
|
||||
def test_filter_masks(self):
|
||||
"""Verify mask filtering logic."""
|
||||
# Create segmentor to get default filter params
|
||||
segmentor = SegmentCompressor()
|
||||
def test_hash_loss_forward(self):
|
||||
"""Verify HashLoss computes loss correctly."""
|
||||
from compressors import HashLoss
|
||||
|
||||
# Create mock masks tensor with different areas
|
||||
# Masks shape: [N, H, W]
|
||||
masks = []
|
||||
for area in [50, 200, 150, 300, 10]:
|
||||
mask = torch.zeros(100, 100)
|
||||
mask[:1, :area] = 1 # Create mask with specific area
|
||||
masks.append(mask)
|
||||
loss_fn = HashLoss()
|
||||
|
||||
masks_tensor = torch.stack(masks) # [5, 100, 100]
|
||||
valid = segmentor._filter_masks(masks_tensor)
|
||||
batch_size = 4
|
||||
hash_bits = 512
|
||||
logits = torch.randn(batch_size, hash_bits)
|
||||
hash_codes = torch.sign(logits)
|
||||
teacher_embed = torch.randn(batch_size, 1024)
|
||||
positive_mask = torch.eye(batch_size, dtype=torch.bool)
|
||||
|
||||
# Should filter out 50 and 10 (below min_mask_area=100)
|
||||
# Then keep top 3 (max_masks=10)
|
||||
assert len(valid) == 3
|
||||
# Verify sorted by area (descending)
|
||||
areas = [v["area"] for v in valid]
|
||||
assert areas == sorted(areas, reverse=True)
|
||||
total_loss, components = loss_fn(
|
||||
logits=logits,
|
||||
hash_codes=hash_codes,
|
||||
teacher_embed=teacher_embed,
|
||||
positive_mask=positive_mask,
|
||||
)
|
||||
|
||||
assert "contrastive" in components
|
||||
assert "distill" in components
|
||||
assert "quantization" in components
|
||||
assert "total" in components
|
||||
|
||||
|
||||
class TestDinoCompressor:
|
||||
"""Test suite for DinoCompressor."""
|
||||
class TestVideoPositiveMask:
|
||||
"""Test suite for VideoPositiveMask."""
|
||||
|
||||
def test_dino_compressor_init(self):
|
||||
"""Verify DinoCompressor initializes correctly."""
|
||||
dino = DinoCompressor()
|
||||
def test_from_frame_indices(self):
|
||||
"""Verify positive mask generation from frame indices."""
|
||||
mask_gen = VideoPositiveMask(temporal_window=2)
|
||||
|
||||
assert dino.model_name == "facebook/dinov2-large"
|
||||
frame_indices = torch.tensor([0, 1, 3, 5])
|
||||
|
||||
def test_dino_compressor_with_compressor(self):
|
||||
"""Verify DinoCompressor with HashCompressor."""
|
||||
hash_compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
||||
dino = DinoCompressor(compressor=hash_compressor)
|
||||
mask = mask_gen.from_frame_indices(frame_indices)
|
||||
|
||||
assert dino.compressor is hash_compressor
|
||||
assert mask.shape == (4, 4)
|
||||
# Frame 0 and 1 should be positive (distance 1 <= 2)
|
||||
assert mask[0, 1] == True
|
||||
# Frame 0 and 3 should be negative (distance 3 > 2)
|
||||
assert mask[0, 3] == False
|
||||
|
||||
def test_from_video_ids(self):
|
||||
"""Verify positive mask generation from video IDs and frame indices."""
|
||||
mask_gen = VideoPositiveMask(temporal_window=2)
|
||||
|
||||
video_ids = torch.tensor([0, 0, 1, 1])
|
||||
frame_indices = torch.tensor([0, 1, 0, 1])
|
||||
|
||||
mask = mask_gen.from_video_ids(video_ids, frame_indices)
|
||||
|
||||
assert mask.shape == (4, 4)
|
||||
# Same video and temporally close
|
||||
assert mask[0, 1] == True # video 0, frames 0,1
|
||||
# Different video
|
||||
assert mask[0, 2] == False # video 0 vs 1
|
||||
|
||||
|
||||
class TestSAMHashPipeline:
|
||||
"""Test suite for SAMHashPipeline."""
|
||||
class TestHashPipeline:
|
||||
"""Test suite for HashPipeline."""
|
||||
|
||||
def test_pipeline_init(self):
|
||||
"""Verify pipeline initializes all components."""
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
hash_bits=512,
|
||||
)
|
||||
|
||||
assert isinstance(pipeline.segmentor, SegmentCompressor)
|
||||
assert isinstance(pipeline.dino, DinoCompressor)
|
||||
assert isinstance(pipeline.hash_compressor, HashCompressor)
|
||||
assert pipeline.dino_model == "facebook/dinov2-large"
|
||||
assert pipeline.dino_dim == 1024
|
||||
|
||||
def test_pipeline_hash_bits(self):
|
||||
"""Verify pipeline uses correct hash bits."""
|
||||
pipeline = SAMHashPipeline(hash_bits=256)
|
||||
assert pipeline.hash_compressor.hash_bits == 256
|
||||
pipeline = HashPipeline(hash_bits=256)
|
||||
assert pipeline.hash_bits == 256
|
||||
|
||||
def test_pipeline_alias(self):
|
||||
"""Verify SAMHashPipeline is alias for HashPipeline."""
|
||||
assert SAMHashPipeline is HashPipeline
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
@@ -216,25 +234,21 @@ class TestConfigIntegration:
|
||||
|
||||
pipeline = create_pipeline_from_config(config)
|
||||
|
||||
assert isinstance(pipeline, SAMHashPipeline)
|
||||
assert pipeline.hash_compressor.hash_bits == config.model.compression_dim
|
||||
assert isinstance(pipeline, HashPipeline)
|
||||
assert pipeline.hash_bits == config.model.compression_dim
|
||||
|
||||
def test_config_sam_settings(self):
|
||||
"""Verify config contains SAM settings."""
|
||||
def test_config_settings(self):
|
||||
"""Verify config contains required settings."""
|
||||
config = cfg_manager.load()
|
||||
|
||||
assert hasattr(config.model, "sam_model")
|
||||
assert hasattr(config.model, "sam_min_mask_area")
|
||||
assert hasattr(config.model, "sam_max_masks")
|
||||
assert config.model.sam_model == "facebook/sam2.1-hiera-large"
|
||||
assert config.model.sam_min_mask_area == 100
|
||||
assert config.model.sam_max_masks == 10
|
||||
assert hasattr(config.model, "dino_model")
|
||||
assert hasattr(config.model, "compression_dim")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestPipelineIntegration:
|
||||
"""Integration tests for full pipeline (slow, requires model downloads)."""
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_pipeline_end_to_end(self):
|
||||
"""Test full pipeline with actual models (slow test)."""
|
||||
# Skip if no GPU
|
||||
@@ -245,54 +259,32 @@ class TestPipelineIntegration:
|
||||
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
||||
|
||||
# Initialize pipeline (will download models on first run)
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
hash_bits=512,
|
||||
sam_min_mask_area=100,
|
||||
sam_max_masks=5,
|
||||
)
|
||||
|
||||
# Run pipeline
|
||||
hash_codes = pipeline(image)
|
||||
hash_bits = pipeline(image)
|
||||
|
||||
# Verify output shape
|
||||
assert hash_codes.dim() == 2
|
||||
assert hash_codes.shape[1] == 512
|
||||
assert torch.all((hash_codes == 0) | (hash_codes == 1))
|
||||
assert hash_bits.dim() == 2
|
||||
assert hash_bits.shape[1] == 512
|
||||
assert torch.all((hash_bits == 0) | (hash_bits == 1))
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_features_without_hash(self):
|
||||
"""Test feature extraction without hash compression."""
|
||||
def test_extract_features(self):
|
||||
"""Test feature extraction."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Requires CUDA")
|
||||
|
||||
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
||||
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
)
|
||||
|
||||
features = pipeline.extract_features(image, use_hash=False)
|
||||
features = pipeline.extract_features(image)
|
||||
|
||||
# Should return DINO features (1024 for large)
|
||||
assert features.dim() == 2
|
||||
assert features.shape[1] == 1024
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_masks_only(self):
|
||||
"""Test mask extraction only."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Requires CUDA")
|
||||
|
||||
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
||||
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
)
|
||||
|
||||
masks = pipeline.extract_masks(image)
|
||||
|
||||
# Should return a list of masks
|
||||
assert isinstance(masks, list)
|
||||
|
||||
Reference in New Issue
Block a user