Compare commits

...

7 Commits

Author SHA1 Message Date
431f6844ef chore: remove experimental OPSX workflow commands and skills
- remove OPSX Claude commands for apply, archive, explore, and propose
- remove matching OpenSpec workflow skills under `.claude/skills`
- clean up deprecated experimental workflow integration from `.claude`
2026-03-17 20:59:56 +08:00
34235c605d chore(deps): update Python version and project documentation 2026-03-12 12:53:13 +08:00
b39ee74e99 feat(benchmark): add multi-object retrieval benchmark with SAM segmentation 2026-03-12 12:52:51 +08:00
2466ab28cd feat(serena): Add Serena project configuration 2026-03-08 15:19:26 +08:00
4da08dc3d3 refactor(compressors): Simplify module by removing SAM/DINO separation code
- Remove dino_compressor.py and segament_compressor.py
- Rewrite pipeline.py to inline DINO into HashPipeline
- Maintain backward compatibility: SAMHashPipeline alias
- Update tests and benchmark.py
2026-03-07 22:55:13 +08:00
c8dc5f9301 docs: update project documentation and configuration 2026-03-07 15:45:28 +08:00
bf02a05ffc feat(opsx): add OpenSpec workflow commands and skills 2026-03-07 15:02:08 +08:00
20 changed files with 3685 additions and 3053 deletions

6
.gitignore vendored
View File

@@ -211,9 +211,11 @@ datasets/
data/ data/
deps/ deps/
outputs/ outputs/
# Vibe Coding
.sisyphus .sisyphus
.claude/ .claude/settings.local.json
CLAUDE.md openspec/changes/
# Devenv # Devenv
.devenv* .devenv*

View File

@@ -1,13 +0,0 @@
activate:
micromamba activate ./.venv
update-venv:
micromamba env export --no-builds | grep -v "prefix" > venv.yaml
download-test:
python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path data/
python -m habitat_sim.utils.datasets_download --uids habitat_test_pointnav_dataset --data-path data/
python -m habitat_sim.utils.datasets_download --uids replica_cad_dataset --data-path data/
python -m habitat_sim.utils.datasets_download --uids rearrange_dataset_v2 --data-path data/
python -m habitat_sim.utils.datasets_download --uids hab_fetch --data-path data/
python -m habitat_sim.utils.datasets_download --uids ycb --data-path data/

View File

@@ -1 +1 @@
3.10 3.13

104
AGENTS.md Normal file
View File

@@ -0,0 +1,104 @@
# Memorix — Automatic Memory Rules
You have access to Memorix memory tools. Follow these rules to maintain persistent context across sessions.
## RULE 1: Session Start — Load Context
At the **beginning of every conversation**, BEFORE responding to the user:
1. Call `memorix_session_start` to get the previous session summary and key memories (this is a direct read, not a search — no fragmentation risk)
2. Then call `memorix_search` with a query related to the user's first message for additional context
3. If search results are found, use `memorix_detail` to fetch the most relevant ones
4. Reference relevant memories naturally — the user should feel you "remember" them
## RULE 2: Store Important Context
**Proactively** call `memorix_store` when any of the following happen:
### What MUST be recorded:
- Architecture/design decisions → type: `decision`
- Bug identified and fixed → type: `problem-solution`
- Unexpected behavior or gotcha → type: `gotcha`
- Config changed (env vars, ports, deps) → type: `what-changed`
- Feature completed or milestone → type: `what-changed`
- Trade-off discussed with conclusion → type: `trade-off`
### What should NOT be recorded:
- Simple file reads, greetings, trivial commands (ls, pwd, git status)
### Use topicKey for evolving topics:
For decisions, architecture docs, or any topic that evolves over time, ALWAYS use `topicKey` parameter.
This ensures the memory is UPDATED instead of creating duplicates.
Use `memorix_suggest_topic_key` to generate a stable key.
Example: `topicKey: "architecture/auth-model"` — subsequent stores with the same key update the existing memory.
### Track progress with the progress parameter:
When working on features or tasks, include the `progress` parameter:
```json
{
"progress": {
"feature": "user authentication",
"status": "in-progress",
"completion": 60
}
}
```
Status values: `in-progress`, `completed`, `blocked`
## RULE 3: Resolve Completed Memories
When a task is completed, a bug is fixed, or information becomes outdated:
1. Call `memorix_resolve` with the observation IDs to mark them as resolved
2. Resolved memories are hidden from default search, preventing context pollution
This is critical — without resolving, old bug reports and completed tasks will keep appearing in future searches.
## RULE 4: Session End — Store Decision Chain Summary
When the conversation is ending, create a **decision chain summary** (not just a checklist):
1. Call `memorix_store` with type `session-request` and `topicKey: "session/latest-summary"`:
**Required structure:**
```
## Goal
[What we were working on — specific, not vague]
## Key Decisions & Reasoning
- Chose X because Y. Rejected Z because [reason].
- [Every architectural/design decision with WHY]
## What Changed
- [File path] — [what changed and why]
## Current State
- [What works now, what's pending]
- [Any blockers or risks]
## Next Steps
- [Concrete next actions, in priority order]
```
**Critical: Include the "Key Decisions & Reasoning" section.** Without it, the next AI session will lack the context to understand WHY things were done a certain way and may suggest conflicting approaches.
2. Call `memorix_resolve` on any memories for tasks completed in this session
## RULE 5: Compact Awareness
Memorix automatically compacts memories on store:
- **With LLM API configured:** Smart dedup — extracts facts, compares with existing, merges or skips duplicates
- **Without LLM (free mode):** Heuristic dedup — uses similarity scores to detect and merge duplicate memories
- **You don't need to manually deduplicate.** Just store naturally and compact handles the rest.
- If you notice excessive duplicate memories, call `memorix_deduplicate` for batch cleanup.
## Guidelines
- **Use concise titles** (~5-10 words) and structured facts
- **Include file paths** in filesModified when relevant
- **Include related concepts** for better searchability
- **Always use topicKey** for recurring topics to prevent duplicates
- **Always resolve** completed tasks and fixed bugs
- **Always include reasoning** — "chose X because Y" is 10x more valuable than "did X"
- Search defaults to `status="active"` — use `status="all"` to include resolved memories

125
CLAUDE.md Normal file
View File

