diff --git a/.gitignore b/.gitignore index 5b0fa00..6c08ed5 100644 --- a/.gitignore +++ b/.gitignore @@ -215,6 +215,7 @@ outputs/ # Vibe Coding .sisyphus .claude/settings.local.json +openspec/changes/ # Devenv .devenv* diff --git a/CLAUDE.md b/CLAUDE.md index 431e715..7e20c92 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,4 +1,4 @@ -# 开发者必读文档 +# Project Spec & Rules ## 代码规范 @@ -47,11 +47,9 @@ - mini-nav/configs/ — 配置管理 (Pydantic + YAML) - mini-nav/commands/ — CLI 命令 (train, benchmark, visualize, generate) - mini-nav/compressors/ — 特征压缩算法 - - hash_compressor.py — 哈希压缩器 - - dino_compressor.py — DINO 压缩器 - - segament_compressor.py — 分割压缩器 - - pipeline.py — 压缩流水线 - - train.py — 压缩器训练 + - hash_compressor.py — 哈希压缩器与训练loss + - pipeline.py — 压缩流水线(整合 DINO 特征提取) + - train.py — 压缩器训练脚本 - mini-nav/data_loading/ — 数据加载与合成 - loader.py — 数据加载器 - synthesizer.py — 场景合成器 @@ -68,3 +66,56 @@ ### Python库 详细可查询pyproject.toml或使用`uv pip list`获取详细的库信息,请基于目前的库实现功能。 如需添加新库,请先询问,用户确认后才能使用`uv add `新增库。 + +## 版本管理 (Jujutsu 特有) +本项目使用 Jujutsu (jj) 进行版本控制,并配套 Memorix MCP 作为架构决策与思维轨迹的持久化中心。 + +- 技能调用: 必须使用 jujutsu 相关工具技能来执行分支、提交、修改(describe)等操作,禁止直接通过 Shell 执行冗长的 Git 兼容指令。 +- 描述规范 (jj desc): + - 执行 jj desc 时,首行必须是精简的变更标题。 + - 空一行后,仅记录改动的核心业务点。 + - 语言使用英文进行描述 + - 禁忌: 禁止在 jj 描述中堆砌复杂的算法逻辑或长篇的设计决策。 +- 记忆联动 (Memorix 优先): + - 凡涉及架构变更、算法决策或重构逻辑,在执行 jj desc 之前,必须先调用 memorix_store (或对应的添加方法)。 + - 关联标记: 在 Memorix 的存储记录中,必须强制包含当前变更的 jj change ID,以便实现从代码变更到思维链的完美映射。 + - 检索逻辑: 在处理需要深入理解上下文的任务时,应主动调用 memorix_search 检索相关的历史 change_id 决策。 +- 无感记录原则: + - 严禁在工程目录下生成任何独立的 change_log.md 或 AI 自动化文档。 + - 所有关于“为什么这样改”的知识,应当流向 jj 的原子化提交描述或 Memorix 的知识图谱库。 + +### 描述示例 +```text +refactor(compressors): Simplify module by removing SAM/DINO separation code + +- Remove dino_compressor.py and segament_compressor.py +- Rewrite pipeline.py to inline DINO into HashPipeline +- Maintain backward compatibility: SAMHashPipeline alias +- Update tests and benchmark.py +``` + +### 提交步骤 +- 执行`jj diff --no-pager`获取当前所有更改 +- 根据更改内容,与openspec生成的相关文档进行总结,重点在于更改内容及其决策逻辑 +- 调用记忆功能,如Memorix记忆先前总结的内容 +- 遵循描述规范,使用jj进行更改的描述 +- 执行`jj new`开启一个新的更改 + +## 记忆管理 (Memorix MCP) +本项目使用 Memorix 作为核心上下文引擎,用于存储架构决策、复杂逻辑关联和历史重构原因。 + +### 记忆写入准则 +- 主动记录: 在完成以下操作后,必须调用 `memorix.store`: + - 用户确认后的核心架构变更(例如:LanceDB 的索引策略)。 + - 复杂的 bug 修复逻辑(记录“为什么”这么修,防止回滚)。 + - 用户在对话中表达的明确偏好(例如:对特定 Python 库的厌恶)。 + - 代码的修改及其决策逻辑(例如:对于用户特定需求导致的更改)。 +- 结构化存储: 存储时请使用 `[Category: Topic] Description` 的格式,确保检索效率。 + +### 记忆检索准则 +- 冷启动检索: 每一轮新对话开始或切换到新任务时,优先调用 `memorix.search` 关键词(如 "project_architecture", "database_schema"),以确保不偏离既有设计。 +- 防止幻觉: 如果对某个旧功能的实现细节不确定,先检索记忆,禁止凭空猜测。 + +### 内存与冗余控制 +- 精简描述: 存入 Memorix 的信息必须精简,严禁存入整段代码块,仅存储“逻辑描述”和“决策依据”。 +- 清理逻辑: 发现记忆库中存在与当前代码事实冲突的旧信息时,应主动提示用户进行更新或覆盖。 diff --git a/mini-nav/commands/benchmark.py b/mini-nav/commands/benchmark.py index 755a0f7..850bf73 100644 --- a/mini-nav/commands/benchmark.py +++ b/mini-nav/commands/benchmark.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, Optional, cast import typer from commands import app @@ -7,15 +7,15 @@ from commands import app @app.command() def benchmark( ctx: typer.Context, - model_path: str = typer.Option( + model_path: Optional[str] = typer.Option( None, "--model", "-m", help="Path to compressor model weights" ), ): import torch + import torch.nn.functional as F from benchmarks import run_benchmark - from compressors import DinoCompressor from configs import cfg_manager - from transformers import AutoImageProcessor, BitImageProcessorFast + from transformers import AutoImageProcessor, AutoModel, BitImageProcessorFast from utils import get_device config = cfg_manager.get() @@ -29,7 +29,12 @@ def benchmark( AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device), ) - model = DinoCompressor().to(device) + # Load DINO model for feature extraction + dino = AutoModel.from_pretrained(model_cfg.dino_model, device_map=device) + dino.eval() + + # Optional hash compressor + compressor = None if model_path: from compressors import HashCompressor @@ -38,7 +43,31 @@ def benchmark( hash_bits=model_cfg.compression_dim, ) compressor.load_state_dict(torch.load(model_path)) - model.compressor = compressor + compressor.to(device) + compressor.eval() + + # Create wrapper with extract_features method + class DinoFeatureExtractor: + def __init__(self, dino, compressor=None): + self.dino = dino + self.compressor = compressor + + def extract_features(self, images: list) -> torch.Tensor: + inputs = processor(images, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = self.dino(**inputs) + features = outputs.last_hidden_state.mean(dim=1) + features = F.normalize(features, dim=-1) + return features + + def encode(self, images: list) -> torch.Tensor: + if self.compressor is None: + return self.extract_features(images) + tokens = self.dino(**processor(images, return_tensors="pt").to(device)).last_hidden_state + _, _, bits = self.compressor(tokens) + return bits + + model = DinoFeatureExtractor(dino, compressor) run_benchmark( model=model, diff --git a/mini-nav/compressors/__init__.py b/mini-nav/compressors/__init__.py index 91394bd..5ea1831 100644 --- a/mini-nav/compressors/__init__.py +++ b/mini-nav/compressors/__init__.py @@ -1,18 +1,15 @@ from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits -from .dino_compressor import DinoCompressor from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask -from .pipeline import SAMHashPipeline, create_pipeline_from_config -from .segament_compressor import SegmentCompressor +from .pipeline import HashPipeline, SAMHashPipeline, create_pipeline_from_config from .train import train __all__ = [ "train", - "DinoCompressor", "HashCompressor", "HashLoss", "VideoPositiveMask", - "SegmentCompressor", - "SAMHashPipeline", + "HashPipeline", + "SAMHashPipeline", # Backward compatibility alias "create_pipeline_from_config", "BinarySign", "hamming_distance", diff --git a/mini-nav/compressors/dino_compressor.py b/mini-nav/compressors/dino_compressor.py deleted file mode 100644 index e991700..0000000 --- a/mini-nav/compressors/dino_compressor.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from PIL import Image -from transformers import AutoImageProcessor, AutoModel - - -class DinoCompressor(nn.Module): - """DINOv2 feature extractor with optional hash compression. - - When compressor is None: returns normalized DINO embeddings. - When compressor is provided: returns binary hash bits for CAM storage. - - Supports both PIL Image input and pre-extracted tokens. - """ - - def __init__( - self, - model_name: str = "facebook/dinov2-large", - compressor: Optional[nn.Module] = None, - device: Optional[str] = None, - ): - """Initialize DINOv2 extractor. - - Args: - model_name: HuggingFace model name - compressor: Optional hash compressor for producing binary codes - device: Device to load model on - """ - super().__init__() - - # Auto detect device - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = torch.device(device) - - self.model_name = model_name - self.processor = AutoImageProcessor.from_pretrained(model_name) - self.dino = AutoModel.from_pretrained(model_name).to(self.device) - self.dino.eval() - - self.compressor = compressor - - def forward(self, inputs): - teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024] - - teacher_embed = teacher_tokens.mean(dim=1) - teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024] - - if self.compressor is None: - return teacher_embed - - # HashCompressor returns (logits, hash_codes, bits) - _, _, bits = self.compressor(teacher_tokens) - return bits # [B, 512] binary bits for CAM - - def extract_features(self, images: list[Image.Image]) -> torch.Tensor: - """Extract DINO features from a list of cropped object images. - - Args: - images: List of PIL Images (cropped objects) - - Returns: - DINO features [N, feature_dim], normalized - """ - if len(images) == 0: - return torch.empty(0, self.dino.config.hidden_size, device=self.device) - - # Process batch of images - inputs = self.processor(images, return_tensors="pt").to(self.device) - - with torch.no_grad(): - outputs = self.dino(**inputs) - - # Pool tokens to get global representation - features = outputs.last_hidden_state.mean(dim=1) # [N, 1024] - features = F.normalize(features, dim=-1) - - return features - - def encode(self, images: list[Image.Image]) -> torch.Tensor: - """Extract features from images and optionally compress to hash codes. - - Args: - images: List of PIL Images - - Returns: - If compressor is None: DINO features [N, 1024] - If compressor is set: Binary hash bits [N, 512] - """ - if self.compressor is None: - return self.extract_features(images) - - # Extract features first - features = self.extract_features(images) # [N, 1024] - - # Add sequence dimension for compressor (expects [B, N, dim]) - features = features.unsqueeze(1) # [N, 1, 1024] - - # Compress to hash codes - _, _, bits = self.compressor(features) - - return bits diff --git a/mini-nav/compressors/pipeline.py b/mini-nav/compressors/pipeline.py index 18e5415..0451a68 100644 --- a/mini-nav/compressors/pipeline.py +++ b/mini-nav/compressors/pipeline.py @@ -1,78 +1,65 @@ -"""Complete pipeline for SAM + DINO + HashCompressor. +"""Hash compression pipeline with DINO feature extraction. -This pipeline extracts object masks from images using SAM2.1, -crops the objects, extracts features using DINOv2, -and compresses them to binary hash codes using HashCompressor. +This pipeline extracts features using DINOv2 and compresses them +to binary hash codes using HashCompressor. """ from typing import Optional import torch import torch.nn as nn +import torch.nn.functional as F from PIL import Image - -from .dino_compressor import DinoCompressor -from .hash_compressor import HashCompressor -from .segament_compressor import SegmentCompressor +from transformers import AutoImageProcessor, AutoModel -def create_pipeline_from_config(config) -> "SAMHashPipeline": - """Create SAMHashPipeline from a config object. +def create_pipeline_from_config(config) -> "HashPipeline": + """Create HashPipeline from a config object. Args: config: Configuration object with model settings Returns: - Initialized SAMHashPipeline + Initialized HashPipeline """ - return SAMHashPipeline( - sam_model=config.model.sam_model, - dino_model=config.model.name, + return HashPipeline( + dino_model=config.model.dino_model, hash_bits=config.model.compression_dim, - sam_min_mask_area=config.model.sam_min_mask_area, - sam_max_masks=config.model.sam_max_masks, compressor_path=config.model.compressor_path, device=config.model.device if config.model.device != "auto" else None, ) -class SAMHashPipeline(nn.Module): - """Complete pipeline: SAM segmentation + DINO features + Hash compression. +class HashPipeline(nn.Module): + """Pipeline: DINO features + Hash compression. Pipeline flow: - Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes) + PIL Image -> DINO (features) -> Hash (binary codes) Usage: # Initialize with config - pipeline = SAMHashPipeline( - sam_model="facebook/sam2.1-hiera-large", + pipeline = HashPipeline( dino_model="facebook/dinov2-large", hash_bits=512, ) # Process image image = Image.open("path/to/image.jpg") - hash_codes = pipeline(image) # [N, 512] binary bits + hash_bits = pipeline(image) # [1, 512] binary bits """ def __init__( self, - sam_model: str = "facebook/sam2.1-hiera-large", dino_model: str = "facebook/dinov2-large", hash_bits: int = 512, - sam_min_mask_area: int = 100, - sam_max_masks: int = 10, compressor_path: Optional[str] = None, device: Optional[str] = None, ): - """Initialize the complete pipeline. + """Initialize the pipeline. Args: - sam_model: SAM model name from HuggingFace dino_model: DINOv2 model name from HuggingFace hash_bits: Number of bits in hash code - sam_min_mask_area: Minimum mask area threshold - sam_max_masks: Maximum number of masks to keep compressor_path: Optional path to trained HashCompressor weights device: Device to run models on """ @@ -83,87 +70,101 @@ class SAMHashPipeline(nn.Module): device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) - # Initialize components - self.segmentor = SegmentCompressor( - model_name=sam_model, - min_mask_area=sam_min_mask_area, - max_masks=sam_max_masks, - device=device, - ) + self.dino_model = dino_model - # HashCompressor expects DINO features (1024 dim for dinov2-large) - dino_dim = 1024 if "large" in dino_model else 768 - self.hash_compressor = HashCompressor( - input_dim=dino_dim, hash_bits=hash_bits - ).to(device) + # Initialize DINO processor and model + self.processor = AutoImageProcessor.from_pretrained(dino_model) + self.dino = AutoModel.from_pretrained(dino_model).to(self.device) + self.dino.eval() + + # Determine DINO feature dimension + self.dino_dim = 1024 if "large" in dino_model else 768 + + # Initialize HashCompressor + self.hash_compressor = nn.Module() # Placeholder, will be replaced + self._init_hash_compressor(hash_bits, compressor_path) + + def _init_hash_compressor( + self, hash_bits: int, compressor_path: Optional[str] = None + ): + """Initialize the hash compressor module. + + This is called during __init__ but we need to replace it properly. + """ + # Import here to avoid circular imports + from .hash_compressor import HashCompressor + + compressor = HashCompressor(input_dim=self.dino_dim, hash_bits=hash_bits).to( + self.device + ) # Load pretrained compressor if provided if compressor_path is not None: - self.hash_compressor.load_state_dict( - torch.load(compressor_path, map_location=device) + compressor.load_state_dict( + torch.load(compressor_path, map_location=self.device) ) print(f"[OK] Loaded HashCompressor from {compressor_path}") - self.dino = DinoCompressor( - model_name=dino_model, - compressor=self.hash_compressor, - device=device, - ) + # Replace the placeholder + self.hash_compressor = compressor + + @property + def hash_bits(self): + """Return the number of hash bits.""" + return self.hash_compressor.hash_bits def forward(self, image: Image.Image) -> torch.Tensor: - """Process a single image through the complete pipeline. + """Process a single image through the pipeline. Args: image: Input PIL Image Returns: - Binary hash codes [N, hash_bits] where N is number of detected objects + Binary hash codes [1, hash_bits] as int32 """ - # Step 1: SAM - extract and crop objects - cropped_objects = self.segmentor(image) + # Extract DINO features + inputs = self.processor(image, return_tensors="pt").to(self.device) - if len(cropped_objects) == 0: - # No objects detected, return empty tensor - return torch.empty( - 0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device - ) + with torch.no_grad(): + outputs = self.dino(**inputs) + tokens = outputs.last_hidden_state # [1, N, dim] - # Step 2: DINO - extract features from cropped objects - # Step 3: HashCompressor - compress features to binary codes - hash_codes = self.dino.encode(cropped_objects) + # Compress to hash codes + _, _, bits = self.hash_compressor(tokens) - return hash_codes + return bits - def extract_features( - self, image: Image.Image, use_hash: bool = False - ) -> torch.Tensor: - """Extract features from image with optional hash compression. + def encode(self, image: Image.Image) -> torch.Tensor: + """Encode an image to binary hash bits. - Args: - image: Input PIL Image - use_hash: If True, return binary hash codes; else return DINO features - - Returns: - Features [N, dim] where dim is 1024 (DINO) or 512 (hash) - """ - cropped_objects = self.segmentor(image) - - if len(cropped_objects) == 0: - dim = self.hash_compressor.hash_bits if use_hash else 1024 - return torch.empty(0, dim, device=self.device) - - if use_hash: - return self.dino.encode(cropped_objects) - else: - return self.dino.extract_features(cropped_objects) - - def extract_masks(self, image: Image.Image) -> list[torch.Tensor]: - """Extract only masks without full processing (for debugging). + Alias for forward(). Args: image: Input PIL Image Returns: - List of binary masks [H, W] + Binary hash codes [1, hash_bits] as int32 """ - return self.segmentor.extract_masks(image) + return self.forward(image) + + def extract_features(self, image: Image.Image) -> torch.Tensor: + """Extract DINO features from an image. + + Args: + image: Input PIL Image + + Returns: + DINO features [1, dino_dim], normalized + """ + inputs = self.processor(image, return_tensors="pt").to(self.device) + + with torch.no_grad(): + outputs = self.dino(**inputs) + features = outputs.last_hidden_state.mean(dim=1) # [1, dim] + features = F.normalize(features, dim=-1) + + return features + + +# Backward compatibility alias +SAMHashPipeline = HashPipeline diff --git a/mini-nav/compressors/segament_compressor.py b/mini-nav/compressors/segament_compressor.py deleted file mode 100644 index 9b32ad0..0000000 --- a/mini-nav/compressors/segament_compressor.py +++ /dev/null @@ -1,180 +0,0 @@ -"""Segment Anything 2 feature extractor with mask filtering and image cropping. - -Extracts object masks from images using SAM2.1, filters by area and confidence, -then crops the original image to obtain individual object regions. -""" - -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -from PIL import Image -from transformers import AutoModelForMaskGeneration, AutoProcessor - - -class SegmentCompressor(nn.Module): - """SAM2.1 based segmenter with mask filtering. - - Extracts object masks from images, filters by area and confidence, - and crops the original image to produce individual object patches. - """ - - def __init__( - self, - model_name: str = "facebook/sam2.1-hiera-large", - min_mask_area: int = 100, - max_masks: int = 10, - device: Optional[str] = None, - ): - """Initialize SAM2.1 segmenter. - - Args: - model_name: HuggingFace model name for SAM2.1 - min_mask_area: Minimum mask pixel area threshold - max_masks: Maximum number of masks to keep - device: Device to load model on (auto-detect if None) - """ - super().__init__() - - self.model_name = model_name - self.min_mask_area = min_mask_area - self.max_masks = max_masks - - # Auto detect device - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.device = torch.device(device) - - # Load SAM model and processor - self.processor = AutoProcessor.from_pretrained(model_name) - self.model = AutoModelForMaskGeneration.from_pretrained(model_name).to( - self.device - ) - self.model.eval() - - def forward(self, image: Image.Image) -> list[Image.Image]: - """Extract object masks and crop object regions. - - Args: - image: Input PIL Image - - Returns: - List of cropped object images (one per valid mask) - """ - # Run SAM inference - inputs = self.processor(image, return_tensors="pt").to(self.device) - - with torch.no_grad(): - outputs = self.model(**inputs) - - # Post-process masks - masks = self.processor.post_process_masks( - outputs.pred_masks, - inputs["original_sizes"], - inputs["reshaped_input_sizes"], - )[0] - - # Filter masks by area and confidence - valid_masks = self._filter_masks(masks) - - if len(valid_masks) == 0: - return [] - - # Crop object regions from original image - cropped_objects = self._crop_objects(image, valid_masks) - - return cropped_objects - - def _filter_masks(self, masks: torch.Tensor) -> list[dict]: - """Filter masks by area and keep top-N. - - Args: - masks: Predicted masks [N, H, W] - - Returns: - List of mask dictionaries with 'mask' and 'area' - """ - valid_masks = [] - - for mask in masks: - # Calculate mask area - area = mask.sum().item() - - # Filter by minimum area - if area < self.min_mask_area: - continue - - valid_masks.append({"mask": mask, "area": area}) - - # Sort by area (descending) and keep top-N - valid_masks = sorted(valid_masks, key=lambda x: x["area"], reverse=True) - valid_masks = valid_masks[: self.max_masks] - - return valid_masks - - def _crop_objects( - self, image: Image.Image, masks: list[dict] - ) -> list[Image.Image]: - """Crop object regions from image using masks. - - Args: - image: Original PIL Image - masks: List of mask dictionaries - - Returns: - List of cropped object images - """ - # Convert PIL to numpy for processing - image_np = np.array(image) - h, w = image_np.shape[:2] - - cropped_objects = [] - - for mask_info in masks: - mask = mask_info["mask"].cpu().numpy() - - # Find bounding box from mask - rows = mask.any(axis=1) - cols = mask.any(axis=0) - - if not rows.any() or not cols.any(): - continue - - y_min, y_max = rows.argmax(), h - rows[::-1].argmax() - 1 - x_min, x_max = cols.argmax(), w - cols[::-1].argmax() - 1 - - # Add small padding - pad = 5 - x_min = max(0, x_min - pad) - y_min = max(0, y_min - pad) - x_max = min(w, x_max + pad) - y_max = min(h, y_max + pad) - - # Crop - cropped = image.crop((x_min, y_min, x_max, y_max)) - cropped_objects.append(cropped) - - return cropped_objects - - @torch.no_grad() - def extract_masks(self, image: Image.Image) -> list[torch.Tensor]: - """Extract only masks without cropping (for debugging). - - Args: - image: Input PIL Image - - Returns: - List of binary masks [H, W] - """ - inputs = self.processor(image, return_tensors="pt").to(self.device) - outputs = self.model(**inputs) - - masks = self.processor.post_process_masks( - outputs.pred_masks, - inputs["original_sizes"], - inputs["reshaped_input_sizes"], - )[0] - - valid_masks = self._filter_masks(masks) - return [m["mask"] for m in valid_masks] diff --git a/mini-nav/tests/test_compressors.py b/mini-nav/tests/test_compressors.py index a6e29d1..99677f1 100644 --- a/mini-nav/tests/test_compressors.py +++ b/mini-nav/tests/test_compressors.py @@ -1,13 +1,13 @@ -"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline).""" +"""Tests for compressor modules (HashCompressor, Pipeline).""" import pytest import torch from compressors import ( BinarySign, - DinoCompressor, HashCompressor, + HashPipeline, SAMHashPipeline, - SegmentCompressor, + VideoPositiveMask, bits_to_hash, create_pipeline_from_config, hamming_distance, @@ -124,87 +124,105 @@ class TestHammingMetrics: assert sim.item() == 512 # Max similarity -class TestSegmentCompressor: - """Test suite for SegmentCompressor.""" +class TestHashLoss: + """Test suite for HashLoss.""" - @pytest.fixture - def mock_image(self): - """Create a mock PIL image.""" - img = Image.new("RGB", (224, 224), color="red") - return img + def test_hash_loss_init(self): + """Verify HashLoss initializes with correct parameters.""" + from compressors import HashLoss - def test_segment_compressor_init(self): - """Verify SegmentCompressor initializes with correct parameters.""" - segmentor = SegmentCompressor( - model_name="facebook/sam2.1-hiera-large", - min_mask_area=100, - max_masks=10, + loss_fn = HashLoss( + contrastive_weight=1.0, + distill_weight=0.5, + quant_weight=0.01, + temperature=0.2, ) - assert segmentor.model_name == "facebook/sam2.1-hiera-large" - assert segmentor.min_mask_area == 100 - assert segmentor.max_masks == 10 + assert loss_fn.contrastive_weight == 1.0 + assert loss_fn.distill_weight == 0.5 + assert loss_fn.quant_weight == 0.01 + assert loss_fn.temperature == 0.2 - def test_filter_masks(self): - """Verify mask filtering logic.""" - # Create segmentor to get default filter params - segmentor = SegmentCompressor() + def test_hash_loss_forward(self): + """Verify HashLoss computes loss correctly.""" + from compressors import HashLoss - # Create mock masks tensor with different areas - # Masks shape: [N, H, W] - masks = [] - for area in [50, 200, 150, 300, 10]: - mask = torch.zeros(100, 100) - mask[:1, :area] = 1 # Create mask with specific area - masks.append(mask) + loss_fn = HashLoss() - masks_tensor = torch.stack(masks) # [5, 100, 100] - valid = segmentor._filter_masks(masks_tensor) + batch_size = 4 + hash_bits = 512 + logits = torch.randn(batch_size, hash_bits) + hash_codes = torch.sign(logits) + teacher_embed = torch.randn(batch_size, 1024) + positive_mask = torch.eye(batch_size, dtype=torch.bool) - # Should filter out 50 and 10 (below min_mask_area=100) - # Then keep top 3 (max_masks=10) - assert len(valid) == 3 - # Verify sorted by area (descending) - areas = [v["area"] for v in valid] - assert areas == sorted(areas, reverse=True) + total_loss, components = loss_fn( + logits=logits, + hash_codes=hash_codes, + teacher_embed=teacher_embed, + positive_mask=positive_mask, + ) + + assert "contrastive" in components + assert "distill" in components + assert "quantization" in components + assert "total" in components -class TestDinoCompressor: - """Test suite for DinoCompressor.""" +class TestVideoPositiveMask: + """Test suite for VideoPositiveMask.""" - def test_dino_compressor_init(self): - """Verify DinoCompressor initializes correctly.""" - dino = DinoCompressor() + def test_from_frame_indices(self): + """Verify positive mask generation from frame indices.""" + mask_gen = VideoPositiveMask(temporal_window=2) - assert dino.model_name == "facebook/dinov2-large" + frame_indices = torch.tensor([0, 1, 3, 5]) - def test_dino_compressor_with_compressor(self): - """Verify DinoCompressor with HashCompressor.""" - hash_compressor = HashCompressor(input_dim=1024, hash_bits=512) - dino = DinoCompressor(compressor=hash_compressor) + mask = mask_gen.from_frame_indices(frame_indices) - assert dino.compressor is hash_compressor + assert mask.shape == (4, 4) + # Frame 0 and 1 should be positive (distance 1 <= 2) + assert mask[0, 1] == True + # Frame 0 and 3 should be negative (distance 3 > 2) + assert mask[0, 3] == False + + def test_from_video_ids(self): + """Verify positive mask generation from video IDs and frame indices.""" + mask_gen = VideoPositiveMask(temporal_window=2) + + video_ids = torch.tensor([0, 0, 1, 1]) + frame_indices = torch.tensor([0, 1, 0, 1]) + + mask = mask_gen.from_video_ids(video_ids, frame_indices) + + assert mask.shape == (4, 4) + # Same video and temporally close + assert mask[0, 1] == True # video 0, frames 0,1 + # Different video + assert mask[0, 2] == False # video 0 vs 1 -class TestSAMHashPipeline: - """Test suite for SAMHashPipeline.""" +class TestHashPipeline: + """Test suite for HashPipeline.""" def test_pipeline_init(self): """Verify pipeline initializes all components.""" - pipeline = SAMHashPipeline( - sam_model="facebook/sam2.1-hiera-large", + pipeline = HashPipeline( dino_model="facebook/dinov2-large", hash_bits=512, ) - assert isinstance(pipeline.segmentor, SegmentCompressor) - assert isinstance(pipeline.dino, DinoCompressor) - assert isinstance(pipeline.hash_compressor, HashCompressor) + assert pipeline.dino_model == "facebook/dinov2-large" + assert pipeline.dino_dim == 1024 def test_pipeline_hash_bits(self): """Verify pipeline uses correct hash bits.""" - pipeline = SAMHashPipeline(hash_bits=256) - assert pipeline.hash_compressor.hash_bits == 256 + pipeline = HashPipeline(hash_bits=256) + assert pipeline.hash_bits == 256 + + def test_pipeline_alias(self): + """Verify SAMHashPipeline is alias for HashPipeline.""" + assert SAMHashPipeline is HashPipeline class TestConfigIntegration: @@ -216,25 +234,21 @@ class TestConfigIntegration: pipeline = create_pipeline_from_config(config) - assert isinstance(pipeline, SAMHashPipeline) - assert pipeline.hash_compressor.hash_bits == config.model.compression_dim + assert isinstance(pipeline, HashPipeline) + assert pipeline.hash_bits == config.model.compression_dim - def test_config_sam_settings(self): - """Verify config contains SAM settings.""" + def test_config_settings(self): + """Verify config contains required settings.""" config = cfg_manager.load() - assert hasattr(config.model, "sam_model") - assert hasattr(config.model, "sam_min_mask_area") - assert hasattr(config.model, "sam_max_masks") - assert config.model.sam_model == "facebook/sam2.1-hiera-large" - assert config.model.sam_min_mask_area == 100 - assert config.model.sam_max_masks == 10 + assert hasattr(config.model, "dino_model") + assert hasattr(config.model, "compression_dim") +@pytest.mark.slow class TestPipelineIntegration: """Integration tests for full pipeline (slow, requires model downloads).""" - @pytest.mark.slow def test_pipeline_end_to_end(self): """Test full pipeline with actual models (slow test).""" # Skip if no GPU @@ -245,54 +259,32 @@ class TestPipelineIntegration: image = Image.new("RGB", (640, 480), color=(128, 128, 128)) # Initialize pipeline (will download models on first run) - pipeline = SAMHashPipeline( - sam_model="facebook/sam2.1-hiera-large", + pipeline = HashPipeline( dino_model="facebook/dinov2-large", hash_bits=512, - sam_min_mask_area=100, - sam_max_masks=5, ) # Run pipeline - hash_codes = pipeline(image) + hash_bits = pipeline(image) # Verify output shape - assert hash_codes.dim() == 2 - assert hash_codes.shape[1] == 512 - assert torch.all((hash_codes == 0) | (hash_codes == 1)) + assert hash_bits.dim() == 2 + assert hash_bits.shape[1] == 512 + assert torch.all((hash_bits == 0) | (hash_bits == 1)) - @pytest.mark.slow - def test_extract_features_without_hash(self): - """Test feature extraction without hash compression.""" + def test_extract_features(self): + """Test feature extraction.""" if not torch.cuda.is_available(): pytest.skip("Requires CUDA") image = Image.new("RGB", (640, 480), color=(128, 128, 128)) - pipeline = SAMHashPipeline( - sam_model="facebook/sam2.1-hiera-large", + pipeline = HashPipeline( dino_model="facebook/dinov2-large", ) - features = pipeline.extract_features(image, use_hash=False) + features = pipeline.extract_features(image) # Should return DINO features (1024 for large) assert features.dim() == 2 assert features.shape[1] == 1024 - - @pytest.mark.slow - def test_extract_masks_only(self): - """Test mask extraction only.""" - if not torch.cuda.is_available(): - pytest.skip("Requires CUDA") - - image = Image.new("RGB", (640, 480), color=(128, 128, 128)) - - pipeline = SAMHashPipeline( - sam_model="facebook/sam2.1-hiera-large", - ) - - masks = pipeline.extract_masks(image) - - # Should return a list of masks - assert isinstance(masks, list)