mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 20:35:31 +08:00
299 lines
9.5 KiB
Python
299 lines
9.5 KiB
Python
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
|
|
|
|
import pytest
|
|
import torch
|
|
from compressors import (
|
|
BinarySign,
|
|
DinoCompressor,
|
|
HashCompressor,
|
|
SAMHashPipeline,
|
|
SegmentCompressor,
|
|
bits_to_hash,
|
|
create_pipeline_from_config,
|
|
hamming_distance,
|
|
hamming_similarity,
|
|
hash_to_bits,
|
|
)
|
|
from configs import cfg_manager
|
|
from PIL import Image
|
|
|
|
|
|
class TestHashCompressor:
|
|
"""Test suite for HashCompressor."""
|
|
|
|
def test_hash_compressor_init(self):
|
|
"""Verify HashCompressor initializes with correct dimensions."""
|
|
compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
|
assert compressor.input_dim == 1024
|
|
assert compressor.hash_bits == 512
|
|
|
|
def test_hash_compressor_forward(self):
|
|
"""Verify forward pass produces correct output shapes."""
|
|
compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
|
tokens = torch.randn(4, 197, 1024) # [B, N, input_dim]
|
|
|
|
logits, hash_codes, bits = compressor(tokens)
|
|
|
|
assert logits.shape == (4, 512)
|
|
assert hash_codes.shape == (4, 512)
|
|
assert bits.shape == (4, 512)
|
|
# Verify bits are binary (0 or 1)
|
|
assert torch.all((bits == 0) | (bits == 1))
|
|
|
|
def test_hash_compressor_encode(self):
|
|
"""Verify encode method returns binary bits."""
|
|
compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
|
tokens = torch.randn(2, 197, 1024)
|
|
|
|
bits = compressor.encode(tokens)
|
|
|
|
assert bits.shape == (2, 512)
|
|
assert bits.dtype == torch.int32
|
|
assert torch.all((bits == 0) | (bits == 1))
|
|
|
|
def test_hash_compressor_similarity(self):
|
|
"""Verify compute_similarity returns correct shape."""
|
|
compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
|
|
|
# Create random bits
|
|
bits1 = torch.randint(0, 2, (3, 512))
|
|
bits2 = torch.randint(0, 2, (5, 512))
|
|
|
|
sim = compressor.compute_similarity(bits1, bits2)
|
|
|
|
assert sim.shape == (3, 5)
|
|
|
|
|
|
class TestBinarySign:
|
|
"""Test suite for BinarySign function."""
|
|
|
|
def test_binary_sign_forward(self):
|
|
"""Verify BinarySign produces {-1, +1} outputs."""
|
|
x = torch.randn(4, 512)
|
|
result = BinarySign.apply(x)
|
|
|
|
assert torch.all((result == 1) | (result == -1))
|
|
|
|
def test_binary_sign_round_trip(self):
|
|
"""Verify bits -> hash -> bits preserves values."""
|
|
bits = torch.randint(0, 2, (4, 512))
|
|
hash_codes = bits_to_hash(bits)
|
|
bits_recovered = hash_to_bits(hash_codes)
|
|
|
|
assert torch.equal(bits, bits_recovered)
|
|
|
|
|
|
class TestHammingMetrics:
|
|
"""Test suite for Hamming distance and similarity."""
|
|
|
|
def test_hamming_distance_same_codes(self):
|
|
"""Verify hamming distance is 0 for identical single codes."""
|
|
bits1 = torch.randint(0, 2, (512,))
|
|
bits2 = bits1.clone()
|
|
|
|
dist = hamming_distance(bits1, bits2)
|
|
|
|
assert dist.item() == 0
|
|
|
|
def test_hamming_distance_self_comparison(self):
|
|
"""Verify hamming distance diagonal is 0 (each code compared to itself)."""
|
|
bits = torch.randint(0, 2, (10, 512))
|
|
|
|
dist = hamming_distance(bits, bits)
|
|
|
|
# Diagonal should be 0 (distance to self)
|
|
diagonal = torch.diag(dist)
|
|
assert torch.all(diagonal == 0)
|
|
|
|
def test_hamming_distance_different(self):
|
|
"""Verify hamming distance is correct for different codes."""
|
|
bits1 = torch.zeros(1, 512, dtype=torch.int32)
|
|
bits2 = torch.ones(1, 512, dtype=torch.int32)
|
|
|
|
dist = hamming_distance(bits1, bits2)
|
|
|
|
assert dist.item() == 512
|
|
|
|
def test_hamming_similarity(self):
|
|
"""Verify hamming similarity is positive for similar codes."""
|
|
hash1 = torch.ones(1, 512)
|
|
hash2 = torch.ones(1, 512)
|
|
|
|
sim = hamming_similarity(hash1, hash2)
|
|
|
|
assert sim.item() == 512 # Max similarity
|
|
|
|
|
|
class TestSegmentCompressor:
|
|
"""Test suite for SegmentCompressor."""
|
|
|
|
@pytest.fixture
|
|
def mock_image(self):
|
|
"""Create a mock PIL image."""
|
|
img = Image.new("RGB", (224, 224), color="red")
|
|
return img
|
|
|
|
def test_segment_compressor_init(self):
|
|
"""Verify SegmentCompressor initializes with correct parameters."""
|
|
segmentor = SegmentCompressor(
|
|
model_name="facebook/sam2.1-hiera-large",
|
|
min_mask_area=100,
|
|
max_masks=10,
|
|
)
|
|
|
|
assert segmentor.model_name == "facebook/sam2.1-hiera-large"
|
|
assert segmentor.min_mask_area == 100
|
|
assert segmentor.max_masks == 10
|
|
|
|
def test_filter_masks(self):
|
|
"""Verify mask filtering logic."""
|
|
# Create segmentor to get default filter params
|
|
segmentor = SegmentCompressor()
|
|
|
|
# Create mock masks tensor with different areas
|
|
# Masks shape: [N, H, W]
|
|
masks = []
|
|
for area in [50, 200, 150, 300, 10]:
|
|
mask = torch.zeros(100, 100)
|
|
mask[:1, :area] = 1 # Create mask with specific area
|
|
masks.append(mask)
|
|
|
|
masks_tensor = torch.stack(masks) # [5, 100, 100]
|
|
valid = segmentor._filter_masks(masks_tensor)
|
|
|
|
# Should filter out 50 and 10 (below min_mask_area=100)
|
|
# Then keep top 3 (max_masks=10)
|
|
assert len(valid) == 3
|
|
# Verify sorted by area (descending)
|
|
areas = [v["area"] for v in valid]
|
|
assert areas == sorted(areas, reverse=True)
|
|
|
|
|
|
class TestDinoCompressor:
|
|
"""Test suite for DinoCompressor."""
|
|
|
|
def test_dino_compressor_init(self):
|
|
"""Verify DinoCompressor initializes correctly."""
|
|
dino = DinoCompressor()
|
|
|
|
assert dino.model_name == "facebook/dinov2-large"
|
|
|
|
def test_dino_compressor_with_compressor(self):
|
|
"""Verify DinoCompressor with HashCompressor."""
|
|
hash_compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
|
dino = DinoCompressor(compressor=hash_compressor)
|
|
|
|
assert dino.compressor is hash_compressor
|
|
|
|
|
|
class TestSAMHashPipeline:
|
|
"""Test suite for SAMHashPipeline."""
|
|
|
|
def test_pipeline_init(self):
|
|
"""Verify pipeline initializes all components."""
|
|
pipeline = SAMHashPipeline(
|
|
sam_model="facebook/sam2.1-hiera-large",
|
|
dino_model="facebook/dinov2-large",
|
|
hash_bits=512,
|
|
)
|
|
|
|
assert isinstance(pipeline.segmentor, SegmentCompressor)
|
|
assert isinstance(pipeline.dino, DinoCompressor)
|
|
assert isinstance(pipeline.hash_compressor, HashCompressor)
|
|
|
|
def test_pipeline_hash_bits(self):
|
|
"""Verify pipeline uses correct hash bits."""
|
|
pipeline = SAMHashPipeline(hash_bits=256)
|
|
assert pipeline.hash_compressor.hash_bits == 256
|
|
|
|
|
|
class TestConfigIntegration:
|
|
"""Test suite for config integration with pipeline."""
|
|
|
|
def test_create_pipeline_from_config(self):
|
|
"""Verify pipeline can be created from config."""
|
|
config = cfg_manager.load()
|
|
|
|
pipeline = create_pipeline_from_config(config)
|
|
|
|
assert isinstance(pipeline, SAMHashPipeline)
|
|
assert pipeline.hash_compressor.hash_bits == config.model.compression_dim
|
|
|
|
def test_config_sam_settings(self):
|
|
"""Verify config contains SAM settings."""
|
|
config = cfg_manager.load()
|
|
|
|
assert hasattr(config.model, "sam_model")
|
|
assert hasattr(config.model, "sam_min_mask_area")
|
|
assert hasattr(config.model, "sam_max_masks")
|
|
assert config.model.sam_model == "facebook/sam2.1-hiera-large"
|
|
assert config.model.sam_min_mask_area == 100
|
|
assert config.model.sam_max_masks == 10
|
|
|
|
|
|
class TestPipelineIntegration:
|
|
"""Integration tests for full pipeline (slow, requires model downloads)."""
|
|
|
|
@pytest.mark.slow
|
|
def test_pipeline_end_to_end(self):
|
|
"""Test full pipeline with actual models (slow test)."""
|
|
# Skip if no GPU
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("Requires CUDA")
|
|
|
|
# Create a simple test image
|
|
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
|
|
|
# Initialize pipeline (will download models on first run)
|
|
pipeline = SAMHashPipeline(
|
|
sam_model="facebook/sam2.1-hiera-large",
|
|
dino_model="facebook/dinov2-large",
|
|
hash_bits=512,
|
|
sam_min_mask_area=100,
|
|
sam_max_masks=5,
|
|
)
|
|
|
|
# Run pipeline
|
|
hash_codes = pipeline(image)
|
|
|
|
# Verify output shape
|
|
assert hash_codes.dim() == 2
|
|
assert hash_codes.shape[1] == 512
|
|
assert torch.all((hash_codes == 0) | (hash_codes == 1))
|
|
|
|
@pytest.mark.slow
|
|
def test_extract_features_without_hash(self):
|
|
"""Test feature extraction without hash compression."""
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("Requires CUDA")
|
|
|
|
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
|
|
|
pipeline = SAMHashPipeline(
|
|
sam_model="facebook/sam2.1-hiera-large",
|
|
dino_model="facebook/dinov2-large",
|
|
)
|
|
|
|
features = pipeline.extract_features(image, use_hash=False)
|
|
|
|
# Should return DINO features (1024 for large)
|
|
assert features.dim() == 2
|
|
assert features.shape[1] == 1024
|
|
|
|
@pytest.mark.slow
|
|
def test_extract_masks_only(self):
|
|
"""Test mask extraction only."""
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("Requires CUDA")
|
|
|
|
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
|
|
|
pipeline = SAMHashPipeline(
|
|
sam_model="facebook/sam2.1-hiera-large",
|
|
)
|
|
|
|
masks = pipeline.extract_masks(image)
|
|
|
|
# Should return a list of masks
|
|
assert isinstance(masks, list)
|