refactor(compressors): Simplify module by removing SAM/DINO separation code

- Remove dino_compressor.py and segament_compressor.py
- Rewrite pipeline.py to inline DINO into HashPipeline
- Maintain backward compatibility: SAMHashPipeline alias
- Update tests and benchmark.py
This commit is contained in:
2026-03-07 21:33:42 +08:00
parent c8dc5f9301
commit 4da08dc3d3
8 changed files with 276 additions and 490 deletions

1
.gitignore vendored
View File

@@ -215,6 +215,7 @@ outputs/
# Vibe Coding
.sisyphus
.claude/settings.local.json
openspec/changes/
# Devenv
.devenv*

View File

@@ -1,4 +1,4 @@
# 开发者必读文档
# Project Spec & Rules
## 代码规范
@@ -47,11 +47,9 @@
- mini-nav/configs/ — 配置管理 (Pydantic + YAML)
- mini-nav/commands/ — CLI 命令 (train, benchmark, visualize, generate)
- mini-nav/compressors/ — 特征压缩算法
- hash_compressor.py — 哈希压缩器
- dino_compressor.py — DINO 压缩器
- segament_compressor.py — 分割压缩器
- pipeline.py — 压缩流水线
- train.py — 压缩器训练
- hash_compressor.py — 哈希压缩器与训练loss
- pipeline.py — 压缩流水线(整合 DINO 特征提取)
- train.py — 压缩器训练脚本
- mini-nav/data_loading/ — 数据加载与合成
- loader.py — 数据加载器
- synthesizer.py — 场景合成器
@@ -68,3 +66,56 @@
### Python库
详细可查询pyproject.toml或使用`uv pip list`获取详细的库信息,请基于目前的库实现功能。
如需添加新库,请先询问,用户确认后才能使用`uv add <package>`新增库。
## 版本管理 (Jujutsu 特有)
本项目使用 Jujutsu (jj) 进行版本控制,并配套 Memorix MCP 作为架构决策与思维轨迹的持久化中心。
- 技能调用: 必须使用 jujutsu 相关工具技能来执行分支、提交、修改describe等操作禁止直接通过 Shell 执行冗长的 Git 兼容指令。
- 描述规范 (jj desc):
- 执行 jj desc 时,首行必须是精简的变更标题。
- 空一行后,仅记录改动的核心业务点。
- 语言使用英文进行描述
- 禁忌: 禁止在 jj 描述中堆砌复杂的算法逻辑或长篇的设计决策。
- 记忆联动 (Memorix 优先):
- 凡涉及架构变更、算法决策或重构逻辑,在执行 jj desc 之前,必须先调用 memorix_store (或对应的添加方法)。
- 关联标记: 在 Memorix 的存储记录中,必须强制包含当前变更的 jj change ID以便实现从代码变更到思维链的完美映射。
- 检索逻辑: 在处理需要深入理解上下文的任务时,应主动调用 memorix_search 检索相关的历史 change_id 决策。
- 无感记录原则:
- 严禁在工程目录下生成任何独立的 change_log.md 或 AI 自动化文档。
- 所有关于“为什么这样改”的知识,应当流向 jj 的原子化提交描述或 Memorix 的知识图谱库。
### 描述示例
```text
refactor(compressors): Simplify module by removing SAM/DINO separation code
- Remove dino_compressor.py and segament_compressor.py
- Rewrite pipeline.py to inline DINO into HashPipeline
- Maintain backward compatibility: SAMHashPipeline alias
- Update tests and benchmark.py
```
### 提交步骤
- 执行`jj diff --no-pager`获取当前所有更改
- 根据更改内容与openspec生成的相关文档进行总结重点在于更改内容及其决策逻辑
- 调用记忆功能如Memorix记忆先前总结的内容
- 遵循描述规范使用jj进行更改的描述
- 执行`jj new`开启一个新的更改
## 记忆管理 (Memorix MCP)
本项目使用 Memorix 作为核心上下文引擎,用于存储架构决策、复杂逻辑关联和历史重构原因。
### 记忆写入准则
- 主动记录: 在完成以下操作后,必须调用 `memorix.store`
- 用户确认后的核心架构变更例如LanceDB 的索引策略)。
- 复杂的 bug 修复逻辑(记录“为什么”这么修,防止回滚)。
- 用户在对话中表达的明确偏好(例如:对特定 Python 库的厌恶)。
- 代码的修改及其决策逻辑(例如:对于用户特定需求导致的更改)。
- 结构化存储: 存储时请使用 `[Category: Topic] Description` 的格式,确保检索效率。
### 记忆检索准则
- 冷启动检索: 每一轮新对话开始或切换到新任务时,优先调用 `memorix.search` 关键词(如 "project_architecture", "database_schema"),以确保不偏离既有设计。
- 防止幻觉: 如果对某个旧功能的实现细节不确定,先检索记忆,禁止凭空猜测。
### 内存与冗余控制
- 精简描述: 存入 Memorix 的信息必须精简,严禁存入整段代码块,仅存储“逻辑描述”和“决策依据”。
- 清理逻辑: 发现记忆库中存在与当前代码事实冲突的旧信息时,应主动提示用户进行更新或覆盖。

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import Any, Optional, cast
import typer
from commands import app
@@ -7,15 +7,15 @@ from commands import app
@app.command()
def benchmark(
ctx: typer.Context,
model_path: str = typer.Option(
model_path: Optional[str] = typer.Option(
None, "--model", "-m", help="Path to compressor model weights"
),
):
import torch
import torch.nn.functional as F
from benchmarks import run_benchmark
from compressors import DinoCompressor
from configs import cfg_manager
from transformers import AutoImageProcessor, BitImageProcessorFast
from transformers import AutoImageProcessor, AutoModel, BitImageProcessorFast
from utils import get_device
config = cfg_manager.get()
@@ -29,7 +29,12 @@ def benchmark(
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
)
model = DinoCompressor().to(device)
# Load DINO model for feature extraction
dino = AutoModel.from_pretrained(model_cfg.dino_model, device_map=device)
dino.eval()
# Optional hash compressor
compressor = None
if model_path:
from compressors import HashCompressor
@@ -38,7 +43,31 @@ def benchmark(
hash_bits=model_cfg.compression_dim,
)
compressor.load_state_dict(torch.load(model_path))
model.compressor = compressor
compressor.to(device)
compressor.eval()
# Create wrapper with extract_features method
class DinoFeatureExtractor:
def __init__(self, dino, compressor=None):
self.dino = dino
self.compressor = compressor
def extract_features(self, images: list) -> torch.Tensor:
inputs = processor(images, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.dino(**inputs)
features = outputs.last_hidden_state.mean(dim=1)
features = F.normalize(features, dim=-1)
return features
def encode(self, images: list) -> torch.Tensor:
if self.compressor is None:
return self.extract_features(images)
tokens = self.dino(**processor(images, return_tensors="pt").to(device)).last_hidden_state
_, _, bits = self.compressor(tokens)
return bits
model = DinoFeatureExtractor(dino, compressor)
run_benchmark(
model=model,

View File

@@ -1,18 +1,15 @@
from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits
from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask
from .pipeline import SAMHashPipeline, create_pipeline_from_config
from .segament_compressor import SegmentCompressor
from .pipeline import HashPipeline, SAMHashPipeline, create_pipeline_from_config
from .train import train
__all__ = [
"train",
"DinoCompressor",
"HashCompressor",
"HashLoss",
"VideoPositiveMask",
"SegmentCompressor",
"SAMHashPipeline",
"HashPipeline",
"SAMHashPipeline", # Backward compatibility alias
"create_pipeline_from_config",
"BinarySign",
"hamming_distance",

View File

@@ -1,105 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
class DinoCompressor(nn.Module):
"""DINOv2 feature extractor with optional hash compression.
When compressor is None: returns normalized DINO embeddings.
When compressor is provided: returns binary hash bits for CAM storage.
Supports both PIL Image input and pre-extracted tokens.
"""
def __init__(
self,
model_name: str = "facebook/dinov2-large",
compressor: Optional[nn.Module] = None,
device: Optional[str] = None,
):
"""Initialize DINOv2 extractor.
Args:
model_name: HuggingFace model name
compressor: Optional hash compressor for producing binary codes
device: Device to load model on
"""
super().__init__()
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model_name = model_name
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.dino = AutoModel.from_pretrained(model_name).to(self.device)
self.dino.eval()
self.compressor = compressor
def forward(self, inputs):
teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
if self.compressor is None:
return teacher_embed
# HashCompressor returns (logits, hash_codes, bits)
_, _, bits = self.compressor(teacher_tokens)
return bits # [B, 512] binary bits for CAM
def extract_features(self, images: list[Image.Image]) -> torch.Tensor:
"""Extract DINO features from a list of cropped object images.
Args:
images: List of PIL Images (cropped objects)
Returns:
DINO features [N, feature_dim], normalized
"""
if len(images) == 0:
return torch.empty(0, self.dino.config.hidden_size, device=self.device)
# Process batch of images
inputs = self.processor(images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
# Pool tokens to get global representation
features = outputs.last_hidden_state.mean(dim=1) # [N, 1024]
features = F.normalize(features, dim=-1)
return features
def encode(self, images: list[Image.Image]) -> torch.Tensor:
"""Extract features from images and optionally compress to hash codes.
Args:
images: List of PIL Images
Returns:
If compressor is None: DINO features [N, 1024]
If compressor is set: Binary hash bits [N, 512]
"""
if self.compressor is None:
return self.extract_features(images)
# Extract features first
features = self.extract_features(images) # [N, 1024]
# Add sequence dimension for compressor (expects [B, N, dim])
features = features.unsqueeze(1) # [N, 1, 1024]
# Compress to hash codes
_, _, bits = self.compressor(features)
return bits

View File

@@ -1,78 +1,65 @@
"""Complete pipeline for SAM + DINO + HashCompressor.
"""Hash compression pipeline with DINO feature extraction.
This pipeline extracts object masks from images using SAM2.1,
crops the objects, extracts features using DINOv2,
and compresses them to binary hash codes using HashCompressor.
This pipeline extracts features using DINOv2 and compresses them
to binary hash codes using HashCompressor.
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor
from .segament_compressor import SegmentCompressor
from transformers import AutoImageProcessor, AutoModel
def create_pipeline_from_config(config) -> "SAMHashPipeline":
"""Create SAMHashPipeline from a config object.
def create_pipeline_from_config(config) -> "HashPipeline":
"""Create HashPipeline from a config object.
Args:
config: Configuration object with model settings
Returns:
Initialized SAMHashPipeline
Initialized HashPipeline
"""
return SAMHashPipeline(
sam_model=config.model.sam_model,
dino_model=config.model.name,
return HashPipeline(
dino_model=config.model.dino_model,
hash_bits=config.model.compression_dim,
sam_min_mask_area=config.model.sam_min_mask_area,
sam_max_masks=config.model.sam_max_masks,
compressor_path=config.model.compressor_path,
device=config.model.device if config.model.device != "auto" else None,
)
class SAMHashPipeline(nn.Module):
"""Complete pipeline: SAM segmentation + DINO features + Hash compression.
class HashPipeline(nn.Module):
"""Pipeline: DINO features + Hash compression.
Pipeline flow:
Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes)
PIL Image -> DINO (features) -> Hash (binary codes)
Usage:
# Initialize with config
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
pipeline = HashPipeline(
dino_model="facebook/dinov2-large",
hash_bits=512,
)
# Process image
image = Image.open("path/to/image.jpg")
hash_codes = pipeline(image) # [N, 512] binary bits
hash_bits = pipeline(image) # [1, 512] binary bits
"""
def __init__(
self,
sam_model: str = "facebook/sam2.1-hiera-large",
dino_model: str = "facebook/dinov2-large",
hash_bits: int = 512,
sam_min_mask_area: int = 100,
sam_max_masks: int = 10,
compressor_path: Optional[str] = None,
device: Optional[str] = None,
):
"""Initialize the complete pipeline.
"""Initialize the pipeline.
Args:
sam_model: SAM model name from HuggingFace
dino_model: DINOv2 model name from HuggingFace
hash_bits: Number of bits in hash code
sam_min_mask_area: Minimum mask area threshold
sam_max_masks: Maximum number of masks to keep
compressor_path: Optional path to trained HashCompressor weights
device: Device to run models on
"""
@@ -83,87 +70,101 @@ class SAMHashPipeline(nn.Module):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Initialize components
self.segmentor = SegmentCompressor(
model_name=sam_model,
min_mask_area=sam_min_mask_area,
max_masks=sam_max_masks,
device=device,
)
self.dino_model = dino_model
# HashCompressor expects DINO features (1024 dim for dinov2-large)
dino_dim = 1024 if "large" in dino_model else 768
self.hash_compressor = HashCompressor(
input_dim=dino_dim, hash_bits=hash_bits
).to(device)
# Initialize DINO processor and model
self.processor = AutoImageProcessor.from_pretrained(dino_model)
self.dino = AutoModel.from_pretrained(dino_model).to(self.device)
self.dino.eval()
# Determine DINO feature dimension
self.dino_dim = 1024 if "large" in dino_model else 768
# Initialize HashCompressor
self.hash_compressor = nn.Module() # Placeholder, will be replaced
self._init_hash_compressor(hash_bits, compressor_path)
def _init_hash_compressor(
self, hash_bits: int, compressor_path: Optional[str] = None
):
"""Initialize the hash compressor module.
This is called during __init__ but we need to replace it properly.
"""
# Import here to avoid circular imports
from .hash_compressor import HashCompressor
compressor = HashCompressor(input_dim=self.dino_dim, hash_bits=hash_bits).to(
self.device
)
# Load pretrained compressor if provided
if compressor_path is not None:
self.hash_compressor.load_state_dict(
torch.load(compressor_path, map_location=device)
compressor.load_state_dict(
torch.load(compressor_path, map_location=self.device)
)
print(f"[OK] Loaded HashCompressor from {compressor_path}")
self.dino = DinoCompressor(
model_name=dino_model,
compressor=self.hash_compressor,
device=device,
)
# Replace the placeholder
self.hash_compressor = compressor
@property
def hash_bits(self):
"""Return the number of hash bits."""
return self.hash_compressor.hash_bits
def forward(self, image: Image.Image) -> torch.Tensor:
"""Process a single image through the complete pipeline.
"""Process a single image through the pipeline.
Args:
image: Input PIL Image
Returns:
Binary hash codes [N, hash_bits] where N is number of detected objects
Binary hash codes [1, hash_bits] as int32
"""
# Step 1: SAM - extract and crop objects
cropped_objects = self.segmentor(image)
# Extract DINO features
inputs = self.processor(image, return_tensors="pt").to(self.device)
if len(cropped_objects) == 0:
# No objects detected, return empty tensor
return torch.empty(
0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device
)
with torch.no_grad():
outputs = self.dino(**inputs)
tokens = outputs.last_hidden_state # [1, N, dim]
# Step 2: DINO - extract features from cropped objects
# Step 3: HashCompressor - compress features to binary codes
hash_codes = self.dino.encode(cropped_objects)
# Compress to hash codes
_, _, bits = self.hash_compressor(tokens)
return hash_codes
return bits
def extract_features(
self, image: Image.Image, use_hash: bool = False
) -> torch.Tensor:
"""Extract features from image with optional hash compression.
def encode(self, image: Image.Image) -> torch.Tensor:
"""Encode an image to binary hash bits.
Args:
image: Input PIL Image
use_hash: If True, return binary hash codes; else return DINO features
Returns:
Features [N, dim] where dim is 1024 (DINO) or 512 (hash)
"""
cropped_objects = self.segmentor(image)
if len(cropped_objects) == 0:
dim = self.hash_compressor.hash_bits if use_hash else 1024
return torch.empty(0, dim, device=self.device)
if use_hash:
return self.dino.encode(cropped_objects)
else:
return self.dino.extract_features(cropped_objects)
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without full processing (for debugging).
Alias for forward().
Args:
image: Input PIL Image
Returns:
List of binary masks [H, W]
Binary hash codes [1, hash_bits] as int32
"""
return self.segmentor.extract_masks(image)
return self.forward(image)
def extract_features(self, image: Image.Image) -> torch.Tensor:
"""Extract DINO features from an image.
Args:
image: Input PIL Image
Returns:
DINO features [1, dino_dim], normalized
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
features = outputs.last_hidden_state.mean(dim=1) # [1, dim]
features = F.normalize(features, dim=-1)
return features
# Backward compatibility alias
SAMHashPipeline = HashPipeline

View File

@@ -1,180 +0,0 @@
"""Segment Anything 2 feature extractor with mask filtering and image cropping.
Extracts object masks from images using SAM2.1, filters by area and confidence,
then crops the original image to obtain individual object regions.
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
class SegmentCompressor(nn.Module):
"""SAM2.1 based segmenter with mask filtering.
Extracts object masks from images, filters by area and confidence,
and crops the original image to produce individual object patches.
"""
def __init__(
self,
model_name: str = "facebook/sam2.1-hiera-large",
min_mask_area: int = 100,
max_masks: int = 10,
device: Optional[str] = None,
):
"""Initialize SAM2.1 segmenter.
Args:
model_name: HuggingFace model name for SAM2.1
min_mask_area: Minimum mask pixel area threshold
max_masks: Maximum number of masks to keep
device: Device to load model on (auto-detect if None)
"""
super().__init__()
self.model_name = model_name
self.min_mask_area = min_mask_area
self.max_masks = max_masks
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Load SAM model and processor
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForMaskGeneration.from_pretrained(model_name).to(
self.device
)
self.model.eval()
def forward(self, image: Image.Image) -> list[Image.Image]:
"""Extract object masks and crop object regions.
Args:
image: Input PIL Image
Returns:
List of cropped object images (one per valid mask)
"""
# Run SAM inference
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process masks
masks = self.processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)[0]
# Filter masks by area and confidence
valid_masks = self._filter_masks(masks)
if len(valid_masks) == 0:
return []
# Crop object regions from original image
cropped_objects = self._crop_objects(image, valid_masks)
return cropped_objects
def _filter_masks(self, masks: torch.Tensor) -> list[dict]:
"""Filter masks by area and keep top-N.
Args:
masks: Predicted masks [N, H, W]
Returns:
List of mask dictionaries with 'mask' and 'area'
"""
valid_masks = []
for mask in masks:
# Calculate mask area
area = mask.sum().item()
# Filter by minimum area
if area < self.min_mask_area:
continue
valid_masks.append({"mask": mask, "area": area})
# Sort by area (descending) and keep top-N
valid_masks = sorted(valid_masks, key=lambda x: x["area"], reverse=True)
valid_masks = valid_masks[: self.max_masks]
return valid_masks
def _crop_objects(
self, image: Image.Image, masks: list[dict]
) -> list[Image.Image]:
"""Crop object regions from image using masks.
Args:
image: Original PIL Image
masks: List of mask dictionaries
Returns:
List of cropped object images
"""
# Convert PIL to numpy for processing
image_np = np.array(image)
h, w = image_np.shape[:2]
cropped_objects = []
for mask_info in masks:
mask = mask_info["mask"].cpu().numpy()
# Find bounding box from mask
rows = mask.any(axis=1)
cols = mask.any(axis=0)
if not rows.any() or not cols.any():
continue
y_min, y_max = rows.argmax(), h - rows[::-1].argmax() - 1
x_min, x_max = cols.argmax(), w - cols[::-1].argmax() - 1
# Add small padding
pad = 5
x_min = max(0, x_min - pad)
y_min = max(0, y_min - pad)
x_max = min(w, x_max + pad)
y_max = min(h, y_max + pad)
# Crop
cropped = image.crop((x_min, y_min, x_max, y_max))
cropped_objects.append(cropped)
return cropped_objects
@torch.no_grad()
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without cropping (for debugging).
Args:
image: Input PIL Image
Returns:
List of binary masks [H, W]
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
masks = self.processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)[0]
valid_masks = self._filter_masks(masks)
return [m["mask"] for m in valid_masks]

View File

@@ -1,13 +1,13 @@
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
"""Tests for compressor modules (HashCompressor, Pipeline)."""
import pytest
import torch
from compressors import (
BinarySign,
DinoCompressor,
HashCompressor,
HashPipeline,
SAMHashPipeline,
SegmentCompressor,
VideoPositiveMask,
bits_to_hash,
create_pipeline_from_config,
hamming_distance,
@@ -124,87 +124,105 @@ class TestHammingMetrics:
assert sim.item() == 512 # Max similarity
class TestSegmentCompressor:
"""Test suite for SegmentCompressor."""
class TestHashLoss:
"""Test suite for HashLoss."""
@pytest.fixture
def mock_image(self):
"""Create a mock PIL image."""
img = Image.new("RGB", (224, 224), color="red")
return img
def test_hash_loss_init(self):
"""Verify HashLoss initializes with correct parameters."""
from compressors import HashLoss
def test_segment_compressor_init(self):
"""Verify SegmentCompressor initializes with correct parameters."""
segmentor = SegmentCompressor(
model_name="facebook/sam2.1-hiera-large",
min_mask_area=100,
max_masks=10,
loss_fn = HashLoss(
contrastive_weight=1.0,
distill_weight=0.5,
quant_weight=0.01,
temperature=0.2,
)
assert segmentor.model_name == "facebook/sam2.1-hiera-large"
assert segmentor.min_mask_area == 100
assert segmentor.max_masks == 10
assert loss_fn.contrastive_weight == 1.0
assert loss_fn.distill_weight == 0.5
assert loss_fn.quant_weight == 0.01
assert loss_fn.temperature == 0.2
def test_filter_masks(self):
"""Verify mask filtering logic."""
# Create segmentor to get default filter params
segmentor = SegmentCompressor()
def test_hash_loss_forward(self):
"""Verify HashLoss computes loss correctly."""
from compressors import HashLoss
# Create mock masks tensor with different areas
# Masks shape: [N, H, W]
masks = []
for area in [50, 200, 150, 300, 10]:
mask = torch.zeros(100, 100)
mask[:1, :area] = 1 # Create mask with specific area
masks.append(mask)
loss_fn = HashLoss()
masks_tensor = torch.stack(masks) # [5, 100, 100]
valid = segmentor._filter_masks(masks_tensor)
batch_size = 4
hash_bits = 512
logits = torch.randn(batch_size, hash_bits)
hash_codes = torch.sign(logits)
teacher_embed = torch.randn(batch_size, 1024)
positive_mask = torch.eye(batch_size, dtype=torch.bool)
# Should filter out 50 and 10 (below min_mask_area=100)
# Then keep top 3 (max_masks=10)
assert len(valid) == 3
# Verify sorted by area (descending)
areas = [v["area"] for v in valid]
assert areas == sorted(areas, reverse=True)
total_loss, components = loss_fn(
logits=logits,
hash_codes=hash_codes,
teacher_embed=teacher_embed,
positive_mask=positive_mask,
)
assert "contrastive" in components
assert "distill" in components
assert "quantization" in components
assert "total" in components
class TestDinoCompressor:
"""Test suite for DinoCompressor."""
class TestVideoPositiveMask:
"""Test suite for VideoPositiveMask."""
def test_dino_compressor_init(self):
"""Verify DinoCompressor initializes correctly."""
dino = DinoCompressor()
def test_from_frame_indices(self):
"""Verify positive mask generation from frame indices."""
mask_gen = VideoPositiveMask(temporal_window=2)
assert dino.model_name == "facebook/dinov2-large"
frame_indices = torch.tensor([0, 1, 3, 5])
def test_dino_compressor_with_compressor(self):
"""Verify DinoCompressor with HashCompressor."""
hash_compressor = HashCompressor(input_dim=1024, hash_bits=512)
dino = DinoCompressor(compressor=hash_compressor)
mask = mask_gen.from_frame_indices(frame_indices)
assert dino.compressor is hash_compressor
assert mask.shape == (4, 4)
# Frame 0 and 1 should be positive (distance 1 <= 2)
assert mask[0, 1] == True
# Frame 0 and 3 should be negative (distance 3 > 2)
assert mask[0, 3] == False
def test_from_video_ids(self):
"""Verify positive mask generation from video IDs and frame indices."""
mask_gen = VideoPositiveMask(temporal_window=2)
video_ids = torch.tensor([0, 0, 1, 1])
frame_indices = torch.tensor([0, 1, 0, 1])
mask = mask_gen.from_video_ids(video_ids, frame_indices)
assert mask.shape == (4, 4)
# Same video and temporally close
assert mask[0, 1] == True # video 0, frames 0,1
# Different video
assert mask[0, 2] == False # video 0 vs 1
class TestSAMHashPipeline:
"""Test suite for SAMHashPipeline."""
class TestHashPipeline:
"""Test suite for HashPipeline."""
def test_pipeline_init(self):
"""Verify pipeline initializes all components."""
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
pipeline = HashPipeline(
dino_model="facebook/dinov2-large",
hash_bits=512,
)
assert isinstance(pipeline.segmentor, SegmentCompressor)
assert isinstance(pipeline.dino, DinoCompressor)
assert isinstance(pipeline.hash_compressor, HashCompressor)
assert pipeline.dino_model == "facebook/dinov2-large"
assert pipeline.dino_dim == 1024
def test_pipeline_hash_bits(self):
"""Verify pipeline uses correct hash bits."""
pipeline = SAMHashPipeline(hash_bits=256)
assert pipeline.hash_compressor.hash_bits == 256
pipeline = HashPipeline(hash_bits=256)
assert pipeline.hash_bits == 256
def test_pipeline_alias(self):
"""Verify SAMHashPipeline is alias for HashPipeline."""
assert SAMHashPipeline is HashPipeline
class TestConfigIntegration:
@@ -216,25 +234,21 @@ class TestConfigIntegration:
pipeline = create_pipeline_from_config(config)
assert isinstance(pipeline, SAMHashPipeline)
assert pipeline.hash_compressor.hash_bits == config.model.compression_dim
assert isinstance(pipeline, HashPipeline)
assert pipeline.hash_bits == config.model.compression_dim
def test_config_sam_settings(self):
"""Verify config contains SAM settings."""
def test_config_settings(self):
"""Verify config contains required settings."""
config = cfg_manager.load()
assert hasattr(config.model, "sam_model")
assert hasattr(config.model, "sam_min_mask_area")
assert hasattr(config.model, "sam_max_masks")
assert config.model.sam_model == "facebook/sam2.1-hiera-large"
assert config.model.sam_min_mask_area == 100
assert config.model.sam_max_masks == 10
assert hasattr(config.model, "dino_model")
assert hasattr(config.model, "compression_dim")
@pytest.mark.slow
class TestPipelineIntegration:
"""Integration tests for full pipeline (slow, requires model downloads)."""
@pytest.mark.slow
def test_pipeline_end_to_end(self):
"""Test full pipeline with actual models (slow test)."""
# Skip if no GPU
@@ -245,54 +259,32 @@ class TestPipelineIntegration:
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
# Initialize pipeline (will download models on first run)
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
pipeline = HashPipeline(
dino_model="facebook/dinov2-large",
hash_bits=512,
sam_min_mask_area=100,
sam_max_masks=5,
)
# Run pipeline
hash_codes = pipeline(image)
hash_bits = pipeline(image)
# Verify output shape
assert hash_codes.dim() == 2
assert hash_codes.shape[1] == 512
assert torch.all((hash_codes == 0) | (hash_codes == 1))
assert hash_bits.dim() == 2
assert hash_bits.shape[1] == 512
assert torch.all((hash_bits == 0) | (hash_bits == 1))
@pytest.mark.slow
def test_extract_features_without_hash(self):
"""Test feature extraction without hash compression."""
def test_extract_features(self):
"""Test feature extraction."""
if not torch.cuda.is_available():
pytest.skip("Requires CUDA")
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
pipeline = HashPipeline(
dino_model="facebook/dinov2-large",
)
features = pipeline.extract_features(image, use_hash=False)
features = pipeline.extract_features(image)
# Should return DINO features (1024 for large)
assert features.dim() == 2
assert features.shape[1] == 1024
@pytest.mark.slow
def test_extract_masks_only(self):
"""Test mask extraction only."""
if not torch.cuda.is_available():
pytest.skip("Requires CUDA")
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
)
masks = pipeline.extract_masks(image)
# Should return a list of masks
assert isinstance(masks, list)