@@ -0,0 +1,125 @@
# Project Spec & Rules
## 代码规范
### Google风格代码
详细参阅https://raw.githubusercontent.com/shendeguize/GooglePythonStyleGuideCN/refs/heads/master/README.md
### 代码编写原则
- 简洁,清晰易懂,最小化实现
- 条件或循环分支不能超过三层提前Return以减少分支的出现
- 变量说明注释、条件或循环分支注释完全
- 无需向后兼容,避免添加过多功能
- 先编写测试集,再实现代码
- 实现测试集后,先询问用户意见,用户确认后才能继续
- 如非用户要求,无需编写基准测试代码
- 英文注释,中文文档
- 完成代码编写后在文档的框架不变的情况下更新文档如CLAUDE.md
### 测试编写原则
- 精简、干净、快速
- 核心关键逻辑或算法必须测试
- 需要加载transformer模型进行验证的测试与无需加载模型的测试分离
- 无需编写测试集的情况
- UI界面相关的代码
- 过于复杂或耗时的逻辑
- 基准测试相关
### 关键词说明
- 确认:用户认同当前的实现方案或测试集实现,即可以开始工作
- 继续:用户需要你重读上下文,继续未完成的工作
### 文档更新说明
仅在工程目录变化时,更新此文档的目录说明部分。
如需修改其他部分,请先询问,在进行修改。
## 工程说明
使用UV管理整个工程pytest用于测试justfile用于快捷命令jujutsu用于版本管理。
### 目录说明
**核心模块**
- mini-nav/main.py — CLI 入口 (Typer)
- mini-nav/database.py — LanceDB 单例管理,用于向量存储与检索
- mini-nav/feature_retrieval.py — DINOv2 图像特征提取与检索
**源代码目录 (mini-nav/)**
- mini-nav/configs/ — 配置管理 (Pydantic + YAML)
- mini-nav/commands/ — CLI 命令 (train, benchmark, visualize, generate)
- mini-nav/compressors/ — 特征压缩算法
- hash_compressor.py — 哈希压缩器与训练loss
- pipeline.py — 压缩流水线(整合 DINO 特征提取)
- train.py — 压缩器训练脚本
- mini-nav/data_loading/ — 数据加载与合成
- loader.py — 数据加载器
- insdet_scenes.py — InsDet场景数据集加载
- synthesizer.py — 场景合成器
- mini-nav/utils/ — 工具函数
- feature_extractor.py — 特征提取工具
- sam.py — SAM 2.1 分割工具
- mini-nav/tests/ — pytest 测试集
- mini-nav/benchmarks/ — 基准测试 (recall@k)
- tasks/
- multi_object_retrieval.py — 多目标检索基准任务
- mini-nav/visualizer/ — Dash + Plotly 可视化应用
**数据目录**
- datasets/ — 数据集目录
- outputs/ — 默认输出目录 (数据库、模型权重等)
### Python库
详细可查询pyproject.toml或使用`uv pip list`获取详细的库信息,请基于目前的库实现功能。
如需添加新库,请先询问,用户确认后才能使用`uv add <package>`新增库。
## 版本管理 (Jujutsu 特有)
本项目使用 Jujutsu (jj) 进行版本控制,并配套 Memorix MCP 作为架构决策与思维轨迹的持久化中心。
- 技能调用: 必须使用 jujutsu 相关工具技能来执行分支、提交、修改describe等操作禁止直接通过 Shell 执行冗长的 Git 兼容指令。
- 描述规范 (jj desc):
- 执行 jj desc 时,首行必须是精简的变更标题。
- 空一行后,仅记录改动的核心业务点。
- 语言使用英文进行描述
- 禁忌: 禁止在 jj 描述中堆砌复杂的算法逻辑或长篇的设计决策。
- 记忆联动 (Memorix 优先):
- 凡涉及架构变更、算法决策或重构逻辑,在执行 jj desc 之前,必须先调用 memorix_store (或对应的添加方法)。
- 关联标记: 在 Memorix 的存储记录中,必须强制包含当前变更的 jj change ID以便实现从代码变更到思维链的完美映射。
- 检索逻辑: 在处理需要深入理解上下文的任务时,应主动调用 memorix_search 检索相关的历史 change_id 决策。
- 无感记录原则:
- 严禁在工程目录下生成任何独立的 change_log.md 或 AI 自动化文档。
- 所有关于“为什么这样改”的知识,应当流向 jj 的原子化提交描述或 Memorix 的知识图谱库。
### 描述示例
```text
refactor(compressors): Simplify module by removing SAM/DINO separation code
- Remove dino_compressor.py and segament_compressor.py
- Rewrite pipeline.py to inline DINO into HashPipeline
- Maintain backward compatibility: SAMHashPipeline alias
- Update tests and benchmark.py
```
### 提交步骤
- 执行`jj diff --no-pager`获取当前所有更改
- 根据更改内容与openspec生成的相关文档进行总结重点在于更改内容及其决策逻辑
- 调用记忆功能如Memorix记忆先前总结的内容
- 遵循描述规范使用jj进行更改的描述
- 执行`jj new`开启一个新的更改
## 记忆管理 (Memorix MCP)
本项目使用 Memorix 作为核心上下文引擎,用于存储架构决策、复杂逻辑关联和历史重构原因。
### 记忆写入准则
- 主动记录: 在完成以下操作后,必须调用 `memorix.store`
- 用户确认后的核心架构变更例如LanceDB 的索引策略)。
- 复杂的 bug 修复逻辑(记录“为什么”这么修,防止回滚)。
- 用户在对话中表达的明确偏好(例如:对特定 Python 库的厌恶)。
- 代码的修改及其决策逻辑(例如:对于用户特定需求导致的更改)。
- 结构化存储: 存储时请使用 `[Category: Topic] Description` 的格式,确保检索效率。
### 记忆检索准则
- 冷启动检索: 每一轮新对话开始或切换到新任务时,优先调用 `memorix.search` 关键词(如 "project_architecture", "database_schema"),以确保不偏离既有设计。
- 防止幻觉: 如果对某个旧功能的实现细节不确定,先检索记忆,禁止凭空猜测。
### 内存与冗余控制
- 精简描述: 存入 Memorix 的信息必须精简,严禁存入整段代码块,仅存储“逻辑描述”和“决策依据”。
- 清理逻辑: 发现记忆库中存在与当前代码事实冲突的旧信息时,应主动提示用户进行更新或覆盖。

View File

@@ -0,0 +1,356 @@
"""Multi-object retrieval benchmark task.
This benchmark evaluates retrieval accuracy using multiple objects from a cropped
scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature
extraction, and LanceDB for vector storage with scene-level score aggregation.
"""
import random
from typing import Any
import lancedb
import numpy as np
import pyarrow as pa
from benchmarks.base import BaseBenchmarkTask
from benchmarks.tasks.registry import RegisterTask
from configs.models import BenchmarkTaskConfig
from rich.progress import track
from torch import nn
from torch.utils.data import DataLoader
from transformers import BitImageProcessorFast
from utils.feature_extractor import extract_single_image_feature
from utils.sam import load_sam_model, segment_image
from utils.common import get_device
def _build_object_schema(vector_dim: int) -> pa.Schema:
"""Build PyArrow schema for object-level vectors.
Args:
vector_dim: Feature vector dimension.
Returns:
PyArrow schema with id, image_id, object_id, category, and vector fields.
"""
return pa.schema(
[
pa.field("id", pa.int32()),
pa.field("image_id", pa.string()),
pa.field("object_id", pa.string()),
pa.field("category", pa.string()),
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
]
)
def _compute_scene_score(
query_object_ids: list[str],
retrieved_results: dict[str, list[tuple[float, str]]],
gamma: float,
) -> dict[str, float]:
"""Compute scene-level scores using co-occurrence penalty.
Args:
query_object_ids: List of query object IDs.
retrieved_results: Dict mapping image_id to list of (distance, object_id) results.
gamma: Co-occurrence penalty exponent.
Returns:
Dict mapping image_id to computed scene score.
"""
scene_scores: dict[str, float] = {}
for image_id, results in retrieved_results.items():
# Build a set of retrieved object IDs for this scene
retrieved_ids = {obj_id for _, obj_id in results}
# Count how many query objects are found in this scene
matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids)
if matched_count == 0:
scene_scores[image_id] = 0.0
continue
# Sum of best similarities (using distance as similarity: smaller = better)
# We use 1/(1+distance) to convert distance to similarity
similarities = []
for dist, obj_id in results:
if obj_id in query_object_ids:
sim = 1.0 / (1.0 + dist)
similarities.append(sim)
sum_similarity = sum(similarities) if similarities else 0.0
# Hit rate: ratio of matched objects
hit_rate = matched_count / len(query_object_ids)
# Final score: sum_similarity * (hit_rate)^gamma
score = sum_similarity * (hit_rate ** gamma)
scene_scores[image_id] = score
return scene_scores
@RegisterTask("multi-object-retrieval")
class MultiObjectRetrievalTask(BaseBenchmarkTask):
"""Multi-object retrieval benchmark task."""
def __init__(self, **kwargs: Any):
"""Initialize multi-object retrieval task.
Args:
**kwargs: Configuration parameters from BenchmarkTaskConfig.
"""
# Use config from kwargs or load default config
if kwargs:
config_dict = kwargs
else:
config = BenchmarkTaskConfig(type="multi-object-retrieval")
config_dict = config.model_dump()
super().__init__(**config_dict)
self.config = BenchmarkTaskConfig(**config_dict)
# SAM settings from ModelConfig (passed via kwargs or use defaults)
self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large")
self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32)
self.max_masks_per_image = kwargs.get("sam_max_masks", 5)
# Lazy-loaded resources
self._sam_model = None
self._mask_generator = None
@property
def sam_model(self) -> Any:
"""Lazy-load SAM model."""
if self._sam_model is None:
self._sam_model, self._mask_generator = load_sam_model(
model_name=self.sam_model,
device=str(get_device()),
)
return self._sam_model
@property
def mask_generator(self) -> Any:
"""Lazy-load mask generator."""
if self._mask_generator is None:
self._sam_model, self._mask_generator = load_sam_model(
model_name=self.sam_model,
device=str(get_device()),
)
return self._mask_generator
def build_database(
self,
model: nn.Module,
processor: BitImageProcessorFast,
train_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> None:
"""Build the evaluation database with object-level vectors.
Args:
model: Feature extraction model.
processor: Image preprocessor.
train_dataset: Training dataset.
table: LanceDB table to store features.
batch_size: Batch size for DataLoader.
"""
# Infer vector dimension from a sample
sample = train_dataset[0]
sample_image = sample["image"]
# Get vector dimension by running a forward pass
vector_dim = self._infer_vector_dim(processor, model, sample_image)
expected_schema = _build_object_schema(vector_dim)
# Check schema compatibility
if table.schema != expected_schema:
raise ValueError(
f"Table schema mismatch. Expected: {expected_schema}, "
f"Got: {table.schema}"
)
# Build database: segment each image, extract features per object
record_id = 0
records = []
for idx in track(range(len(train_dataset)), description="Building object database"):
item = train_dataset[idx]
image = item["image"]
image_id = item.get("image_id", f"image_{idx}")
# Segment image using SAM
masks = segment_image(
self.mask_generator,
image,
min_area=self.min_mask_area,
max_masks=self.max_masks_per_image,
)
if not masks:
continue
# Extract features for each mask
for mask_idx, mask_info in enumerate(masks):
# Extract masked region
masked_image = self._apply_mask(image, mask_info["segment"])
# Extract feature vector
vector = extract_single_image_feature(processor, model, masked_image)
# Create object ID
object_id = f"{image_id}_obj_{mask_idx}"
category = mask_info.get("category", "unknown")
records.append({
"id": record_id,
"image_id": image_id,
"object_id": object_id,
"category": category,
"vector": vector,
})
record_id += 1
# Add all records to table
if records:
table.add(records)
def evaluate(
self,
model: nn.Module,
processor: BitImageProcessorFast,
test_dataset: Any,
table: lancedb.table.Table,
batch_size: int,
) -> dict[str, Any]:
"""Evaluate the model on the test dataset.
Args:
model: Feature extraction model.
processor: Image preprocessor.
test_dataset: Test dataset.
table: LanceDB table to search against.
batch_size: Batch size for DataLoader.
Returns:
Dictionary containing evaluation results with keys:
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
- correct: Number of correct predictions
- total: Total number of test samples
- top_k: The K value used
"""
top_k = self.config.top_k_per_object
correct = 0
total = 0
for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"):
item = test_dataset[idx]
image = item["image"]
target_image_id = item.get("image_id", f"image_{idx}")
# Segment query image
masks = segment_image(
self.mask_generator,
image,
min_area=self.min_mask_area,
max_masks=self.max_masks_per_image,
)
if not masks:
continue
# Randomly sample query objects
num_query = min(self.config.num_query_objects, len(masks))
query_masks = random.sample(masks, num_query)
# Extract features and search for each query object
retrieved_results: dict[str, list[tuple[float, str]]] = {}
for mask_info in query_masks:
# Extract masked region
masked_image = self._apply_mask(image, mask_info["segment"])
# Extract feature vector
vector = extract_single_image_feature(processor, model, masked_image)
# Search in LanceDB
results = (
table.search(vector)
.select(["image_id", "object_id", "_distance"])
.limit(top_k)
.to_polars()
)
# Aggregate results by scene
for row in results.iter_rows():
image_id = row["image_id"]
object_id = row["object_id"]
distance = row["_distance"]
if image_id not in retrieved_results:
retrieved_results[image_id] = []
retrieved_results[image_id].append((distance, object_id))
# Compute scene scores
query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)]
scene_scores = _compute_scene_score(
query_object_ids,
retrieved_results,
self.config.gamma,
)
# Rank scenes by score
ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True)
# Check if target is in top-K
top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]]
if target_image_id in top_k_scenes:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0.0
return {
"accuracy": accuracy,
"correct": correct,
"total": total,
"top_k": top_k,
}
def _infer_vector_dim(
self,
processor: BitImageProcessorFast,
model: nn.Module,
sample_image: Any,
) -> int:
"""Infer vector dimension from model output."""
vector = extract_single_image_feature(processor, model, sample_image)
return len(vector)
def _apply_mask(self, image: Any, mask: np.ndarray) -> Any:
"""Apply mask to image and return masked image.
Args:
image: PIL Image.
mask: Binary mask as numpy array.
Returns:
Masked PIL Image.
"""
import numpy as np
from PIL import Image
image_np = np.array(image.convert("RGB"))
# Ensure mask is the right shape
if mask.shape != image_np.shape[:2]:
from skimage.transform import resize
mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False)
else:
mask_resized = mask
# Apply mask
masked_np = image_np * mask_resized[:, :, np.newaxis]
return Image.fromarray(masked_np.astype(np.uint8))

