mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
100 lines
3.0 KiB
Python
100 lines
3.0 KiB
Python
"""Tests for PoolNetCompressor module."""
|
|
|
|
import pytest
|
|
import torch
|
|
from feature_compressor.core.compressor import PoolNetCompressor
|
|
|
|
|
|
class TestPoolNetCompressor:
|
|
"""Test suite for PoolNetCompressor class."""
|
|
|
|
def test_compressor_init(self):
|
|
"""Test PoolNetCompressor initializes with correct parameters."""
|
|
# This test will fail until we implement the module
|
|
|
|
compressor = PoolNetCompressor(
|
|
input_dim=1024,
|
|
compression_dim=256,
|
|
top_k_ratio=0.5,
|
|
hidden_ratio=2.0,
|
|
dropout_rate=0.1,
|
|
use_residual=True,
|
|
)
|
|
|
|
assert compressor.input_dim == 1024
|
|
assert compressor.compression_dim == 256
|
|
assert compressor.top_k_ratio == 0.5
|
|
|
|
def test_compressor_forward_shape(self):
|
|
"""Test output shape is [batch, compression_dim]."""
|
|
|
|
compressor = PoolNetCompressor(
|
|
input_dim=1024,
|
|
compression_dim=256,
|
|
top_k_ratio=0.5,
|
|
)
|
|
|
|
# Simulate DINOv2 output: batch=2, seq_len=257 (CLS+256 patches), dim=1024
|
|
x = torch.randn(2, 257, 1024)
|
|
out = compressor(x)
|
|
|
|
assert out.shape == (2, 256), f"Expected (2, 256), got {out.shape}"
|
|
|
|
def test_attention_scores_shape(self):
|
|
"""Test attention scores have shape [batch, seq_len]."""
|
|
|
|
compressor = PoolNetCompressor(input_dim=1024, compression_dim=256)
|
|
|
|
x = torch.randn(2, 257, 1024)
|
|
scores = compressor._compute_attention_scores(x)
|
|
|
|
assert scores.shape == (2, 257), f"Expected (2, 257), got {scores.shape}"
|
|
|
|
def test_top_k_selection(self):
|
|
"""Test that only top_k_ratio tokens are selected."""
|
|
|
|
compressor = PoolNetCompressor(
|
|
input_dim=1024, compression_dim=256, top_k_ratio=0.5
|
|
)
|
|
|
|
x = torch.randn(2, 257, 1024)
|
|
pooled = compressor._apply_pooling(x, compressor._compute_attention_scores(x))
|
|
|
|
# With top_k_ratio=0.5, should select 50% of tokens (int rounds down)
|
|
expected_k = 128 # int(257 * 0.5) = 128
|
|
assert pooled.shape[1] == expected_k, (
|
|
f"Expected seq_len={expected_k}, got {pooled.shape[1]}"
|
|
)
|
|
|
|
def test_residual_connection(self):
|
|
"""Test residual adds input contribution to output."""
|
|
|
|
compressor = PoolNetCompressor(
|
|
input_dim=1024,
|
|
compression_dim=256,
|
|
use_residual=True,
|
|
)
|
|
|
|
x = torch.randn(2, 257, 1024)
|
|
out1 = compressor(x)
|
|
|
|
# Residual should affect output
|
|
assert out1 is not None
|
|
assert out1.shape == (2, 256)
|
|
|
|
def test_gpu_device(self):
|
|
"""Test model moves to GPU correctly if available."""
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
compressor = PoolNetCompressor(
|
|
input_dim=1024,
|
|
compression_dim=256,
|
|
device=device,
|
|
)
|
|
|
|
x = torch.randn(2, 257, 1024).to(device)
|
|
out = compressor(x)
|
|
|
|
assert out.device.type == device
|