View File

@@ -1,4 +1,4 @@
from typing import cast from typing import Any, Optional, cast
import typer import typer
from commands import app from commands import app
@@ -7,15 +7,15 @@ from commands import app
@app.command() @app.command()
def benchmark( def benchmark(
ctx: typer.Context, ctx: typer.Context,
model_path: str = typer.Option( model_path: Optional[str] = typer.Option(
None, "--model", "-m", help="Path to compressor model weights" None, "--model", "-m", help="Path to compressor model weights"
), ),
): ):
import torch import torch
import torch.nn.functional as F
from benchmarks import run_benchmark from benchmarks import run_benchmark
from compressors import DinoCompressor
from configs import cfg_manager from configs import cfg_manager
from transformers import AutoImageProcessor, BitImageProcessorFast from transformers import AutoImageProcessor, AutoModel, BitImageProcessorFast
from utils import get_device from utils import get_device
config = cfg_manager.get() config = cfg_manager.get()
@@ -29,7 +29,12 @@ def benchmark(
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device), AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
) )
model = DinoCompressor().to(device) # Load DINO model for feature extraction
dino = AutoModel.from_pretrained(model_cfg.dino_model, device_map=device)
dino.eval()
# Optional hash compressor
compressor = None
if model_path: if model_path:
from compressors import HashCompressor from compressors import HashCompressor
@@ -38,7 +43,31 @@ def benchmark(
hash_bits=model_cfg.compression_dim, hash_bits=model_cfg.compression_dim,
) )
compressor.load_state_dict(torch.load(model_path)) compressor.load_state_dict(torch.load(model_path))
model.compressor = compressor compressor.to(device)
compressor.eval()
# Create wrapper with extract_features method
class DinoFeatureExtractor:
def __init__(self, dino, compressor=None):
self.dino = dino
self.compressor = compressor
def extract_features(self, images: list) -> torch.Tensor:
inputs = processor(images, return_tensors="pt").to(device)
with torch.no_grad():
outputs = self.dino(**inputs)
features = outputs.last_hidden_state.mean(dim=1)
features = F.normalize(features, dim=-1)
return features
def encode(self, images: list) -> torch.Tensor:
if self.compressor is None:
return self.extract_features(images)
tokens = self.dino(**processor(images, return_tensors="pt").to(device)).last_hidden_state
_, _, bits = self.compressor(tokens)
return bits
model = DinoFeatureExtractor(dino, compressor)
run_benchmark( run_benchmark(
model=model, model=model,

View File

@@ -1,18 +1,15 @@
from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits
from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask
from .pipeline import SAMHashPipeline, create_pipeline_from_config from .pipeline import HashPipeline, SAMHashPipeline, create_pipeline_from_config
from .segament_compressor import SegmentCompressor
from .train import train from .train import train
__all__ = [ __all__ = [
"train", "train",
"DinoCompressor",
"HashCompressor", "HashCompressor",
"HashLoss", "HashLoss",
"VideoPositiveMask", "VideoPositiveMask",
"SegmentCompressor", "HashPipeline",
"SAMHashPipeline", "SAMHashPipeline", # Backward compatibility alias
"create_pipeline_from_config", "create_pipeline_from_config",
"BinarySign", "BinarySign",
"hamming_distance", "hamming_distance",

View File

@@ -1,105 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
class DinoCompressor(nn.Module):
"""DINOv2 feature extractor with optional hash compression.
When compressor is None: returns normalized DINO embeddings.
When compressor is provided: returns binary hash bits for CAM storage.
Supports both PIL Image input and pre-extracted tokens.
"""
def __init__(
self,
model_name: str = "facebook/dinov2-large",
compressor: Optional[nn.Module] = None,
device: Optional[str] = None,
):
"""Initialize DINOv2 extractor.
Args:
model_name: HuggingFace model name
compressor: Optional hash compressor for producing binary codes
device: Device to load model on
"""
super().__init__()
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model_name = model_name
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.dino = AutoModel.from_pretrained(model_name).to(self.device)
self.dino.eval()
self.compressor = compressor
def forward(self, inputs):
teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024]
teacher_embed = teacher_tokens.mean(dim=1)
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
if self.compressor is None:
return teacher_embed
# HashCompressor returns (logits, hash_codes, bits)
_, _, bits = self.compressor(teacher_tokens)
return bits # [B, 512] binary bits for CAM
def extract_features(self, images: list[Image.Image]) -> torch.Tensor:
"""Extract DINO features from a list of cropped object images.
Args:
images: List of PIL Images (cropped objects)
Returns:
DINO features [N, feature_dim], normalized
"""
if len(images) == 0:
return torch.empty(0, self.dino.config.hidden_size, device=self.device)
# Process batch of images
inputs = self.processor(images, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
# Pool tokens to get global representation
features = outputs.last_hidden_state.mean(dim=1) # [N, 1024]
features = F.normalize(features, dim=-1)
return features
def encode(self, images: list[Image.Image]) -> torch.Tensor:
"""Extract features from images and optionally compress to hash codes.
Args:
images: List of PIL Images
Returns:
If compressor is None: DINO features [N, 1024]
If compressor is set: Binary hash bits [N, 512]
"""
if self.compressor is None:
return self.extract_features(images)
# Extract features first
features = self.extract_features(images) # [N, 1024]
# Add sequence dimension for compressor (expects [B, N, dim])
features = features.unsqueeze(1) # [N, 1, 1024]
# Compress to hash codes
_, _, bits = self.compressor(features)
return bits

View File

@@ -1,79 +1,65 @@
"""Complete pipeline for SAM + DINO + HashCompressor. """Hash compression pipeline with DINO feature extraction.
This pipeline extracts object masks from images using SAM2.1, This pipeline extracts features using DINOv2 and compresses them
crops the objects, extracts features using DINOv2, to binary hash codes using HashCompressor.
and compresses them to binary hash codes using HashCompressor.
""" """
from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from .dino_compressor import DinoCompressor
from .hash_compressor import HashCompressor
from .segament_compressor import SegmentCompressor
def create_pipeline_from_config(config) -> "SAMHashPipeline": def create_pipeline_from_config(config) -> "HashPipeline":
"""Create SAMHashPipeline from a config object. """Create HashPipeline from a config object.
Args: Args:
config: Configuration object with model settings config: Configuration object with model settings
Returns: Returns:
Initialized SAMHashPipeline Initialized HashPipeline
""" """
return SAMHashPipeline( return HashPipeline(
sam_model=config.model.sam_model, dino_model=config.model.dino_model,
dino_model=config.model.name,
hash_bits=config.model.compression_dim, hash_bits=config.model.compression_dim,
sam_min_mask_area=config.model.sam_min_mask_area,
sam_max_masks=config.model.sam_max_masks,
compressor_path=config.model.compressor_path, compressor_path=config.model.compressor_path,
device=config.model.device if config.model.device != "auto" else None, device=config.model.device if config.model.device != "auto" else None,
) )
class SAMHashPipeline(nn.Module): class HashPipeline(nn.Module):
"""Complete pipeline: SAM segmentation + DINO features + Hash compression. """Pipeline: DINO features + Hash compression.
Pipeline flow: Pipeline flow:
Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes) PIL Image -> DINO (features) -> Hash (binary codes)
Usage: Usage:
# Initialize with config # Initialize with config
pipeline = SAMHashPipeline( pipeline = HashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large", dino_model="facebook/dinov2-large",
hash_bits=512, hash_bits=512,
) )
# Process image # Process image
image = Image.open("path/to/image.jpg") image = Image.open("path/to/image.jpg")
hash_codes = pipeline(image) # [N, 512] binary bits hash_bits = pipeline(image) # [1, 512] binary bits
""" """
def __init__( def __init__(
self, self,
sam_model: str = "facebook/sam2.1-hiera-large",
dino_model: str = "facebook/dinov2-large", dino_model: str = "facebook/dinov2-large",
hash_bits: int = 512, hash_bits: int = 512,
sam_min_mask_area: int = 100,
sam_max_masks: int = 10,
compressor_path: Optional[str] = None, compressor_path: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
): ):
"""Initialize the complete pipeline. """Initialize the pipeline.
Args: Args:
sam_model: SAM model name from HuggingFace
dino_model: DINOv2 model name from HuggingFace dino_model: DINOv2 model name from HuggingFace
hash_bits: Number of bits in hash code hash_bits: Number of bits in hash code
sam_min_mask_area: Minimum mask area threshold
sam_max_masks: Maximum number of masks to keep
compressor_path: Optional path to trained HashCompressor weights compressor_path: Optional path to trained HashCompressor weights
device: Device to run models on device: Device to run models on
""" """
@@ -84,87 +70,101 @@ class SAMHashPipeline(nn.Module):
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device) self.device = torch.device(device)
# Initialize components self.dino_model = dino_model
self.segmentor = SegmentCompressor(
model_name=sam_model,
min_mask_area=sam_min_mask_area,
max_masks=sam_max_masks,
device=device,
)
# HashCompressor expects DINO features (1024 dim for dinov2-large) # Initialize DINO processor and model
dino_dim = 1024 if "large" in dino_model else 768 self.processor = AutoImageProcessor.from_pretrained(dino_model)
self.hash_compressor = HashCompressor( self.dino = AutoModel.from_pretrained(dino_model).to(self.device)
input_dim=dino_dim, hash_bits=hash_bits self.dino.eval()
).to(device)
# Determine DINO feature dimension
self.dino_dim = 1024 if "large" in dino_model else 768
# Initialize HashCompressor
self.hash_compressor = nn.Module() # Placeholder, will be replaced
self._init_hash_compressor(hash_bits, compressor_path)
def _init_hash_compressor(
self, hash_bits: int, compressor_path: Optional[str] = None
):
"""Initialize the hash compressor module.
This is called during __init__ but we need to replace it properly.
"""
# Import here to avoid circular imports
from .hash_compressor import HashCompressor
compressor = HashCompressor(input_dim=self.dino_dim, hash_bits=hash_bits).to(
self.device
)
# Load pretrained compressor if provided # Load pretrained compressor if provided
if compressor_path is not None: if compressor_path is not None:
self.hash_compressor.load_state_dict( compressor.load_state_dict(
torch.load(compressor_path, map_location=device) torch.load(compressor_path, map_location=self.device)
) )
print(f"[OK] Loaded HashCompressor from {compressor_path}") print(f"[OK] Loaded HashCompressor from {compressor_path}")
self.dino = DinoCompressor( # Replace the placeholder
model_name=dino_model, self.hash_compressor = compressor
compressor=self.hash_compressor,
device=device, @property
) def hash_bits(self):
"""Return the number of hash bits."""
return self.hash_compressor.hash_bits
def forward(self, image: Image.Image) -> torch.Tensor: def forward(self, image: Image.Image) -> torch.Tensor:
"""Process a single image through the complete pipeline. """Process a single image through the pipeline.
Args: Args:
image: Input PIL Image image: Input PIL Image
Returns: Returns:
Binary hash codes [N, hash_bits] where N is number of detected objects Binary hash codes [1, hash_bits] as int32
""" """
# Step 1: SAM - extract and crop objects # Extract DINO features
cropped_objects = self.segmentor(image) inputs = self.processor(image, return_tensors="pt").to(self.device)
if len(cropped_objects) == 0: with torch.no_grad():
# No objects detected, return empty tensor outputs = self.dino(**inputs)
return torch.empty( tokens = outputs.last_hidden_state # [1, N, dim]
0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device
)
# Step 2: DINO - extract features from cropped objects # Compress to hash codes
# Step 3: HashCompressor - compress features to binary codes _, _, bits = self.hash_compressor(tokens)
hash_codes = self.dino.encode(cropped_objects)
return hash_codes return bits
def extract_features( def encode(self, image: Image.Image) -> torch.Tensor:
self, image: Image.Image, use_hash: bool = False """Encode an image to binary hash bits.
) -> torch.Tensor:
"""Extract features from image with optional hash compression.
Args: Alias for forward().
image: Input PIL Image
use_hash: If True, return binary hash codes; else return DINO features
Returns:
Features [N, dim] where dim is 1024 (DINO) or 512 (hash)
"""
cropped_objects = self.segmentor(image)
if len(cropped_objects) == 0:
dim = self.hash_compressor.hash_bits if use_hash else 1024
return torch.empty(0, dim, device=self.device)
if use_hash:
return self.dino.encode(cropped_objects)
else:
return self.dino.extract_features(cropped_objects)
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without full processing (for debugging).
Args: Args:
image: Input PIL Image image: Input PIL Image
Returns: Returns:
List of binary masks [H, W] Binary hash codes [1, hash_bits] as int32
""" """
return self.segmentor.extract_masks(image) return self.forward(image)
def extract_features(self, image: Image.Image) -> torch.Tensor:
"""Extract DINO features from an image.
Args:
image: Input PIL Image
Returns:
DINO features [1, dino_dim], normalized
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino(**inputs)
features = outputs.last_hidden_state.mean(dim=1) # [1, dim]
features = F.normalize(features, dim=-1)
return features
# Backward compatibility alias
SAMHashPipeline = HashPipeline

View File

@@ -1,180 +0,0 @@
"""Segment Anything 2 feature extractor with mask filtering and image cropping.
Extracts object masks from images using SAM2.1, filters by area and confidence,
then crops the original image to obtain individual object regions.
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
class SegmentCompressor(nn.Module):
"""SAM2.1 based segmenter with mask filtering.
Extracts object masks from images, filters by area and confidence,
and crops the original image to produce individual object patches.
"""
def __init__(
self,
model_name: str = "facebook/sam2.1-hiera-large",
min_mask_area: int = 100,
max_masks: int = 10,
device: Optional[str] = None,
):
"""Initialize SAM2.1 segmenter.
Args:
model_name: HuggingFace model name for SAM2.1
min_mask_area: Minimum mask pixel area threshold
max_masks: Maximum number of masks to keep
device: Device to load model on (auto-detect if None)
"""
super().__init__()
self.model_name = model_name
self.min_mask_area = min_mask_area
self.max_masks = max_masks
# Auto detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Load SAM model and processor
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForMaskGeneration.from_pretrained(model_name).to(
self.device
)
self.model.eval()
def forward(self, image: Image.Image) -> list[Image.Image]:
"""Extract object masks and crop object regions.
Args:
image: Input PIL Image
Returns:
List of cropped object images (one per valid mask)
"""
# Run SAM inference
inputs = self.processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process masks
masks = self.processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)[0]
# Filter masks by area and confidence
valid_masks = self._filter_masks(masks)
if len(valid_masks) == 0:
return []
# Crop object regions from original image
cropped_objects = self._crop_objects(image, valid_masks)
return cropped_objects
def _filter_masks(self, masks: torch.Tensor) -> list[dict]:
"""Filter masks by area and keep top-N.
Args:
masks: Predicted masks [N, H, W]
Returns:
List of mask dictionaries with 'mask' and 'area'
"""
valid_masks = []
for mask in masks:
# Calculate mask area
area = mask.sum().item()
# Filter by minimum area
if area < self.min_mask_area:
continue
valid_masks.append({"mask": mask, "area": area})
# Sort by area (descending) and keep top-N
valid_masks = sorted(valid_masks, key=lambda x: x["area"], reverse=True)
valid_masks = valid_masks[: self.max_masks]
return valid_masks
def _crop_objects(
self, image: Image.Image, masks: list[dict]
) -> list[Image.Image]:
"""Crop object regions from image using masks.
Args:
image: Original PIL Image
masks: List of mask dictionaries
Returns:
List of cropped object images
"""
# Convert PIL to numpy for processing
image_np = np.array(image)
h, w = image_np.shape[:2]
cropped_objects = []
for mask_info in masks:
mask = mask_info["mask"].cpu().numpy()
# Find bounding box from mask
rows = mask.any(axis=1)
cols = mask.any(axis=0)
if not rows.any() or not cols.any():
continue
y_min, y_max = rows.argmax(), h - rows[::-1].argmax() - 1
x_min, x_max = cols.argmax(), w - cols[::-1].argmax() - 1
# Add small padding
pad = 5
x_min = max(0, x_min - pad)
y_min = max(0, y_min - pad)
x_max = min(w, x_max + pad)
y_max = min(h, y_max + pad)
# Crop
cropped = image.crop((x_min, y_min, x_max, y_max))
cropped_objects.append(cropped)
return cropped_objects
@torch.no_grad()
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
"""Extract only masks without cropping (for debugging).
Args:
image: Input PIL Image
Returns:
List of binary masks [H, W]
"""
inputs = self.processor(image, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
masks = self.processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"],
)[0]
valid_masks = self._filter_masks(masks)
return [m["mask"] for m in valid_masks]

View File

@@ -118,6 +118,17 @@ class BenchmarkTaskConfig(BaseModel):
type: str = Field(default="retrieval", description="Task type") type: str = Field(default="retrieval", description="Task type")
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation") top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
# Multi-object retrieval specific settings
gamma: float = Field(
default=1.0, ge=0, description="Co-occurrence penalty exponent"
)
top_k_per_object: int = Field(
default=50, gt=0, description="Top K results per object query"
)
num_query_objects: int = Field(
default=3, gt=0, description="Number of objects to sample from query image"
)
class BenchmarkConfig(BaseModel): class BenchmarkConfig(BaseModel):
"""Configuration for benchmark evaluation.""" """Configuration for benchmark evaluation."""

View File

@@ -0,0 +1,64 @@
"""InsDet Scenes dataset for multi-object retrieval benchmark."""
from pathlib import Path
from typing import Any
from benchmarks.base import BaseDataset
from data_loading.loader import load_val_dataset
class InsDetScenesDataset(BaseDataset):
"""InsDet-FULL/Scenes dataset with easy/hard splits.
This dataset provides scene images with object annotations from the
Instance Detection (InsDet) dataset, supporting easy and hard splits.
"""
def __init__(
self,
scenes_dir: Path | str,
split: str = "easy",
):
"""Initialize InsDet Scenes dataset.
Args:
scenes_dir: Path to the InsDet-FULL/Scenes directory.
split: Scene split to use ('easy' or 'hard').
"""
self.scenes_dir = Path(scenes_dir)
self.split = split
self._dataset = load_val_dataset(self.scenes_dir, split)
def get_train_split(self) -> Any:
"""Get training split (same as test for this dataset).
Returns:
HuggingFace Dataset for training.
"""
return self._dataset
def get_test_split(self) -> Any:
"""Get test/evaluation split.
Returns:
HuggingFace Dataset for testing.
"""
return self._dataset
def __len__(self) -> int:
"""Get dataset length."""
return len(self._dataset)
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Get a single item from the dataset.
Args:
idx: Index of the item.
Returns:
Dictionary containing:
- image: PIL Image
- image_id: Scene identifier
- objects: dict with bbox, category, area, id
"""
return self._dataset[idx]

View File

@@ -1,13 +1,13 @@
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline).""" """Tests for compressor modules (HashCompressor, Pipeline)."""
import pytest import pytest
import torch import torch
from compressors import ( from compressors import (
BinarySign, BinarySign,
DinoCompressor,
HashCompressor, HashCompressor,
HashPipeline,
SAMHashPipeline, SAMHashPipeline,
SegmentCompressor, VideoPositiveMask,
bits_to_hash, bits_to_hash,
create_pipeline_from_config, create_pipeline_from_config,
hamming_distance, hamming_distance,
@@ -124,87 +124,105 @@ class TestHammingMetrics:
assert sim.item() == 512 # Max similarity assert sim.item() == 512 # Max similarity
class TestSegmentCompressor: class TestHashLoss:
"""Test suite for SegmentCompressor.""" """Test suite for HashLoss."""
@pytest.fixture def test_hash_loss_init(self):
def mock_image(self): """Verify HashLoss initializes with correct parameters."""
"""Create a mock PIL image.""" from compressors import HashLoss
img = Image.new("RGB", (224, 224), color="red")
return img
def test_segment_compressor_init(self): loss_fn = HashLoss(
"""Verify SegmentCompressor initializes with correct parameters.""" contrastive_weight=1.0,
segmentor = SegmentCompressor( distill_weight=0.5,
model_name="facebook/sam2.1-hiera-large", quant_weight=0.01,
min_mask_area=100, temperature=0.2,
max_masks=10,
) )
assert segmentor.model_name == "facebook/sam2.1-hiera-large" assert loss_fn.contrastive_weight == 1.0
assert segmentor.min_mask_area == 100 assert loss_fn.distill_weight == 0.5
assert segmentor.max_masks == 10 assert loss_fn.quant_weight == 0.01
assert loss_fn.temperature == 0.2
def test_filter_masks(self): def test_hash_loss_forward(self):
"""Verify mask filtering logic.""" """Verify HashLoss computes loss correctly."""
# Create segmentor to get default filter params from compressors import HashLoss
segmentor = SegmentCompressor()
# Create mock masks tensor with different areas loss_fn = HashLoss()
# Masks shape: [N, H, W]
masks = []
for area in [50, 200, 150, 300, 10]:
mask = torch.zeros(100, 100)
mask[:1, :area] = 1 # Create mask with specific area
masks.append(mask)
masks_tensor = torch.stack(masks) # [5, 100, 100] batch_size = 4
valid = segmentor._filter_masks(masks_tensor) hash_bits = 512
logits = torch.randn(batch_size, hash_bits)
hash_codes = torch.sign(logits)
teacher_embed = torch.randn(batch_size, 1024)
positive_mask = torch.eye(batch_size, dtype=torch.bool)
# Should filter out 50 and 10 (below min_mask_area=100) total_loss, components = loss_fn(
# Then keep top 3 (max_masks=10) logits=logits,
assert len(valid) == 3 hash_codes=hash_codes,
# Verify sorted by area (descending) teacher_embed=teacher_embed,
areas = [v["area"] for v in valid] positive_mask=positive_mask,
assert areas == sorted(areas, reverse=True) )
assert "contrastive" in components
assert "distill" in components
assert "quantization" in components
assert "total" in components
class TestDinoCompressor: class TestVideoPositiveMask:
"""Test suite for DinoCompressor.""" """Test suite for VideoPositiveMask."""
def test_dino_compressor_init(self): def test_from_frame_indices(self):
"""Verify DinoCompressor initializes correctly.""" """Verify positive mask generation from frame indices."""
dino = DinoCompressor() mask_gen = VideoPositiveMask(temporal_window=2)
assert dino.model_name == "facebook/dinov2-large" frame_indices = torch.tensor([0, 1, 3, 5])
def test_dino_compressor_with_compressor(self): mask = mask_gen.from_frame_indices(frame_indices)
"""Verify DinoCompressor with HashCompressor."""
hash_compressor = HashCompressor(input_dim=1024, hash_bits=512)
dino = DinoCompressor(compressor=hash_compressor)
assert dino.compressor is hash_compressor assert mask.shape == (4, 4)
# Frame 0 and 1 should be positive (distance 1 <= 2)
assert mask[0, 1] == True
# Frame 0 and 3 should be negative (distance 3 > 2)
assert mask[0, 3] == False
def test_from_video_ids(self):
"""Verify positive mask generation from video IDs and frame indices."""
mask_gen = VideoPositiveMask(temporal_window=2)
video_ids = torch.tensor([0, 0, 1, 1])
frame_indices = torch.tensor([0, 1, 0, 1])
mask = mask_gen.from_video_ids(video_ids, frame_indices)
assert mask.shape == (4, 4)
# Same video and temporally close
assert mask[0, 1] == True # video 0, frames 0,1
# Different video
assert mask[0, 2] == False # video 0 vs 1
class TestSAMHashPipeline: class TestHashPipeline:
"""Test suite for SAMHashPipeline.""" """Test suite for HashPipeline."""
def test_pipeline_init(self): def test_pipeline_init(self):
"""Verify pipeline initializes all components.""" """Verify pipeline initializes all components."""
pipeline = SAMHashPipeline( pipeline = HashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large", dino_model="facebook/dinov2-large",
hash_bits=512, hash_bits=512,
) )
assert isinstance(pipeline.segmentor, SegmentCompressor) assert pipeline.dino_model == "facebook/dinov2-large"
assert isinstance(pipeline.dino, DinoCompressor) assert pipeline.dino_dim == 1024
assert isinstance(pipeline.hash_compressor, HashCompressor)
def test_pipeline_hash_bits(self): def test_pipeline_hash_bits(self):
"""Verify pipeline uses correct hash bits.""" """Verify pipeline uses correct hash bits."""
pipeline = SAMHashPipeline(hash_bits=256) pipeline = HashPipeline(hash_bits=256)
assert pipeline.hash_compressor.hash_bits == 256 assert pipeline.hash_bits == 256
def test_pipeline_alias(self):
"""Verify SAMHashPipeline is alias for HashPipeline."""
assert SAMHashPipeline is HashPipeline
class TestConfigIntegration: class TestConfigIntegration:
@@ -216,25 +234,21 @@ class TestConfigIntegration:
pipeline = create_pipeline_from_config(config) pipeline = create_pipeline_from_config(config)
assert isinstance(pipeline, SAMHashPipeline) assert isinstance(pipeline, HashPipeline)
assert pipeline.hash_compressor.hash_bits == config.model.compression_dim assert pipeline.hash_bits == config.model.compression_dim
def test_config_sam_settings(self): def test_config_settings(self):
"""Verify config contains SAM settings.""" """Verify config contains required settings."""
config = cfg_manager.load() config = cfg_manager.load()
assert hasattr(config.model, "sam_model") assert hasattr(config.model, "dino_model")
assert hasattr(config.model, "sam_min_mask_area") assert hasattr(config.model, "compression_dim")
assert hasattr(config.model, "sam_max_masks")
assert config.model.sam_model == "facebook/sam2.1-hiera-large"
assert config.model.sam_min_mask_area == 100
assert config.model.sam_max_masks == 10
@pytest.mark.slow
class TestPipelineIntegration: class TestPipelineIntegration:
"""Integration tests for full pipeline (slow, requires model downloads).""" """Integration tests for full pipeline (slow, requires model downloads)."""
@pytest.mark.slow
def test_pipeline_end_to_end(self): def test_pipeline_end_to_end(self):
"""Test full pipeline with actual models (slow test).""" """Test full pipeline with actual models (slow test)."""
# Skip if no GPU # Skip if no GPU
@@ -245,54 +259,32 @@ class TestPipelineIntegration:
image = Image.new("RGB", (640, 480), color=(128, 128, 128)) image = Image.new("RGB", (640, 480), color=(128, 128, 128))
# Initialize pipeline (will download models on first run) # Initialize pipeline (will download models on first run)
pipeline = SAMHashPipeline( pipeline = HashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large", dino_model="facebook/dinov2-large",
hash_bits=512, hash_bits=512,
sam_min_mask_area=100,
sam_max_masks=5,
) )
# Run pipeline # Run pipeline
hash_codes = pipeline(image) hash_bits = pipeline(image)
# Verify output shape # Verify output shape
assert hash_codes.dim() == 2 assert hash_bits.dim() == 2
assert hash_codes.shape[1] == 512 assert hash_bits.shape[1] == 512
assert torch.all((hash_codes == 0) | (hash_codes == 1)) assert torch.all((hash_bits == 0) | (hash_bits == 1))
@pytest.mark.slow def test_extract_features(self):
def test_extract_features_without_hash(self): """Test feature extraction."""
"""Test feature extraction without hash compression."""
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("Requires CUDA") pytest.skip("Requires CUDA")
image = Image.new("RGB", (640, 480), color=(128, 128, 128)) image = Image.new("RGB", (640, 480), color=(128, 128, 128))
pipeline = SAMHashPipeline( pipeline = HashPipeline(
sam_model="facebook/sam2.1-hiera-large",
dino_model="facebook/dinov2-large", dino_model="facebook/dinov2-large",
) )
features = pipeline.extract_features(image, use_hash=False) features = pipeline.extract_features(image)
# Should return DINO features (1024 for large) # Should return DINO features (1024 for large)
assert features.dim() == 2 assert features.dim() == 2
assert features.shape[1] == 1024 assert features.shape[1] == 1024
@pytest.mark.slow
def test_extract_masks_only(self):
"""Test mask extraction only."""
if not torch.cuda.is_available():
pytest.skip("Requires CUDA")
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
pipeline = SAMHashPipeline(
sam_model="facebook/sam2.1-hiera-large",
)
masks = pipeline.extract_masks(image)
# Should return a list of masks
assert isinstance(masks, list)

View File

@@ -0,0 +1,238 @@
"""Integration tests for multi-object retrieval benchmark pipeline.
These tests verify the end-to-end functionality of the multi-object retrieval
benchmark, including schema building, database population, and evaluation.
"""
import numpy as np
import pytest
from unittest.mock import Mock, patch, MagicMock
from PIL import Image
class TestMultiObjectRetrievalIntegration:
"""Integration tests for multi-object retrieval benchmark."""
@pytest.fixture
def mock_model_processor(self):
"""Create mock model and processor."""
mock_model = Mock()
mock_processor = Mock()
# Mock the feature extraction to return a fixed-size vector
def mock_extract(processor, model, image):
return [0.1] * 256 # 256-dim vector
mock_processor.images = mock_extract
return mock_model, mock_processor
@pytest.fixture
def mock_dataset(self):
"""Create a mock dataset with images and annotations."""
# Create mock items
items = []
for i in range(3):
item = {
"image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)),
"image_id": f"scene_{i}",
"objects": {
"bbox": [[10, 10, 50, 50], [60, 60, 40, 40]],
"category": ["object_a", "object_b"],
"area": [2500, 1600],
"id": [0, 1],
},
}
items.append(item)
mock_dataset = Mock()
mock_dataset.__len__ = Mock(return_value=len(items))
mock_dataset.__getitem__ = lambda self, idx: items[idx]
mock_dataset.with_format = lambda fmt: mock_dataset
return mock_dataset
def test_build_object_schema(self):
"""Test that object schema is built correctly."""
from benchmarks.tasks.multi_object_retrieval import _build_object_schema
import pyarrow as pa
vector_dim = 256
schema = _build_object_schema(vector_dim)
assert isinstance(schema, pa.Schema)
assert "id" in schema.names
assert "image_id" in schema.names
assert "object_id" in schema.names
assert "category" in schema.names
assert "vector" in schema.names
# Check vector field has correct dimension
vector_field = schema.field("vector")
assert isinstance(vector_field.type, pa.List)
assert vector_field.type.value_type == pa.float32()
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
def test_build_database_with_mocked_sam(
self,
mock_segment,
mock_load_sam,
mock_model_processor,
mock_dataset,
):
"""Test database building with mocked SAM segmentation."""
from benchmarks.tasks.multi_object_retrieval import (
MultiObjectRetrievalTask,
_build_object_schema,
)
mock_model, mock_processor = mock_model_processor
# Mock SAM
mock_load_sam.return_value = (Mock(), Mock())
mock_segment.return_value = [
{
"segment": np.ones((224, 224), dtype=bool),
"area": 50000,
"bbox": [0, 0, 224, 224],
}
]
# Create task with config
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-large",
min_mask_area=1024,
max_masks_per_image=5,
gamma=1.0,
top_k_per_object=50,
num_query_objects=3,
)
# Create mock table
mock_table = Mock()
mock_table.schema = _build_object_schema(256)
# Build database (this should not raise)
task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
# Verify table.add was called
assert mock_table.add.called
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
def test_evaluate_with_mocked_sam(
self,
mock_segment,
mock_load_sam,
mock_model_processor,
mock_dataset,
):
"""Test evaluation with mocked SAM segmentation."""
from benchmarks.tasks.multi_object_retrieval import (
MultiObjectRetrievalTask,
_build_object_schema,
)
mock_model, mock_processor = mock_model_processor
# Mock SAM
mock_load_sam.return_value = (Mock(), Mock())
mock_segment.return_value = [
{
"segment": np.ones((224, 224), dtype=bool),
"area": 50000,
"bbox": [0, 0, 224, 224],
"object_id": "query_obj_0",
}
]
# Create mock table with search results
mock_table = Mock()
mock_table.schema = _build_object_schema(256)
# Mock search to return matching result
mock_result = Mock()
mock_result.to_polars.return_value = {
"image_id": ["scene_0"],
"object_id": ["scene_0_obj_0"],
"_distance": [0.1],
}
mock_table.search.return_value.select.return_value.limit.return_value = mock_result
# Create task
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-large",
min_mask_area=1024,
max_masks_per_image=5,
gamma=1.0,
top_k_per_object=50,
num_query_objects=1,
)
# Evaluate
results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
# Verify results structure
assert "accuracy" in results
assert "correct" in results
assert "total" in results
assert "top_k" in results
assert results["top_k"] == 50
def test_task_initialization_with_config(self):
"""Test task initialization with custom config."""
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
task = MultiObjectRetrievalTask(
sam_model="facebook/sam2.1-hiera-small",
min_mask_area=500,
max_masks_per_image=3,
gamma=0.5,
top_k_per_object=100,
num_query_objects=5,
)
assert task.sam_model == "facebook/sam2.1-hiera-small"
assert task.min_mask_area == 500
assert task.max_masks_per_image == 3
assert task.config.gamma == 0.5
assert task.config.top_k_per_object == 100
assert task.config.num_query_objects == 5
def test_task_initialization_defaults(self):
"""Test task initialization with default config."""
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
task = MultiObjectRetrievalTask()
# Check defaults from BenchmarkTaskConfig
assert task.config.gamma == 1.0
assert task.config.top_k_per_object == 50
assert task.config.num_query_objects == 3
# SAM settings from ModelConfig defaults
assert task.sam_model == "facebook/sam2.1-hiera-large"
assert task.min_mask_area == 1024
assert task.max_masks_per_image == 5
class TestInsDetScenesDataset:
"""Tests for InsDetScenesDataset class."""
def test_dataset_class_exists(self):
"""Test that InsDetScenesDataset can be imported."""
from data_loading.insdet_scenes import InsDetScenesDataset
assert InsDetScenesDataset is not None
@patch("data_loading.insdet_scenes.load_val_dataset")
def test_dataset_loads_correct_split(self, mock_load):
"""Test dataset loads correct split."""
from data_loading.insdet_scenes import InsDetScenesDataset
mock_load.return_value = Mock()
dataset = InsDetScenesDataset("/path/to/scenes", split="easy")
mock_load.assert_called_once_with("/path/to/scenes", "easy")
assert dataset.split == "easy"

168
mini-nav/tests/test_sam.py Normal file
View File

@@ -0,0 +1,168 @@
"""Tests for SAM segmentation utilities.
Note: These tests mock the SAM model loading since SAM requires
heavy model weights. The actual SAM integration should be tested
separately in integration tests.
"""
import numpy as np
import pytest
from unittest.mock import Mock, patch
from PIL import Image
class TestSAMSegmentation:
"""Test suite for SAM segmentation utilities."""
def test_segment_image_empty_masks(self):
"""Test segment_image returns empty list when no masks generated."""
from utils.sam import segment_image
# Create mock mask generator that returns empty list
mock_generator = Mock()
mock_generator.generate.return_value = []
result = segment_image(mock_generator, Image.new("RGB", (100, 100)))
assert result == []
def test_segment_image_filters_small_masks(self):
"""Test segment_image filters masks below min_area threshold."""
from utils.sam import segment_image
# Create mock masks with different areas
small_mask = {
"segment": np.zeros((10, 10), dtype=bool),
"area": 50, # Below 32*32 = 1024
"bbox": [0, 0, 10, 10],
"predicted_iou": 0.9,
"stability_score": 0.8,
}
large_mask = {
"segment": np.ones((100, 100), dtype=bool),
"area": 10000, # Above threshold
"bbox": [0, 0, 100, 100],
"predicted_iou": 0.95,
"stability_score": 0.9,
}
mock_generator = Mock()
mock_generator.generate.return_value = [small_mask, large_mask]
result = segment_image(
mock_generator,
Image.new("RGB", (100, 100)),
min_area=32 * 32,
max_masks=5,
)
# Should only return the large mask
assert len(result) == 1
assert result[0]["area"] == 10000
def test_segment_image_limits_max_masks(self):
"""Test segment_image limits to max_masks largest masks."""
from utils.sam import segment_image
# Create 10 masks with different areas
masks = [
{
"segment": np.ones((i + 1, i + 1), dtype=bool),
"area": (i + 1) * (i + 1),
"bbox": [0, 0, i + 1, i + 1],
"predicted_iou": 0.9,
"stability_score": 0.8,
}
for i in range(10)
]
mock_generator = Mock()
mock_generator.generate.return_value = masks
result = segment_image(
mock_generator,
Image.new("RGB", (100, 100)),
min_area=1,
max_masks=3,
)
# Should only return top 3 largest masks
assert len(result) == 3
# Check they are sorted by area (largest first)
areas = [m["area"] for m in result]
assert areas == sorted(areas, reverse=True)
def test_segment_image_sorted_by_area(self):
"""Test segment_image returns masks sorted by area descending."""
from utils.sam import segment_image
# Create masks with known areas (unordered)
mask1 = {"segment": np.ones((5, 5), dtype=bool), "area": 25, "bbox": [0, 0, 5, 5]}
mask2 = {"segment": np.ones((10, 10), dtype=bool), "area": 100, "bbox": [0, 0, 10, 10]}
mask3 = {"segment": np.ones((3, 3), dtype=bool), "area": 9, "bbox": [0, 0, 3, 3]}
mock_generator = Mock()
mock_generator.generate.return_value = [mask1, mask2, mask3]
result = segment_image(
mock_generator,
Image.new("RGB", (100, 100)),
min_area=1,
max_masks=10,
)
# Should be sorted by area descending
assert result[0]["area"] == 100
assert result[1]["area"] == 25
assert result[2]["area"] == 9
class TestExtractMaskedRegion:
"""Test suite for extracting masked regions from images."""
def test_extract_masked_region_binary(self):
"""Test extracting masked region with binary mask."""
from utils.sam import extract_masked_region
# Create a simple image
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
# Create a binary mask (half kept, half masked)
mask = np.zeros((10, 10), dtype=bool)
mask[:, :5] = True
result = extract_masked_region(image, mask)
# Check that left half is red, right half is black
result_np = np.array(result)
left_half = result_np[:, :5, :]
right_half = result_np[:, 5:, :]
assert np.all(left_half == [255, 0, 0])
assert np.all(right_half == [0, 0, 0])
def test_extract_masked_region_all_masked(self):
"""Test extracting when entire image is masked."""
from utils.sam import extract_masked_region
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
mask = np.ones((10, 10), dtype=bool)
result = extract_masked_region(image, mask)
result_np = np.array(result)
# Entire image should be preserved
assert np.all(result_np == [255, 0, 0])
def test_extract_masked_region_all_zero_mask(self):
"""Test extracting when mask is all zeros."""
from utils.sam import extract_masked_region
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
mask = np.zeros((10, 10), dtype=bool)
result = extract_masked_region(image, mask)
result_np = np.array(result)
# Entire image should be black
assert np.all(result_np == [0, 0, 0])

View File

@@ -0,0 +1,121 @@
"""Tests for scene scoring algorithm in multi-object retrieval."""
import pytest
from benchmarks.tasks.multi_object_retrieval import _compute_scene_score
class TestSceneScoringAlgorithm:
"""Test suite for scene scoring with co-occurrence penalty."""
def test_scene_score_basic(self):
"""Test basic scene scoring with single match."""
query_object_ids = ["obj_1", "obj_2", "obj_3"]
# Scene A has obj_1
retrieved_results = {
"scene_A": [("distance_1", "obj_1")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Hit rate = 1/3, similarity = 1/(1+distance_1)
assert "scene_A" in scores
assert scores["scene_A"] > 0
def test_scene_score_no_match(self):
"""Test scene scoring when no objects match."""
query_object_ids = ["obj_1", "obj_2", "obj_3"]
retrieved_results = {
"scene_A": [("distance_1", "other_obj")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
assert scores["scene_A"] == 0.0
def test_scene_score_multiple_scenes(self):
"""Test scoring across multiple scenes."""
query_object_ids = ["obj_1", "obj_2"]
retrieved_results = {
"scene_A": [("0.1", "obj_1")],
"scene_B": [("0.1", "obj_2")],
"scene_C": [("0.1", "other")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Scenes with matches should have positive scores
assert scores["scene_A"] > 0
assert scores["scene_B"] > 0
# Scene C has no match, score should be 0
assert scores["scene_C"] == 0.0
def test_scene_score_gamma_zero(self):
"""Test scoring with gamma=0 (no penalty)."""
query_object_ids = ["obj_1", "obj_2", "obj_3", "obj_4", "obj_5"]
retrieved_results = {
"scene_A": [("0.1", "obj_1")],
}
scores_gamma_0 = _compute_scene_score(query_object_ids, retrieved_results, gamma=0.0)
scores_gamma_1 = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# With gamma=0, hit_rate^0 = 1, so score = similarity
# With gamma=1, hit_rate^1 = 1/5, so score = similarity * 1/5
# scores_gamma_0 should be larger
assert scores_gamma_0["scene_A"] > scores_gamma_1["scene_A"]
def test_scene_score_multiple_matches(self):
"""Test scoring when scene has multiple matching objects."""
query_object_ids = ["obj_1", "obj_2"]
retrieved_results = {
"scene_A": [("0.1", "obj_1"), ("0.2", "obj_2")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Both objects match, hit_rate = 2/2 = 1.0
# Score = (1/(1+0.1) + 1/(1+0.2)) * 1.0
expected_similarity = 1.0 / 1.1 + 1.0 / 1.2
assert abs(scores["scene_A"] - expected_similarity) < 0.01
def test_scene_score_distance_to_similarity(self):
"""Test that smaller distance yields higher score."""
query_object_ids = ["obj_1"]
retrieved_results = {
"scene_close": [("0.01", "obj_1")],
"scene_far": [("10.0", "obj_1")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# Closer scene should have higher score
assert scores["scene_close"] > scores["scene_far"]
def test_scene_score_empty_results(self):
"""Test scoring with empty retrieved results."""
query_object_ids = ["obj_1", "obj_2"]
retrieved_results = {}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
assert scores == {}
def test_scene_score_empty_query(self):
"""Test scoring with empty query objects."""
query_object_ids = []
retrieved_results = {
"scene_A": [("0.1", "obj_1")],
}
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
# With empty query, no scenes should have positive score
assert all(score == 0.0 for score in scores.values())

100
mini-nav/utils/sam.py Normal file
View File

@@ -0,0 +1,100 @@
"""SAM (Segment Anything Model) utilities for object segmentation."""
from pathlib import Path
from typing import Any
import numpy as np
import torch
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
def load_sam_model(
model_name: str = "facebook/sam2.1-hiera-large",
device: str = "cuda",
checkpoint_dir: Path | None = None,
) -> tuple[Any, Any]:
"""Load SAM 2.1 model and mask generator.
Args:
model_name: SAM model name (currently supports facebook/sam2.1-hiera-*).
device: Device to load model on (cuda or cpu).
checkpoint_dir: Optional directory for model checkpoint cache.
Returns:
Tuple of (sam_model, mask_generator).
"""
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
# Build SAM2 model
sam_model = build_sam2(model_name, device=device)
# Create automatic mask generator
mask_generator = SAM2AutomaticMaskGenerator(sam_model)
return sam_model, mask_generator
def segment_image(
mask_generator: Any,
image: Image.Image,
min_area: int = 32 * 32,
max_masks: int = 5,
) -> list[dict[str, Any]]:
"""Segment image using SAM to extract object masks.
Args:
mask_generator: SAM2AutomaticMaskGenerator instance.
image: PIL Image to segment.
min_area: Minimum mask area threshold in pixels.
max_masks: Maximum number of masks to return.
Returns:
List of mask dictionaries with keys:
- segment: Binary mask (numpy array)
- area: Mask area in pixels
- bbox: Bounding box [x, y, width, height]
- predicted_iou: Model's confidence in the mask
- stability_score: Stability score for the mask
"""
# Convert PIL Image to numpy array
image_np = np.array(image.convert("RGB"))
# Generate masks
masks = mask_generator.generate(image_np)
if not masks:
return []
# Filter by minimum area
filtered_masks = [m for m in masks if m["area"] >= min_area]
if not filtered_masks:
return []
# Sort by area (largest first) and limit to max_masks
sorted_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True)
return sorted_masks[:max_masks]
def extract_masked_region(
image: Image.Image,
mask: np.ndarray,
) -> Image.Image:
"""Extract masked region from image.
Args:
image: Original PIL Image.
mask: Binary mask as numpy array (True = keep).
Returns:
PIL Image with only the masked region visible.
"""
image_np = np.array(image.convert("RGB"))
# Apply mask
masked_np = image_np * mask[:, :, np.newaxis]
return Image.fromarray(masked_np.astype(np.uint8))

View File

@@ -3,13 +3,16 @@ name = "mini-nav"
version = "0.1.0" version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.13"
dependencies = [ dependencies = [
"accelerate>=1.12.0", "accelerate>=1.12.0",
"dash>=3.4.0", "dash>=3.4.0",
"dash-ag-grid>=33.3.3", "dash-ag-grid>=33.3.3",
"dash-mantine-components>=2.5.1", "dash-mantine-components>=2.5.1",
"datasets>=4.5.0", "datasets>=4.5.0",
"habitat-baselines>=0.3.320250127",
"habitat-lab>=0.3.320250127",
"habitat-sim-uv-wheels-experimental>=0.2.2a2",
"httpx[socks]>=0.28.1", "httpx[socks]>=0.28.1",
"lancedb>=0.27.1", "lancedb>=0.27.1",
"polars[database,numpy,pandas,pydantic]>=1.37.1", "polars[database,numpy,pandas,pydantic]>=1.37.1",

4722
uv.lock generated

File diff suppressed because it is too large Load Diff