mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-17 22:45:32 +08:00
Compare commits
7 Commits
e832f9d656
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 431f6844ef | |||
|
34235c605d
|
|||
|
b39ee74e99
|
|||
|
2466ab28cd
|
|||
|
4da08dc3d3
|
|||
|
c8dc5f9301
|
|||
|
bf02a05ffc
|
6
.gitignore
vendored
6
.gitignore
vendored
@@ -211,9 +211,11 @@ datasets/
|
||||
data/
|
||||
deps/
|
||||
outputs/
|
||||
|
||||
# Vibe Coding
|
||||
.sisyphus
|
||||
.claude/
|
||||
CLAUDE.md
|
||||
.claude/settings.local.json
|
||||
openspec/changes/
|
||||
|
||||
# Devenv
|
||||
.devenv*
|
||||
|
||||
13
.justfile
13
.justfile
@@ -1,13 +0,0 @@
|
||||
activate:
|
||||
micromamba activate ./.venv
|
||||
|
||||
update-venv:
|
||||
micromamba env export --no-builds | grep -v "prefix" > venv.yaml
|
||||
|
||||
download-test:
|
||||
python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path data/
|
||||
python -m habitat_sim.utils.datasets_download --uids habitat_test_pointnav_dataset --data-path data/
|
||||
python -m habitat_sim.utils.datasets_download --uids replica_cad_dataset --data-path data/
|
||||
python -m habitat_sim.utils.datasets_download --uids rearrange_dataset_v2 --data-path data/
|
||||
python -m habitat_sim.utils.datasets_download --uids hab_fetch --data-path data/
|
||||
python -m habitat_sim.utils.datasets_download --uids ycb --data-path data/
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.10
|
||||
3.13
|
||||
|
||||
104
AGENTS.md
Normal file
104
AGENTS.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Memorix — Automatic Memory Rules
|
||||
|
||||
You have access to Memorix memory tools. Follow these rules to maintain persistent context across sessions.
|
||||
|
||||
## RULE 1: Session Start — Load Context
|
||||
|
||||
At the **beginning of every conversation**, BEFORE responding to the user:
|
||||
|
||||
1. Call `memorix_session_start` to get the previous session summary and key memories (this is a direct read, not a search — no fragmentation risk)
|
||||
2. Then call `memorix_search` with a query related to the user's first message for additional context
|
||||
3. If search results are found, use `memorix_detail` to fetch the most relevant ones
|
||||
4. Reference relevant memories naturally — the user should feel you "remember" them
|
||||
|
||||
## RULE 2: Store Important Context
|
||||
|
||||
**Proactively** call `memorix_store` when any of the following happen:
|
||||
|
||||
### What MUST be recorded:
|
||||
- Architecture/design decisions → type: `decision`
|
||||
- Bug identified and fixed → type: `problem-solution`
|
||||
- Unexpected behavior or gotcha → type: `gotcha`
|
||||
- Config changed (env vars, ports, deps) → type: `what-changed`
|
||||
- Feature completed or milestone → type: `what-changed`
|
||||
- Trade-off discussed with conclusion → type: `trade-off`
|
||||
|
||||
### What should NOT be recorded:
|
||||
- Simple file reads, greetings, trivial commands (ls, pwd, git status)
|
||||
|
||||
### Use topicKey for evolving topics:
|
||||
For decisions, architecture docs, or any topic that evolves over time, ALWAYS use `topicKey` parameter.
|
||||
This ensures the memory is UPDATED instead of creating duplicates.
|
||||
Use `memorix_suggest_topic_key` to generate a stable key.
|
||||
|
||||
Example: `topicKey: "architecture/auth-model"` — subsequent stores with the same key update the existing memory.
|
||||
|
||||
### Track progress with the progress parameter:
|
||||
When working on features or tasks, include the `progress` parameter:
|
||||
```json
|
||||
{
|
||||
"progress": {
|
||||
"feature": "user authentication",
|
||||
"status": "in-progress",
|
||||
"completion": 60
|
||||
}
|
||||
}
|
||||
```
|
||||
Status values: `in-progress`, `completed`, `blocked`
|
||||
|
||||
## RULE 3: Resolve Completed Memories
|
||||
|
||||
When a task is completed, a bug is fixed, or information becomes outdated:
|
||||
|
||||
1. Call `memorix_resolve` with the observation IDs to mark them as resolved
|
||||
2. Resolved memories are hidden from default search, preventing context pollution
|
||||
|
||||
This is critical — without resolving, old bug reports and completed tasks will keep appearing in future searches.
|
||||
|
||||
## RULE 4: Session End — Store Decision Chain Summary
|
||||
|
||||
When the conversation is ending, create a **decision chain summary** (not just a checklist):
|
||||
|
||||
1. Call `memorix_store` with type `session-request` and `topicKey: "session/latest-summary"`:
|
||||
|
||||
**Required structure:**
|
||||
```
|
||||
## Goal
|
||||
[What we were working on — specific, not vague]
|
||||
|
||||
## Key Decisions & Reasoning
|
||||
- Chose X because Y. Rejected Z because [reason].
|
||||
- [Every architectural/design decision with WHY]
|
||||
|
||||
## What Changed
|
||||
- [File path] — [what changed and why]
|
||||
|
||||
## Current State
|
||||
- [What works now, what's pending]
|
||||
- [Any blockers or risks]
|
||||
|
||||
## Next Steps
|
||||
- [Concrete next actions, in priority order]
|
||||
```
|
||||
|
||||
**Critical: Include the "Key Decisions & Reasoning" section.** Without it, the next AI session will lack the context to understand WHY things were done a certain way and may suggest conflicting approaches.
|
||||
|
||||
2. Call `memorix_resolve` on any memories for tasks completed in this session
|
||||
|
||||
## RULE 5: Compact Awareness
|
||||
|
||||
Memorix automatically compacts memories on store:
|
||||
- **With LLM API configured:** Smart dedup — extracts facts, compares with existing, merges or skips duplicates
|
||||
- **Without LLM (free mode):** Heuristic dedup — uses similarity scores to detect and merge duplicate memories
|
||||
- **You don't need to manually deduplicate.** Just store naturally and compact handles the rest.
|
||||
- If you notice excessive duplicate memories, call `memorix_deduplicate` for batch cleanup.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Use concise titles** (~5-10 words) and structured facts
|
||||
- **Include file paths** in filesModified when relevant
|
||||
- **Include related concepts** for better searchability
|
||||
- **Always use topicKey** for recurring topics to prevent duplicates
|
||||
- **Always resolve** completed tasks and fixed bugs
|
||||
- **Always include reasoning** — "chose X because Y" is 10x more valuable than "did X"
|
||||
- Search defaults to `status="active"` — use `status="all"` to include resolved memories
|
||||
125
CLAUDE.md
Normal file
125
CLAUDE.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# Project Spec & Rules
|
||||
|
||||
## 代码规范
|
||||
|
||||
### Google风格代码
|
||||
详细参阅:https://raw.githubusercontent.com/shendeguize/GooglePythonStyleGuideCN/refs/heads/master/README.md
|
||||
|
||||
### 代码编写原则
|
||||
- 简洁,清晰易懂,最小化实现
|
||||
- 条件或循环分支不能超过三层,提前Return以减少分支的出现
|
||||
- 变量说明注释、条件或循环分支注释完全
|
||||
- 无需向后兼容,避免添加过多功能
|
||||
- 先编写测试集,再实现代码
|
||||
- 实现测试集后,先询问用户意见,用户确认后才能继续
|
||||
- 如非用户要求,无需编写基准测试代码
|
||||
- 英文注释,中文文档
|
||||
- 完成代码编写后,在文档的框架不变的情况下更新文档,如CLAUDE.md
|
||||
|
||||
### 测试编写原则
|
||||
- 精简、干净、快速
|
||||
- 核心关键逻辑或算法必须测试
|
||||
- 需要加载transformer模型进行验证的测试与无需加载模型的测试分离
|
||||
- 无需编写测试集的情况
|
||||
- UI界面相关的代码
|
||||
- 过于复杂或耗时的逻辑
|
||||
- 基准测试相关
|
||||
|
||||
### 关键词说明
|
||||
- 确认:用户认同当前的实现方案或测试集实现,即可以开始工作
|
||||
- 继续:用户需要你重读上下文,继续未完成的工作
|
||||
|
||||
### 文档更新说明
|
||||
仅在工程目录变化时,更新此文档的目录说明部分。
|
||||
如需修改其他部分,请先询问,在进行修改。
|
||||
|
||||
## 工程说明
|
||||
使用UV管理整个工程,pytest用于测试,justfile用于快捷命令,jujutsu用于版本管理。
|
||||
|
||||
### 目录说明
|
||||
|
||||
**核心模块**
|
||||
- mini-nav/main.py — CLI 入口 (Typer)
|
||||
- mini-nav/database.py — LanceDB 单例管理,用于向量存储与检索
|
||||
- mini-nav/feature_retrieval.py — DINOv2 图像特征提取与检索
|
||||
|
||||
**源代码目录 (mini-nav/)**
|
||||
- mini-nav/configs/ — 配置管理 (Pydantic + YAML)
|
||||
- mini-nav/commands/ — CLI 命令 (train, benchmark, visualize, generate)
|
||||
- mini-nav/compressors/ — 特征压缩算法
|
||||
- hash_compressor.py — 哈希压缩器与训练loss
|
||||
- pipeline.py — 压缩流水线(整合 DINO 特征提取)
|
||||
- train.py — 压缩器训练脚本
|
||||
- mini-nav/data_loading/ — 数据加载与合成
|
||||
- loader.py — 数据加载器
|
||||
- insdet_scenes.py — InsDet场景数据集加载
|
||||
- synthesizer.py — 场景合成器
|
||||
- mini-nav/utils/ — 工具函数
|
||||
- feature_extractor.py — 特征提取工具
|
||||
- sam.py — SAM 2.1 分割工具
|
||||
- mini-nav/tests/ — pytest 测试集
|
||||
- mini-nav/benchmarks/ — 基准测试 (recall@k)
|
||||
- tasks/
|
||||
- multi_object_retrieval.py — 多目标检索基准任务
|
||||
- mini-nav/visualizer/ — Dash + Plotly 可视化应用
|
||||
|
||||
**数据目录**
|
||||
- datasets/ — 数据集目录
|
||||
- outputs/ — 默认输出目录 (数据库、模型权重等)
|
||||
|
||||
### Python库
|
||||
详细可查询pyproject.toml或使用`uv pip list`获取详细的库信息,请基于目前的库实现功能。
|
||||
如需添加新库,请先询问,用户确认后才能使用`uv add <package>`新增库。
|
||||
|
||||
## 版本管理 (Jujutsu 特有)
|
||||
本项目使用 Jujutsu (jj) 进行版本控制,并配套 Memorix MCP 作为架构决策与思维轨迹的持久化中心。
|
||||
|
||||
- 技能调用: 必须使用 jujutsu 相关工具技能来执行分支、提交、修改(describe)等操作,禁止直接通过 Shell 执行冗长的 Git 兼容指令。
|
||||
- 描述规范 (jj desc):
|
||||
- 执行 jj desc 时,首行必须是精简的变更标题。
|
||||
- 空一行后,仅记录改动的核心业务点。
|
||||
- 语言使用英文进行描述
|
||||
- 禁忌: 禁止在 jj 描述中堆砌复杂的算法逻辑或长篇的设计决策。
|
||||
- 记忆联动 (Memorix 优先):
|
||||
- 凡涉及架构变更、算法决策或重构逻辑,在执行 jj desc 之前,必须先调用 memorix_store (或对应的添加方法)。
|
||||
- 关联标记: 在 Memorix 的存储记录中,必须强制包含当前变更的 jj change ID,以便实现从代码变更到思维链的完美映射。
|
||||
- 检索逻辑: 在处理需要深入理解上下文的任务时,应主动调用 memorix_search 检索相关的历史 change_id 决策。
|
||||
- 无感记录原则:
|
||||
- 严禁在工程目录下生成任何独立的 change_log.md 或 AI 自动化文档。
|
||||
- 所有关于“为什么这样改”的知识,应当流向 jj 的原子化提交描述或 Memorix 的知识图谱库。
|
||||
|
||||
### 描述示例
|
||||
```text
|
||||
refactor(compressors): Simplify module by removing SAM/DINO separation code
|
||||
|
||||
- Remove dino_compressor.py and segament_compressor.py
|
||||
- Rewrite pipeline.py to inline DINO into HashPipeline
|
||||
- Maintain backward compatibility: SAMHashPipeline alias
|
||||
- Update tests and benchmark.py
|
||||
```
|
||||
|
||||
### 提交步骤
|
||||
- 执行`jj diff --no-pager`获取当前所有更改
|
||||
- 根据更改内容,与openspec生成的相关文档进行总结,重点在于更改内容及其决策逻辑
|
||||
- 调用记忆功能,如Memorix记忆先前总结的内容
|
||||
- 遵循描述规范,使用jj进行更改的描述
|
||||
- 执行`jj new`开启一个新的更改
|
||||
|
||||
## 记忆管理 (Memorix MCP)
|
||||
本项目使用 Memorix 作为核心上下文引擎,用于存储架构决策、复杂逻辑关联和历史重构原因。
|
||||
|
||||
### 记忆写入准则
|
||||
- 主动记录: 在完成以下操作后,必须调用 `memorix.store`:
|
||||
- 用户确认后的核心架构变更(例如:LanceDB 的索引策略)。
|
||||
- 复杂的 bug 修复逻辑(记录“为什么”这么修,防止回滚)。
|
||||
- 用户在对话中表达的明确偏好(例如:对特定 Python 库的厌恶)。
|
||||
- 代码的修改及其决策逻辑(例如:对于用户特定需求导致的更改)。
|
||||
- 结构化存储: 存储时请使用 `[Category: Topic] Description` 的格式,确保检索效率。
|
||||
|
||||
### 记忆检索准则
|
||||
- 冷启动检索: 每一轮新对话开始或切换到新任务时,优先调用 `memorix.search` 关键词(如 "project_architecture", "database_schema"),以确保不偏离既有设计。
|
||||
- 防止幻觉: 如果对某个旧功能的实现细节不确定,先检索记忆,禁止凭空猜测。
|
||||
|
||||
### 内存与冗余控制
|
||||
- 精简描述: 存入 Memorix 的信息必须精简,严禁存入整段代码块,仅存储“逻辑描述”和“决策依据”。
|
||||
- 清理逻辑: 发现记忆库中存在与当前代码事实冲突的旧信息时,应主动提示用户进行更新或覆盖。
|
||||
356
mini-nav/benchmarks/tasks/multi_object_retrieval.py
Normal file
356
mini-nav/benchmarks/tasks/multi_object_retrieval.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""Multi-object retrieval benchmark task.
|
||||
|
||||
This benchmark evaluates retrieval accuracy using multiple objects from a cropped
|
||||
scene region. It uses SAM for object segmentation, DINO+Hash pipeline for feature
|
||||
extraction, and LanceDB for vector storage with scene-level score aggregation.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from benchmarks.base import BaseBenchmarkTask
|
||||
from benchmarks.tasks.registry import RegisterTask
|
||||
from configs.models import BenchmarkTaskConfig
|
||||
from rich.progress import track
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import BitImageProcessorFast
|
||||
from utils.feature_extractor import extract_single_image_feature
|
||||
from utils.sam import load_sam_model, segment_image
|
||||
from utils.common import get_device
|
||||
|
||||
|
||||
def _build_object_schema(vector_dim: int) -> pa.Schema:
|
||||
"""Build PyArrow schema for object-level vectors.
|
||||
|
||||
Args:
|
||||
vector_dim: Feature vector dimension.
|
||||
|
||||
Returns:
|
||||
PyArrow schema with id, image_id, object_id, category, and vector fields.
|
||||
"""
|
||||
return pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int32()),
|
||||
pa.field("image_id", pa.string()),
|
||||
pa.field("object_id", pa.string()),
|
||||
pa.field("category", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), vector_dim)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _compute_scene_score(
|
||||
query_object_ids: list[str],
|
||||
retrieved_results: dict[str, list[tuple[float, str]]],
|
||||
gamma: float,
|
||||
) -> dict[str, float]:
|
||||
"""Compute scene-level scores using co-occurrence penalty.
|
||||
|
||||
Args:
|
||||
query_object_ids: List of query object IDs.
|
||||
retrieved_results: Dict mapping image_id to list of (distance, object_id) results.
|
||||
gamma: Co-occurrence penalty exponent.
|
||||
|
||||
Returns:
|
||||
Dict mapping image_id to computed scene score.
|
||||
"""
|
||||
scene_scores: dict[str, float] = {}
|
||||
|
||||
for image_id, results in retrieved_results.items():
|
||||
# Build a set of retrieved object IDs for this scene
|
||||
retrieved_ids = {obj_id for _, obj_id in results}
|
||||
|
||||
# Count how many query objects are found in this scene
|
||||
matched_count = sum(1 for q_id in query_object_ids if q_id in retrieved_ids)
|
||||
|
||||
if matched_count == 0:
|
||||
scene_scores[image_id] = 0.0
|
||||
continue
|
||||
|
||||
# Sum of best similarities (using distance as similarity: smaller = better)
|
||||
# We use 1/(1+distance) to convert distance to similarity
|
||||
similarities = []
|
||||
for dist, obj_id in results:
|
||||
if obj_id in query_object_ids:
|
||||
sim = 1.0 / (1.0 + dist)
|
||||
similarities.append(sim)
|
||||
|
||||
sum_similarity = sum(similarities) if similarities else 0.0
|
||||
|
||||
# Hit rate: ratio of matched objects
|
||||
hit_rate = matched_count / len(query_object_ids)
|
||||
|
||||
# Final score: sum_similarity * (hit_rate)^gamma
|
||||
score = sum_similarity * (hit_rate ** gamma)
|
||||
scene_scores[image_id] = score
|
||||
|
||||
return scene_scores
|
||||
|
||||
|
||||
@RegisterTask("multi-object-retrieval")
|
||||
class MultiObjectRetrievalTask(BaseBenchmarkTask):
|
||||
"""Multi-object retrieval benchmark task."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize multi-object retrieval task.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters from BenchmarkTaskConfig.
|
||||
"""
|
||||
# Use config from kwargs or load default config
|
||||
if kwargs:
|
||||
config_dict = kwargs
|
||||
else:
|
||||
config = BenchmarkTaskConfig(type="multi-object-retrieval")
|
||||
config_dict = config.model_dump()
|
||||
|
||||
super().__init__(**config_dict)
|
||||
self.config = BenchmarkTaskConfig(**config_dict)
|
||||
|
||||
# SAM settings from ModelConfig (passed via kwargs or use defaults)
|
||||
self.sam_model = kwargs.get("sam_model", "facebook/sam2.1-hiera-large")
|
||||
self.min_mask_area = kwargs.get("sam_min_mask_area", 32 * 32)
|
||||
self.max_masks_per_image = kwargs.get("sam_max_masks", 5)
|
||||
|
||||
# Lazy-loaded resources
|
||||
self._sam_model = None
|
||||
self._mask_generator = None
|
||||
|
||||
@property
|
||||
def sam_model(self) -> Any:
|
||||
"""Lazy-load SAM model."""
|
||||
if self._sam_model is None:
|
||||
self._sam_model, self._mask_generator = load_sam_model(
|
||||
model_name=self.sam_model,
|
||||
device=str(get_device()),
|
||||
)
|
||||
return self._sam_model
|
||||
|
||||
@property
|
||||
def mask_generator(self) -> Any:
|
||||
"""Lazy-load mask generator."""
|
||||
if self._mask_generator is None:
|
||||
self._sam_model, self._mask_generator = load_sam_model(
|
||||
model_name=self.sam_model,
|
||||
device=str(get_device()),
|
||||
)
|
||||
return self._mask_generator
|
||||
|
||||
def build_database(
|
||||
self,
|
||||
model: nn.Module,
|
||||
processor: BitImageProcessorFast,
|
||||
train_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""Build the evaluation database with object-level vectors.
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
train_dataset: Training dataset.
|
||||
table: LanceDB table to store features.
|
||||
batch_size: Batch size for DataLoader.
|
||||
"""
|
||||
# Infer vector dimension from a sample
|
||||
sample = train_dataset[0]
|
||||
sample_image = sample["image"]
|
||||
|
||||
# Get vector dimension by running a forward pass
|
||||
vector_dim = self._infer_vector_dim(processor, model, sample_image)
|
||||
expected_schema = _build_object_schema(vector_dim)
|
||||
|
||||
# Check schema compatibility
|
||||
if table.schema != expected_schema:
|
||||
raise ValueError(
|
||||
f"Table schema mismatch. Expected: {expected_schema}, "
|
||||
f"Got: {table.schema}"
|
||||
)
|
||||
|
||||
# Build database: segment each image, extract features per object
|
||||
record_id = 0
|
||||
records = []
|
||||
|
||||
for idx in track(range(len(train_dataset)), description="Building object database"):
|
||||
item = train_dataset[idx]
|
||||
image = item["image"]
|
||||
image_id = item.get("image_id", f"image_{idx}")
|
||||
|
||||
# Segment image using SAM
|
||||
masks = segment_image(
|
||||
self.mask_generator,
|
||||
image,
|
||||
min_area=self.min_mask_area,
|
||||
max_masks=self.max_masks_per_image,
|
||||
)
|
||||
|
||||
if not masks:
|
||||
continue
|
||||
|
||||
# Extract features for each mask
|
||||
for mask_idx, mask_info in enumerate(masks):
|
||||
# Extract masked region
|
||||
masked_image = self._apply_mask(image, mask_info["segment"])
|
||||
|
||||
# Extract feature vector
|
||||
vector = extract_single_image_feature(processor, model, masked_image)
|
||||
|
||||
# Create object ID
|
||||
object_id = f"{image_id}_obj_{mask_idx}"
|
||||
category = mask_info.get("category", "unknown")
|
||||
|
||||
records.append({
|
||||
"id": record_id,
|
||||
"image_id": image_id,
|
||||
"object_id": object_id,
|
||||
"category": category,
|
||||
"vector": vector,
|
||||
})
|
||||
record_id += 1
|
||||
|
||||
# Add all records to table
|
||||
if records:
|
||||
table.add(records)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
model: nn.Module,
|
||||
processor: BitImageProcessorFast,
|
||||
test_dataset: Any,
|
||||
table: lancedb.table.Table,
|
||||
batch_size: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Evaluate the model on the test dataset.
|
||||
|
||||
Args:
|
||||
model: Feature extraction model.
|
||||
processor: Image preprocessor.
|
||||
test_dataset: Test dataset.
|
||||
table: LanceDB table to search against.
|
||||
batch_size: Batch size for DataLoader.
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation results with keys:
|
||||
- accuracy: Recall@K accuracy (0.0 ~ 1.0)
|
||||
- correct: Number of correct predictions
|
||||
- total: Total number of test samples
|
||||
- top_k: The K value used
|
||||
"""
|
||||
top_k = self.config.top_k_per_object
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for idx in track(range(len(test_dataset)), description=f"Evaluating Recall@{top_k}"):
|
||||
item = test_dataset[idx]
|
||||
image = item["image"]
|
||||
target_image_id = item.get("image_id", f"image_{idx}")
|
||||
|
||||
# Segment query image
|
||||
masks = segment_image(
|
||||
self.mask_generator,
|
||||
image,
|
||||
min_area=self.min_mask_area,
|
||||
max_masks=self.max_masks_per_image,
|
||||
)
|
||||
|
||||
if not masks:
|
||||
continue
|
||||
|
||||
# Randomly sample query objects
|
||||
num_query = min(self.config.num_query_objects, len(masks))
|
||||
query_masks = random.sample(masks, num_query)
|
||||
|
||||
# Extract features and search for each query object
|
||||
retrieved_results: dict[str, list[tuple[float, str]]] = {}
|
||||
|
||||
for mask_info in query_masks:
|
||||
# Extract masked region
|
||||
masked_image = self._apply_mask(image, mask_info["segment"])
|
||||
|
||||
# Extract feature vector
|
||||
vector = extract_single_image_feature(processor, model, masked_image)
|
||||
|
||||
# Search in LanceDB
|
||||
results = (
|
||||
table.search(vector)
|
||||
.select(["image_id", "object_id", "_distance"])
|
||||
.limit(top_k)
|
||||
.to_polars()
|
||||
)
|
||||
|
||||
# Aggregate results by scene
|
||||
for row in results.iter_rows():
|
||||
image_id = row["image_id"]
|
||||
object_id = row["object_id"]
|
||||
distance = row["_distance"]
|
||||
|
||||
if image_id not in retrieved_results:
|
||||
retrieved_results[image_id] = []
|
||||
retrieved_results[image_id].append((distance, object_id))
|
||||
|
||||
# Compute scene scores
|
||||
query_object_ids = [m.get("object_id", f"query_obj_{i}") for i, m in enumerate(query_masks)]
|
||||
scene_scores = _compute_scene_score(
|
||||
query_object_ids,
|
||||
retrieved_results,
|
||||
self.config.gamma,
|
||||
)
|
||||
|
||||
# Rank scenes by score
|
||||
ranked_scenes = sorted(scene_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Check if target is in top-K
|
||||
top_k_scenes = [scene_id for scene_id, _ in ranked_scenes[:top_k]]
|
||||
if target_image_id in top_k_scenes:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
"top_k": top_k,
|
||||
}
|
||||
|
||||
def _infer_vector_dim(
|
||||
self,
|
||||
processor: BitImageProcessorFast,
|
||||
model: nn.Module,
|
||||
sample_image: Any,
|
||||
) -> int:
|
||||
"""Infer vector dimension from model output."""
|
||||
vector = extract_single_image_feature(processor, model, sample_image)
|
||||
return len(vector)
|
||||
|
||||
def _apply_mask(self, image: Any, mask: np.ndarray) -> Any:
|
||||
"""Apply mask to image and return masked image.
|
||||
|
||||
Args:
|
||||
image: PIL Image.
|
||||
mask: Binary mask as numpy array.
|
||||
|
||||
Returns:
|
||||
Masked PIL Image.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
image_np = np.array(image.convert("RGB"))
|
||||
# Ensure mask is the right shape
|
||||
if mask.shape != image_np.shape[:2]:
|
||||
from skimage.transform import resize
|
||||
mask_resized = resize(mask, image_np.shape[:2], order=0, anti_aliasing=False)
|
||||
else:
|
||||
mask_resized = mask
|
||||
|
||||
# Apply mask
|
||||
masked_np = image_np * mask_resized[:, :, np.newaxis]
|
||||
return Image.fromarray(masked_np.astype(np.uint8))
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import typer
|
||||
from commands import app
|
||||
@@ -7,15 +7,15 @@ from commands import app
|
||||
@app.command()
|
||||
def benchmark(
|
||||
ctx: typer.Context,
|
||||
model_path: str = typer.Option(
|
||||
model_path: Optional[str] = typer.Option(
|
||||
None, "--model", "-m", help="Path to compressor model weights"
|
||||
),
|
||||
):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from benchmarks import run_benchmark
|
||||
from compressors import DinoCompressor
|
||||
from configs import cfg_manager
|
||||
from transformers import AutoImageProcessor, BitImageProcessorFast
|
||||
from transformers import AutoImageProcessor, AutoModel, BitImageProcessorFast
|
||||
from utils import get_device
|
||||
|
||||
config = cfg_manager.get()
|
||||
@@ -29,7 +29,12 @@ def benchmark(
|
||||
AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device),
|
||||
)
|
||||
|
||||
model = DinoCompressor().to(device)
|
||||
# Load DINO model for feature extraction
|
||||
dino = AutoModel.from_pretrained(model_cfg.dino_model, device_map=device)
|
||||
dino.eval()
|
||||
|
||||
# Optional hash compressor
|
||||
compressor = None
|
||||
if model_path:
|
||||
from compressors import HashCompressor
|
||||
|
||||
@@ -38,7 +43,31 @@ def benchmark(
|
||||
hash_bits=model_cfg.compression_dim,
|
||||
)
|
||||
compressor.load_state_dict(torch.load(model_path))
|
||||
model.compressor = compressor
|
||||
compressor.to(device)
|
||||
compressor.eval()
|
||||
|
||||
# Create wrapper with extract_features method
|
||||
class DinoFeatureExtractor:
|
||||
def __init__(self, dino, compressor=None):
|
||||
self.dino = dino
|
||||
self.compressor = compressor
|
||||
|
||||
def extract_features(self, images: list) -> torch.Tensor:
|
||||
inputs = processor(images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = self.dino(**inputs)
|
||||
features = outputs.last_hidden_state.mean(dim=1)
|
||||
features = F.normalize(features, dim=-1)
|
||||
return features
|
||||
|
||||
def encode(self, images: list) -> torch.Tensor:
|
||||
if self.compressor is None:
|
||||
return self.extract_features(images)
|
||||
tokens = self.dino(**processor(images, return_tensors="pt").to(device)).last_hidden_state
|
||||
_, _, bits = self.compressor(tokens)
|
||||
return bits
|
||||
|
||||
model = DinoFeatureExtractor(dino, compressor)
|
||||
|
||||
run_benchmark(
|
||||
model=model,
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
from .common import BinarySign, bits_to_hash, hamming_distance, hamming_similarity, hash_to_bits
|
||||
from .dino_compressor import DinoCompressor
|
||||
from .hash_compressor import HashCompressor, HashLoss, VideoPositiveMask
|
||||
from .pipeline import SAMHashPipeline, create_pipeline_from_config
|
||||
from .segament_compressor import SegmentCompressor
|
||||
from .pipeline import HashPipeline, SAMHashPipeline, create_pipeline_from_config
|
||||
from .train import train
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"DinoCompressor",
|
||||
"HashCompressor",
|
||||
"HashLoss",
|
||||
"VideoPositiveMask",
|
||||
"SegmentCompressor",
|
||||
"SAMHashPipeline",
|
||||
"HashPipeline",
|
||||
"SAMHashPipeline", # Backward compatibility alias
|
||||
"create_pipeline_from_config",
|
||||
"BinarySign",
|
||||
"hamming_distance",
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
class DinoCompressor(nn.Module):
|
||||
"""DINOv2 feature extractor with optional hash compression.
|
||||
|
||||
When compressor is None: returns normalized DINO embeddings.
|
||||
When compressor is provided: returns binary hash bits for CAM storage.
|
||||
|
||||
Supports both PIL Image input and pre-extracted tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "facebook/dinov2-large",
|
||||
compressor: Optional[nn.Module] = None,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""Initialize DINOv2 extractor.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model name
|
||||
compressor: Optional hash compressor for producing binary codes
|
||||
device: Device to load model on
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Auto detect device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.model_name = model_name
|
||||
self.processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
self.dino = AutoModel.from_pretrained(model_name).to(self.device)
|
||||
self.dino.eval()
|
||||
|
||||
self.compressor = compressor
|
||||
|
||||
def forward(self, inputs):
|
||||
teacher_tokens = self.dino(**inputs).last_hidden_state # [B,N,1024]
|
||||
|
||||
teacher_embed = teacher_tokens.mean(dim=1)
|
||||
teacher_embed = F.normalize(teacher_embed, dim=-1) # [B,1024]
|
||||
|
||||
if self.compressor is None:
|
||||
return teacher_embed
|
||||
|
||||
# HashCompressor returns (logits, hash_codes, bits)
|
||||
_, _, bits = self.compressor(teacher_tokens)
|
||||
return bits # [B, 512] binary bits for CAM
|
||||
|
||||
def extract_features(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Extract DINO features from a list of cropped object images.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images (cropped objects)
|
||||
|
||||
Returns:
|
||||
DINO features [N, feature_dim], normalized
|
||||
"""
|
||||
if len(images) == 0:
|
||||
return torch.empty(0, self.dino.config.hidden_size, device=self.device)
|
||||
|
||||
# Process batch of images
|
||||
inputs = self.processor(images, return_tensors="pt").to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.dino(**inputs)
|
||||
|
||||
# Pool tokens to get global representation
|
||||
features = outputs.last_hidden_state.mean(dim=1) # [N, 1024]
|
||||
features = F.normalize(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
def encode(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Extract features from images and optionally compress to hash codes.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images
|
||||
|
||||
Returns:
|
||||
If compressor is None: DINO features [N, 1024]
|
||||
If compressor is set: Binary hash bits [N, 512]
|
||||
"""
|
||||
if self.compressor is None:
|
||||
return self.extract_features(images)
|
||||
|
||||
# Extract features first
|
||||
features = self.extract_features(images) # [N, 1024]
|
||||
|
||||
# Add sequence dimension for compressor (expects [B, N, dim])
|
||||
features = features.unsqueeze(1) # [N, 1, 1024]
|
||||
|
||||
# Compress to hash codes
|
||||
_, _, bits = self.compressor(features)
|
||||
|
||||
return bits
|
||||
@@ -1,79 +1,65 @@
|
||||
"""Complete pipeline for SAM + DINO + HashCompressor.
|
||||
"""Hash compression pipeline with DINO feature extraction.
|
||||
|
||||
This pipeline extracts object masks from images using SAM2.1,
|
||||
crops the objects, extracts features using DINOv2,
|
||||
and compresses them to binary hash codes using HashCompressor.
|
||||
This pipeline extracts features using DINOv2 and compresses them
|
||||
to binary hash codes using HashCompressor.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from .dino_compressor import DinoCompressor
|
||||
from .hash_compressor import HashCompressor
|
||||
from .segament_compressor import SegmentCompressor
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
|
||||
def create_pipeline_from_config(config) -> "SAMHashPipeline":
|
||||
"""Create SAMHashPipeline from a config object.
|
||||
def create_pipeline_from_config(config) -> "HashPipeline":
|
||||
"""Create HashPipeline from a config object.
|
||||
|
||||
Args:
|
||||
config: Configuration object with model settings
|
||||
|
||||
Returns:
|
||||
Initialized SAMHashPipeline
|
||||
Initialized HashPipeline
|
||||
"""
|
||||
return SAMHashPipeline(
|
||||
sam_model=config.model.sam_model,
|
||||
dino_model=config.model.name,
|
||||
return HashPipeline(
|
||||
dino_model=config.model.dino_model,
|
||||
hash_bits=config.model.compression_dim,
|
||||
sam_min_mask_area=config.model.sam_min_mask_area,
|
||||
sam_max_masks=config.model.sam_max_masks,
|
||||
compressor_path=config.model.compressor_path,
|
||||
device=config.model.device if config.model.device != "auto" else None,
|
||||
)
|
||||
|
||||
|
||||
class SAMHashPipeline(nn.Module):
|
||||
"""Complete pipeline: SAM segmentation + DINO features + Hash compression.
|
||||
class HashPipeline(nn.Module):
|
||||
"""Pipeline: DINO features + Hash compression.
|
||||
|
||||
Pipeline flow:
|
||||
Image -> SAM (extract masks) -> Crop objects -> DINO (features) -> Hash (binary codes)
|
||||
PIL Image -> DINO (features) -> Hash (binary codes)
|
||||
|
||||
Usage:
|
||||
# Initialize with config
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
hash_bits=512,
|
||||
)
|
||||
|
||||
# Process image
|
||||
image = Image.open("path/to/image.jpg")
|
||||
hash_codes = pipeline(image) # [N, 512] binary bits
|
||||
hash_bits = pipeline(image) # [1, 512] binary bits
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: str = "facebook/sam2.1-hiera-large",
|
||||
dino_model: str = "facebook/dinov2-large",
|
||||
hash_bits: int = 512,
|
||||
sam_min_mask_area: int = 100,
|
||||
sam_max_masks: int = 10,
|
||||
compressor_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""Initialize the complete pipeline.
|
||||
"""Initialize the pipeline.
|
||||
|
||||
Args:
|
||||
sam_model: SAM model name from HuggingFace
|
||||
dino_model: DINOv2 model name from HuggingFace
|
||||
hash_bits: Number of bits in hash code
|
||||
sam_min_mask_area: Minimum mask area threshold
|
||||
sam_max_masks: Maximum number of masks to keep
|
||||
compressor_path: Optional path to trained HashCompressor weights
|
||||
device: Device to run models on
|
||||
"""
|
||||
@@ -84,87 +70,101 @@ class SAMHashPipeline(nn.Module):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = torch.device(device)
|
||||
|
||||
# Initialize components
|
||||
self.segmentor = SegmentCompressor(
|
||||
model_name=sam_model,
|
||||
min_mask_area=sam_min_mask_area,
|
||||
max_masks=sam_max_masks,
|
||||
device=device,
|
||||
)
|
||||
self.dino_model = dino_model
|
||||
|
||||
# HashCompressor expects DINO features (1024 dim for dinov2-large)
|
||||
dino_dim = 1024 if "large" in dino_model else 768
|
||||
self.hash_compressor = HashCompressor(
|
||||
input_dim=dino_dim, hash_bits=hash_bits
|
||||
).to(device)
|
||||
# Initialize DINO processor and model
|
||||
self.processor = AutoImageProcessor.from_pretrained(dino_model)
|
||||
self.dino = AutoModel.from_pretrained(dino_model).to(self.device)
|
||||
self.dino.eval()
|
||||
|
||||
# Determine DINO feature dimension
|
||||
self.dino_dim = 1024 if "large" in dino_model else 768
|
||||
|
||||
# Initialize HashCompressor
|
||||
self.hash_compressor = nn.Module() # Placeholder, will be replaced
|
||||
self._init_hash_compressor(hash_bits, compressor_path)
|
||||
|
||||
def _init_hash_compressor(
|
||||
self, hash_bits: int, compressor_path: Optional[str] = None
|
||||
):
|
||||
"""Initialize the hash compressor module.
|
||||
|
||||
This is called during __init__ but we need to replace it properly.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from .hash_compressor import HashCompressor
|
||||
|
||||
compressor = HashCompressor(input_dim=self.dino_dim, hash_bits=hash_bits).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Load pretrained compressor if provided
|
||||
if compressor_path is not None:
|
||||
self.hash_compressor.load_state_dict(
|
||||
torch.load(compressor_path, map_location=device)
|
||||
compressor.load_state_dict(
|
||||
torch.load(compressor_path, map_location=self.device)
|
||||
)
|
||||
print(f"[OK] Loaded HashCompressor from {compressor_path}")
|
||||
|
||||
self.dino = DinoCompressor(
|
||||
model_name=dino_model,
|
||||
compressor=self.hash_compressor,
|
||||
device=device,
|
||||
)
|
||||
# Replace the placeholder
|
||||
self.hash_compressor = compressor
|
||||
|
||||
@property
|
||||
def hash_bits(self):
|
||||
"""Return the number of hash bits."""
|
||||
return self.hash_compressor.hash_bits
|
||||
|
||||
def forward(self, image: Image.Image) -> torch.Tensor:
|
||||
"""Process a single image through the complete pipeline.
|
||||
"""Process a single image through the pipeline.
|
||||
|
||||
Args:
|
||||
image: Input PIL Image
|
||||
|
||||
Returns:
|
||||
Binary hash codes [N, hash_bits] where N is number of detected objects
|
||||
Binary hash codes [1, hash_bits] as int32
|
||||
"""
|
||||
# Step 1: SAM - extract and crop objects
|
||||
cropped_objects = self.segmentor(image)
|
||||
# Extract DINO features
|
||||
inputs = self.processor(image, return_tensors="pt").to(self.device)
|
||||
|
||||
if len(cropped_objects) == 0:
|
||||
# No objects detected, return empty tensor
|
||||
return torch.empty(
|
||||
0, self.hash_compressor.hash_bits, dtype=torch.int32, device=self.device
|
||||
)
|
||||
with torch.no_grad():
|
||||
outputs = self.dino(**inputs)
|
||||
tokens = outputs.last_hidden_state # [1, N, dim]
|
||||
|
||||
# Step 2: DINO - extract features from cropped objects
|
||||
# Step 3: HashCompressor - compress features to binary codes
|
||||
hash_codes = self.dino.encode(cropped_objects)
|
||||
# Compress to hash codes
|
||||
_, _, bits = self.hash_compressor(tokens)
|
||||
|
||||
return hash_codes
|
||||
return bits
|
||||
|
||||
def extract_features(
|
||||
self, image: Image.Image, use_hash: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""Extract features from image with optional hash compression.
|
||||
def encode(self, image: Image.Image) -> torch.Tensor:
|
||||
"""Encode an image to binary hash bits.
|
||||
|
||||
Args:
|
||||
image: Input PIL Image
|
||||
use_hash: If True, return binary hash codes; else return DINO features
|
||||
|
||||
Returns:
|
||||
Features [N, dim] where dim is 1024 (DINO) or 512 (hash)
|
||||
"""
|
||||
cropped_objects = self.segmentor(image)
|
||||
|
||||
if len(cropped_objects) == 0:
|
||||
dim = self.hash_compressor.hash_bits if use_hash else 1024
|
||||
return torch.empty(0, dim, device=self.device)
|
||||
|
||||
if use_hash:
|
||||
return self.dino.encode(cropped_objects)
|
||||
else:
|
||||
return self.dino.extract_features(cropped_objects)
|
||||
|
||||
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
|
||||
"""Extract only masks without full processing (for debugging).
|
||||
Alias for forward().
|
||||
|
||||
Args:
|
||||
image: Input PIL Image
|
||||
|
||||
Returns:
|
||||
List of binary masks [H, W]
|
||||
Binary hash codes [1, hash_bits] as int32
|
||||
"""
|
||||
return self.segmentor.extract_masks(image)
|
||||
return self.forward(image)
|
||||
|
||||
def extract_features(self, image: Image.Image) -> torch.Tensor:
|
||||
"""Extract DINO features from an image.
|
||||
|
||||
Args:
|
||||
image: Input PIL Image
|
||||
|
||||
Returns:
|
||||
DINO features [1, dino_dim], normalized
|
||||
"""
|
||||
inputs = self.processor(image, return_tensors="pt").to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.dino(**inputs)
|
||||
features = outputs.last_hidden_state.mean(dim=1) # [1, dim]
|
||||
features = F.normalize(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
SAMHashPipeline = HashPipeline
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
"""Segment Anything 2 feature extractor with mask filtering and image cropping.
|
||||
|
||||
Extracts object masks from images using SAM2.1, filters by area and confidence,
|
||||
then crops the original image to obtain individual object regions.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
||||
|
||||
|
||||
class SegmentCompressor(nn.Module):
|
||||
"""SAM2.1 based segmenter with mask filtering.
|
||||
|
||||
Extracts object masks from images, filters by area and confidence,
|
||||
and crops the original image to produce individual object patches.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "facebook/sam2.1-hiera-large",
|
||||
min_mask_area: int = 100,
|
||||
max_masks: int = 10,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
"""Initialize SAM2.1 segmenter.
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model name for SAM2.1
|
||||
min_mask_area: Minimum mask pixel area threshold
|
||||
max_masks: Maximum number of masks to keep
|
||||
device: Device to load model on (auto-detect if None)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
self.min_mask_area = min_mask_area
|
||||
self.max_masks = max_masks
|
||||
|
||||
# Auto detect device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.device = torch.device(device)
|
||||
|
||||
# Load SAM model and processor
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = AutoModelForMaskGeneration.from_pretrained(model_name).to(
|
||||
self.device
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, image: Image.Image) -> list[Image.Image]:
|
||||
"""Extract object masks and crop object regions.
|
||||
|
||||
Args:
|
||||
image: Input PIL Image
|
||||
|
||||
Returns:
|
||||
List of cropped object images (one per valid mask)
|
||||
"""
|
||||
# Run SAM inference
|
||||
inputs = self.processor(image, return_tensors="pt").to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# Post-process masks
|
||||
masks = self.processor.post_process_masks(
|
||||
outputs.pred_masks,
|
||||
inputs["original_sizes"],
|
||||
inputs["reshaped_input_sizes"],
|
||||
)[0]
|
||||
|
||||
# Filter masks by area and confidence
|
||||
valid_masks = self._filter_masks(masks)
|
||||
|
||||
if len(valid_masks) == 0:
|
||||
return []
|
||||
|
||||
# Crop object regions from original image
|
||||
cropped_objects = self._crop_objects(image, valid_masks)
|
||||
|
||||
return cropped_objects
|
||||
|
||||
def _filter_masks(self, masks: torch.Tensor) -> list[dict]:
|
||||
"""Filter masks by area and keep top-N.
|
||||
|
||||
Args:
|
||||
masks: Predicted masks [N, H, W]
|
||||
|
||||
Returns:
|
||||
List of mask dictionaries with 'mask' and 'area'
|
||||
"""
|
||||
valid_masks = []
|
||||
|
||||
for mask in masks:
|
||||
# Calculate mask area
|
||||
area = mask.sum().item()
|
||||
|
||||
# Filter by minimum area
|
||||
if area < self.min_mask_area:
|
||||
continue
|
||||
|
||||
valid_masks.append({"mask": mask, "area": area})
|
||||
|
||||
# Sort by area (descending) and keep top-N
|
||||
valid_masks = sorted(valid_masks, key=lambda x: x["area"], reverse=True)
|
||||
valid_masks = valid_masks[: self.max_masks]
|
||||
|
||||
return valid_masks
|
||||
|
||||
def _crop_objects(
|
||||
self, image: Image.Image, masks: list[dict]
|
||||
) -> list[Image.Image]:
|
||||
"""Crop object regions from image using masks.
|
||||
|
||||
Args:
|
||||
image: Original PIL Image
|
||||
masks: List of mask dictionaries
|
||||
|
||||
Returns:
|
||||
List of cropped object images
|
||||
"""
|
||||
# Convert PIL to numpy for processing
|
||||
image_np = np.array(image)
|
||||
h, w = image_np.shape[:2]
|
||||
|
||||
cropped_objects = []
|
||||
|
||||
for mask_info in masks:
|
||||
mask = mask_info["mask"].cpu().numpy()
|
||||
|
||||
# Find bounding box from mask
|
||||
rows = mask.any(axis=1)
|
||||
cols = mask.any(axis=0)
|
||||
|
||||
if not rows.any() or not cols.any():
|
||||
continue
|
||||
|
||||
y_min, y_max = rows.argmax(), h - rows[::-1].argmax() - 1
|
||||
x_min, x_max = cols.argmax(), w - cols[::-1].argmax() - 1
|
||||
|
||||
# Add small padding
|
||||
pad = 5
|
||||
x_min = max(0, x_min - pad)
|
||||
y_min = max(0, y_min - pad)
|
||||
x_max = min(w, x_max + pad)
|
||||
y_max = min(h, y_max + pad)
|
||||
|
||||
# Crop
|
||||
cropped = image.crop((x_min, y_min, x_max, y_max))
|
||||
cropped_objects.append(cropped)
|
||||
|
||||
return cropped_objects
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_masks(self, image: Image.Image) -> list[torch.Tensor]:
|
||||
"""Extract only masks without cropping (for debugging).
|
||||
|
||||
Args:
|
||||
image: Input PIL Image
|
||||
|
||||
Returns:
|
||||
List of binary masks [H, W]
|
||||
"""
|
||||
inputs = self.processor(image, return_tensors="pt").to(self.device)
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
masks = self.processor.post_process_masks(
|
||||
outputs.pred_masks,
|
||||
inputs["original_sizes"],
|
||||
inputs["reshaped_input_sizes"],
|
||||
)[0]
|
||||
|
||||
valid_masks = self._filter_masks(masks)
|
||||
return [m["mask"] for m in valid_masks]
|
||||
@@ -118,6 +118,17 @@ class BenchmarkTaskConfig(BaseModel):
|
||||
type: str = Field(default="retrieval", description="Task type")
|
||||
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
|
||||
|
||||
# Multi-object retrieval specific settings
|
||||
gamma: float = Field(
|
||||
default=1.0, ge=0, description="Co-occurrence penalty exponent"
|
||||
)
|
||||
top_k_per_object: int = Field(
|
||||
default=50, gt=0, description="Top K results per object query"
|
||||
)
|
||||
num_query_objects: int = Field(
|
||||
default=3, gt=0, description="Number of objects to sample from query image"
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""Configuration for benchmark evaluation."""
|
||||
|
||||
64
mini-nav/data_loading/insdet_scenes.py
Normal file
64
mini-nav/data_loading/insdet_scenes.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""InsDet Scenes dataset for multi-object retrieval benchmark."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from benchmarks.base import BaseDataset
|
||||
from data_loading.loader import load_val_dataset
|
||||
|
||||
|
||||
class InsDetScenesDataset(BaseDataset):
|
||||
"""InsDet-FULL/Scenes dataset with easy/hard splits.
|
||||
|
||||
This dataset provides scene images with object annotations from the
|
||||
Instance Detection (InsDet) dataset, supporting easy and hard splits.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenes_dir: Path | str,
|
||||
split: str = "easy",
|
||||
):
|
||||
"""Initialize InsDet Scenes dataset.
|
||||
|
||||
Args:
|
||||
scenes_dir: Path to the InsDet-FULL/Scenes directory.
|
||||
split: Scene split to use ('easy' or 'hard').
|
||||
"""
|
||||
self.scenes_dir = Path(scenes_dir)
|
||||
self.split = split
|
||||
self._dataset = load_val_dataset(self.scenes_dir, split)
|
||||
|
||||
def get_train_split(self) -> Any:
|
||||
"""Get training split (same as test for this dataset).
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset for training.
|
||||
"""
|
||||
return self._dataset
|
||||
|
||||
def get_test_split(self) -> Any:
|
||||
"""Get test/evaluation split.
|
||||
|
||||
Returns:
|
||||
HuggingFace Dataset for testing.
|
||||
"""
|
||||
return self._dataset
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get dataset length."""
|
||||
return len(self._dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, Any]:
|
||||
"""Get a single item from the dataset.
|
||||
|
||||
Args:
|
||||
idx: Index of the item.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- image: PIL Image
|
||||
- image_id: Scene identifier
|
||||
- objects: dict with bbox, category, area, id
|
||||
"""
|
||||
return self._dataset[idx]
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Tests for compressor modules (SAM, DINO, HashCompressor, Pipeline)."""
|
||||
"""Tests for compressor modules (HashCompressor, Pipeline)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from compressors import (
|
||||
BinarySign,
|
||||
DinoCompressor,
|
||||
HashCompressor,
|
||||
HashPipeline,
|
||||
SAMHashPipeline,
|
||||
SegmentCompressor,
|
||||
VideoPositiveMask,
|
||||
bits_to_hash,
|
||||
create_pipeline_from_config,
|
||||
hamming_distance,
|
||||
@@ -124,87 +124,105 @@ class TestHammingMetrics:
|
||||
assert sim.item() == 512 # Max similarity
|
||||
|
||||
|
||||
class TestSegmentCompressor:
|
||||
"""Test suite for SegmentCompressor."""
|
||||
class TestHashLoss:
|
||||
"""Test suite for HashLoss."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_image(self):
|
||||
"""Create a mock PIL image."""
|
||||
img = Image.new("RGB", (224, 224), color="red")
|
||||
return img
|
||||
def test_hash_loss_init(self):
|
||||
"""Verify HashLoss initializes with correct parameters."""
|
||||
from compressors import HashLoss
|
||||
|
||||
def test_segment_compressor_init(self):
|
||||
"""Verify SegmentCompressor initializes with correct parameters."""
|
||||
segmentor = SegmentCompressor(
|
||||
model_name="facebook/sam2.1-hiera-large",
|
||||
min_mask_area=100,
|
||||
max_masks=10,
|
||||
loss_fn = HashLoss(
|
||||
contrastive_weight=1.0,
|
||||
distill_weight=0.5,
|
||||
quant_weight=0.01,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
assert segmentor.model_name == "facebook/sam2.1-hiera-large"
|
||||
assert segmentor.min_mask_area == 100
|
||||
assert segmentor.max_masks == 10
|
||||
assert loss_fn.contrastive_weight == 1.0
|
||||
assert loss_fn.distill_weight == 0.5
|
||||
assert loss_fn.quant_weight == 0.01
|
||||
assert loss_fn.temperature == 0.2
|
||||
|
||||
def test_filter_masks(self):
|
||||
"""Verify mask filtering logic."""
|
||||
# Create segmentor to get default filter params
|
||||
segmentor = SegmentCompressor()
|
||||
def test_hash_loss_forward(self):
|
||||
"""Verify HashLoss computes loss correctly."""
|
||||
from compressors import HashLoss
|
||||
|
||||
# Create mock masks tensor with different areas
|
||||
# Masks shape: [N, H, W]
|
||||
masks = []
|
||||
for area in [50, 200, 150, 300, 10]:
|
||||
mask = torch.zeros(100, 100)
|
||||
mask[:1, :area] = 1 # Create mask with specific area
|
||||
masks.append(mask)
|
||||
loss_fn = HashLoss()
|
||||
|
||||
masks_tensor = torch.stack(masks) # [5, 100, 100]
|
||||
valid = segmentor._filter_masks(masks_tensor)
|
||||
batch_size = 4
|
||||
hash_bits = 512
|
||||
logits = torch.randn(batch_size, hash_bits)
|
||||
hash_codes = torch.sign(logits)
|
||||
teacher_embed = torch.randn(batch_size, 1024)
|
||||
positive_mask = torch.eye(batch_size, dtype=torch.bool)
|
||||
|
||||
# Should filter out 50 and 10 (below min_mask_area=100)
|
||||
# Then keep top 3 (max_masks=10)
|
||||
assert len(valid) == 3
|
||||
# Verify sorted by area (descending)
|
||||
areas = [v["area"] for v in valid]
|
||||
assert areas == sorted(areas, reverse=True)
|
||||
total_loss, components = loss_fn(
|
||||
logits=logits,
|
||||
hash_codes=hash_codes,
|
||||
teacher_embed=teacher_embed,
|
||||
positive_mask=positive_mask,
|
||||
)
|
||||
|
||||
assert "contrastive" in components
|
||||
assert "distill" in components
|
||||
assert "quantization" in components
|
||||
assert "total" in components
|
||||
|
||||
|
||||
class TestDinoCompressor:
|
||||
"""Test suite for DinoCompressor."""
|
||||
class TestVideoPositiveMask:
|
||||
"""Test suite for VideoPositiveMask."""
|
||||
|
||||
def test_dino_compressor_init(self):
|
||||
"""Verify DinoCompressor initializes correctly."""
|
||||
dino = DinoCompressor()
|
||||
def test_from_frame_indices(self):
|
||||
"""Verify positive mask generation from frame indices."""
|
||||
mask_gen = VideoPositiveMask(temporal_window=2)
|
||||
|
||||
assert dino.model_name == "facebook/dinov2-large"
|
||||
frame_indices = torch.tensor([0, 1, 3, 5])
|
||||
|
||||
def test_dino_compressor_with_compressor(self):
|
||||
"""Verify DinoCompressor with HashCompressor."""
|
||||
hash_compressor = HashCompressor(input_dim=1024, hash_bits=512)
|
||||
dino = DinoCompressor(compressor=hash_compressor)
|
||||
mask = mask_gen.from_frame_indices(frame_indices)
|
||||
|
||||
assert dino.compressor is hash_compressor
|
||||
assert mask.shape == (4, 4)
|
||||
# Frame 0 and 1 should be positive (distance 1 <= 2)
|
||||
assert mask[0, 1] == True
|
||||
# Frame 0 and 3 should be negative (distance 3 > 2)
|
||||
assert mask[0, 3] == False
|
||||
|
||||
def test_from_video_ids(self):
|
||||
"""Verify positive mask generation from video IDs and frame indices."""
|
||||
mask_gen = VideoPositiveMask(temporal_window=2)
|
||||
|
||||
video_ids = torch.tensor([0, 0, 1, 1])
|
||||
frame_indices = torch.tensor([0, 1, 0, 1])
|
||||
|
||||
mask = mask_gen.from_video_ids(video_ids, frame_indices)
|
||||
|
||||
assert mask.shape == (4, 4)
|
||||
# Same video and temporally close
|
||||
assert mask[0, 1] == True # video 0, frames 0,1
|
||||
# Different video
|
||||
assert mask[0, 2] == False # video 0 vs 1
|
||||
|
||||
|
||||
class TestSAMHashPipeline:
|
||||
"""Test suite for SAMHashPipeline."""
|
||||
class TestHashPipeline:
|
||||
"""Test suite for HashPipeline."""
|
||||
|
||||
def test_pipeline_init(self):
|
||||
"""Verify pipeline initializes all components."""
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
hash_bits=512,
|
||||
)
|
||||
|
||||
assert isinstance(pipeline.segmentor, SegmentCompressor)
|
||||
assert isinstance(pipeline.dino, DinoCompressor)
|
||||
assert isinstance(pipeline.hash_compressor, HashCompressor)
|
||||
assert pipeline.dino_model == "facebook/dinov2-large"
|
||||
assert pipeline.dino_dim == 1024
|
||||
|
||||
def test_pipeline_hash_bits(self):
|
||||
"""Verify pipeline uses correct hash bits."""
|
||||
pipeline = SAMHashPipeline(hash_bits=256)
|
||||
assert pipeline.hash_compressor.hash_bits == 256
|
||||
pipeline = HashPipeline(hash_bits=256)
|
||||
assert pipeline.hash_bits == 256
|
||||
|
||||
def test_pipeline_alias(self):
|
||||
"""Verify SAMHashPipeline is alias for HashPipeline."""
|
||||
assert SAMHashPipeline is HashPipeline
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
@@ -216,25 +234,21 @@ class TestConfigIntegration:
|
||||
|
||||
pipeline = create_pipeline_from_config(config)
|
||||
|
||||
assert isinstance(pipeline, SAMHashPipeline)
|
||||
assert pipeline.hash_compressor.hash_bits == config.model.compression_dim
|
||||
assert isinstance(pipeline, HashPipeline)
|
||||
assert pipeline.hash_bits == config.model.compression_dim
|
||||
|
||||
def test_config_sam_settings(self):
|
||||
"""Verify config contains SAM settings."""
|
||||
def test_config_settings(self):
|
||||
"""Verify config contains required settings."""
|
||||
config = cfg_manager.load()
|
||||
|
||||
assert hasattr(config.model, "sam_model")
|
||||
assert hasattr(config.model, "sam_min_mask_area")
|
||||
assert hasattr(config.model, "sam_max_masks")
|
||||
assert config.model.sam_model == "facebook/sam2.1-hiera-large"
|
||||
assert config.model.sam_min_mask_area == 100
|
||||
assert config.model.sam_max_masks == 10
|
||||
assert hasattr(config.model, "dino_model")
|
||||
assert hasattr(config.model, "compression_dim")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestPipelineIntegration:
|
||||
"""Integration tests for full pipeline (slow, requires model downloads)."""
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_pipeline_end_to_end(self):
|
||||
"""Test full pipeline with actual models (slow test)."""
|
||||
# Skip if no GPU
|
||||
@@ -245,54 +259,32 @@ class TestPipelineIntegration:
|
||||
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
||||
|
||||
# Initialize pipeline (will download models on first run)
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
hash_bits=512,
|
||||
sam_min_mask_area=100,
|
||||
sam_max_masks=5,
|
||||
)
|
||||
|
||||
# Run pipeline
|
||||
hash_codes = pipeline(image)
|
||||
hash_bits = pipeline(image)
|
||||
|
||||
# Verify output shape
|
||||
assert hash_codes.dim() == 2
|
||||
assert hash_codes.shape[1] == 512
|
||||
assert torch.all((hash_codes == 0) | (hash_codes == 1))
|
||||
assert hash_bits.dim() == 2
|
||||
assert hash_bits.shape[1] == 512
|
||||
assert torch.all((hash_bits == 0) | (hash_bits == 1))
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_features_without_hash(self):
|
||||
"""Test feature extraction without hash compression."""
|
||||
def test_extract_features(self):
|
||||
"""Test feature extraction."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Requires CUDA")
|
||||
|
||||
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
||||
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
pipeline = HashPipeline(
|
||||
dino_model="facebook/dinov2-large",
|
||||
)
|
||||
|
||||
features = pipeline.extract_features(image, use_hash=False)
|
||||
features = pipeline.extract_features(image)
|
||||
|
||||
# Should return DINO features (1024 for large)
|
||||
assert features.dim() == 2
|
||||
assert features.shape[1] == 1024
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_extract_masks_only(self):
|
||||
"""Test mask extraction only."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Requires CUDA")
|
||||
|
||||
image = Image.new("RGB", (640, 480), color=(128, 128, 128))
|
||||
|
||||
pipeline = SAMHashPipeline(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
)
|
||||
|
||||
masks = pipeline.extract_masks(image)
|
||||
|
||||
# Should return a list of masks
|
||||
assert isinstance(masks, list)
|
||||
|
||||
238
mini-nav/tests/test_multi_object_retrieval.py
Normal file
238
mini-nav/tests/test_multi_object_retrieval.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Integration tests for multi-object retrieval benchmark pipeline.
|
||||
|
||||
These tests verify the end-to-end functionality of the multi-object retrieval
|
||||
benchmark, including schema building, database population, and evaluation.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestMultiObjectRetrievalIntegration:
|
||||
"""Integration tests for multi-object retrieval benchmark."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_processor(self):
|
||||
"""Create mock model and processor."""
|
||||
mock_model = Mock()
|
||||
mock_processor = Mock()
|
||||
|
||||
# Mock the feature extraction to return a fixed-size vector
|
||||
def mock_extract(processor, model, image):
|
||||
return [0.1] * 256 # 256-dim vector
|
||||
mock_processor.images = mock_extract
|
||||
|
||||
return mock_model, mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(self):
|
||||
"""Create a mock dataset with images and annotations."""
|
||||
# Create mock items
|
||||
items = []
|
||||
for i in range(3):
|
||||
item = {
|
||||
"image": Image.new("RGB", (224, 224), color=(i * 50, 100, 150)),
|
||||
"image_id": f"scene_{i}",
|
||||
"objects": {
|
||||
"bbox": [[10, 10, 50, 50], [60, 60, 40, 40]],
|
||||
"category": ["object_a", "object_b"],
|
||||
"area": [2500, 1600],
|
||||
"id": [0, 1],
|
||||
},
|
||||
}
|
||||
items.append(item)
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.__len__ = Mock(return_value=len(items))
|
||||
mock_dataset.__getitem__ = lambda self, idx: items[idx]
|
||||
mock_dataset.with_format = lambda fmt: mock_dataset
|
||||
|
||||
return mock_dataset
|
||||
|
||||
def test_build_object_schema(self):
|
||||
"""Test that object schema is built correctly."""
|
||||
from benchmarks.tasks.multi_object_retrieval import _build_object_schema
|
||||
import pyarrow as pa
|
||||
|
||||
vector_dim = 256
|
||||
schema = _build_object_schema(vector_dim)
|
||||
|
||||
assert isinstance(schema, pa.Schema)
|
||||
assert "id" in schema.names
|
||||
assert "image_id" in schema.names
|
||||
assert "object_id" in schema.names
|
||||
assert "category" in schema.names
|
||||
assert "vector" in schema.names
|
||||
|
||||
# Check vector field has correct dimension
|
||||
vector_field = schema.field("vector")
|
||||
assert isinstance(vector_field.type, pa.List)
|
||||
assert vector_field.type.value_type == pa.float32()
|
||||
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
|
||||
def test_build_database_with_mocked_sam(
|
||||
self,
|
||||
mock_segment,
|
||||
mock_load_sam,
|
||||
mock_model_processor,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test database building with mocked SAM segmentation."""
|
||||
from benchmarks.tasks.multi_object_retrieval import (
|
||||
MultiObjectRetrievalTask,
|
||||
_build_object_schema,
|
||||
)
|
||||
|
||||
mock_model, mock_processor = mock_model_processor
|
||||
|
||||
# Mock SAM
|
||||
mock_load_sam.return_value = (Mock(), Mock())
|
||||
mock_segment.return_value = [
|
||||
{
|
||||
"segment": np.ones((224, 224), dtype=bool),
|
||||
"area": 50000,
|
||||
"bbox": [0, 0, 224, 224],
|
||||
}
|
||||
]
|
||||
|
||||
# Create task with config
|
||||
task = MultiObjectRetrievalTask(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
min_mask_area=1024,
|
||||
max_masks_per_image=5,
|
||||
gamma=1.0,
|
||||
top_k_per_object=50,
|
||||
num_query_objects=3,
|
||||
)
|
||||
|
||||
# Create mock table
|
||||
mock_table = Mock()
|
||||
mock_table.schema = _build_object_schema(256)
|
||||
|
||||
# Build database (this should not raise)
|
||||
task.build_database(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
|
||||
|
||||
# Verify table.add was called
|
||||
assert mock_table.add.called
|
||||
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.load_sam_model")
|
||||
@patch("benchmarks.tasks.multi_object_retrieval.segment_image")
|
||||
def test_evaluate_with_mocked_sam(
|
||||
self,
|
||||
mock_segment,
|
||||
mock_load_sam,
|
||||
mock_model_processor,
|
||||
mock_dataset,
|
||||
):
|
||||
"""Test evaluation with mocked SAM segmentation."""
|
||||
from benchmarks.tasks.multi_object_retrieval import (
|
||||
MultiObjectRetrievalTask,
|
||||
_build_object_schema,
|
||||
)
|
||||
|
||||
mock_model, mock_processor = mock_model_processor
|
||||
|
||||
# Mock SAM
|
||||
mock_load_sam.return_value = (Mock(), Mock())
|
||||
mock_segment.return_value = [
|
||||
{
|
||||
"segment": np.ones((224, 224), dtype=bool),
|
||||
"area": 50000,
|
||||
"bbox": [0, 0, 224, 224],
|
||||
"object_id": "query_obj_0",
|
||||
}
|
||||
]
|
||||
|
||||
# Create mock table with search results
|
||||
mock_table = Mock()
|
||||
mock_table.schema = _build_object_schema(256)
|
||||
|
||||
# Mock search to return matching result
|
||||
mock_result = Mock()
|
||||
mock_result.to_polars.return_value = {
|
||||
"image_id": ["scene_0"],
|
||||
"object_id": ["scene_0_obj_0"],
|
||||
"_distance": [0.1],
|
||||
}
|
||||
|
||||
mock_table.search.return_value.select.return_value.limit.return_value = mock_result
|
||||
|
||||
# Create task
|
||||
task = MultiObjectRetrievalTask(
|
||||
sam_model="facebook/sam2.1-hiera-large",
|
||||
min_mask_area=1024,
|
||||
max_masks_per_image=5,
|
||||
gamma=1.0,
|
||||
top_k_per_object=50,
|
||||
num_query_objects=1,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
results = task.evaluate(mock_model, mock_processor, mock_dataset, mock_table, batch_size=1)
|
||||
|
||||
# Verify results structure
|
||||
assert "accuracy" in results
|
||||
assert "correct" in results
|
||||
assert "total" in results
|
||||
assert "top_k" in results
|
||||
assert results["top_k"] == 50
|
||||
|
||||
def test_task_initialization_with_config(self):
|
||||
"""Test task initialization with custom config."""
|
||||
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
|
||||
|
||||
task = MultiObjectRetrievalTask(
|
||||
sam_model="facebook/sam2.1-hiera-small",
|
||||
min_mask_area=500,
|
||||
max_masks_per_image=3,
|
||||
gamma=0.5,
|
||||
top_k_per_object=100,
|
||||
num_query_objects=5,
|
||||
)
|
||||
|
||||
assert task.sam_model == "facebook/sam2.1-hiera-small"
|
||||
assert task.min_mask_area == 500
|
||||
assert task.max_masks_per_image == 3
|
||||
assert task.config.gamma == 0.5
|
||||
assert task.config.top_k_per_object == 100
|
||||
assert task.config.num_query_objects == 5
|
||||
|
||||
def test_task_initialization_defaults(self):
|
||||
"""Test task initialization with default config."""
|
||||
from benchmarks.tasks.multi_object_retrieval import MultiObjectRetrievalTask
|
||||
|
||||
task = MultiObjectRetrievalTask()
|
||||
|
||||
# Check defaults from BenchmarkTaskConfig
|
||||
assert task.config.gamma == 1.0
|
||||
assert task.config.top_k_per_object == 50
|
||||
assert task.config.num_query_objects == 3
|
||||
# SAM settings from ModelConfig defaults
|
||||
assert task.sam_model == "facebook/sam2.1-hiera-large"
|
||||
assert task.min_mask_area == 1024
|
||||
assert task.max_masks_per_image == 5
|
||||
|
||||
|
||||
class TestInsDetScenesDataset:
|
||||
"""Tests for InsDetScenesDataset class."""
|
||||
|
||||
def test_dataset_class_exists(self):
|
||||
"""Test that InsDetScenesDataset can be imported."""
|
||||
from data_loading.insdet_scenes import InsDetScenesDataset
|
||||
|
||||
assert InsDetScenesDataset is not None
|
||||
|
||||
@patch("data_loading.insdet_scenes.load_val_dataset")
|
||||
def test_dataset_loads_correct_split(self, mock_load):
|
||||
"""Test dataset loads correct split."""
|
||||
from data_loading.insdet_scenes import InsDetScenesDataset
|
||||
|
||||
mock_load.return_value = Mock()
|
||||
|
||||
dataset = InsDetScenesDataset("/path/to/scenes", split="easy")
|
||||
|
||||
mock_load.assert_called_once_with("/path/to/scenes", "easy")
|
||||
assert dataset.split == "easy"
|
||||
168
mini-nav/tests/test_sam.py
Normal file
168
mini-nav/tests/test_sam.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Tests for SAM segmentation utilities.
|
||||
|
||||
Note: These tests mock the SAM model loading since SAM requires
|
||||
heavy model weights. The actual SAM integration should be tested
|
||||
separately in integration tests.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class TestSAMSegmentation:
|
||||
"""Test suite for SAM segmentation utilities."""
|
||||
|
||||
def test_segment_image_empty_masks(self):
|
||||
"""Test segment_image returns empty list when no masks generated."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create mock mask generator that returns empty list
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = []
|
||||
|
||||
result = segment_image(mock_generator, Image.new("RGB", (100, 100)))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_segment_image_filters_small_masks(self):
|
||||
"""Test segment_image filters masks below min_area threshold."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create mock masks with different areas
|
||||
small_mask = {
|
||||
"segment": np.zeros((10, 10), dtype=bool),
|
||||
"area": 50, # Below 32*32 = 1024
|
||||
"bbox": [0, 0, 10, 10],
|
||||
"predicted_iou": 0.9,
|
||||
"stability_score": 0.8,
|
||||
}
|
||||
large_mask = {
|
||||
"segment": np.ones((100, 100), dtype=bool),
|
||||
"area": 10000, # Above threshold
|
||||
"bbox": [0, 0, 100, 100],
|
||||
"predicted_iou": 0.95,
|
||||
"stability_score": 0.9,
|
||||
}
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = [small_mask, large_mask]
|
||||
|
||||
result = segment_image(
|
||||
mock_generator,
|
||||
Image.new("RGB", (100, 100)),
|
||||
min_area=32 * 32,
|
||||
max_masks=5,
|
||||
)
|
||||
|
||||
# Should only return the large mask
|
||||
assert len(result) == 1
|
||||
assert result[0]["area"] == 10000
|
||||
|
||||
def test_segment_image_limits_max_masks(self):
|
||||
"""Test segment_image limits to max_masks largest masks."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create 10 masks with different areas
|
||||
masks = [
|
||||
{
|
||||
"segment": np.ones((i + 1, i + 1), dtype=bool),
|
||||
"area": (i + 1) * (i + 1),
|
||||
"bbox": [0, 0, i + 1, i + 1],
|
||||
"predicted_iou": 0.9,
|
||||
"stability_score": 0.8,
|
||||
}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = masks
|
||||
|
||||
result = segment_image(
|
||||
mock_generator,
|
||||
Image.new("RGB", (100, 100)),
|
||||
min_area=1,
|
||||
max_masks=3,
|
||||
)
|
||||
|
||||
# Should only return top 3 largest masks
|
||||
assert len(result) == 3
|
||||
# Check they are sorted by area (largest first)
|
||||
areas = [m["area"] for m in result]
|
||||
assert areas == sorted(areas, reverse=True)
|
||||
|
||||
def test_segment_image_sorted_by_area(self):
|
||||
"""Test segment_image returns masks sorted by area descending."""
|
||||
from utils.sam import segment_image
|
||||
|
||||
# Create masks with known areas (unordered)
|
||||
mask1 = {"segment": np.ones((5, 5), dtype=bool), "area": 25, "bbox": [0, 0, 5, 5]}
|
||||
mask2 = {"segment": np.ones((10, 10), dtype=bool), "area": 100, "bbox": [0, 0, 10, 10]}
|
||||
mask3 = {"segment": np.ones((3, 3), dtype=bool), "area": 9, "bbox": [0, 0, 3, 3]}
|
||||
|
||||
mock_generator = Mock()
|
||||
mock_generator.generate.return_value = [mask1, mask2, mask3]
|
||||
|
||||
result = segment_image(
|
||||
mock_generator,
|
||||
Image.new("RGB", (100, 100)),
|
||||
min_area=1,
|
||||
max_masks=10,
|
||||
)
|
||||
|
||||
# Should be sorted by area descending
|
||||
assert result[0]["area"] == 100
|
||||
assert result[1]["area"] == 25
|
||||
assert result[2]["area"] == 9
|
||||
|
||||
|
||||
class TestExtractMaskedRegion:
|
||||
"""Test suite for extracting masked regions from images."""
|
||||
|
||||
def test_extract_masked_region_binary(self):
|
||||
"""Test extracting masked region with binary mask."""
|
||||
from utils.sam import extract_masked_region
|
||||
|
||||
# Create a simple image
|
||||
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||
|
||||
# Create a binary mask (half kept, half masked)
|
||||
mask = np.zeros((10, 10), dtype=bool)
|
||||
mask[:, :5] = True
|
||||
|
||||
result = extract_masked_region(image, mask)
|
||||
|
||||
# Check that left half is red, right half is black
|
||||
result_np = np.array(result)
|
||||
left_half = result_np[:, :5, :]
|
||||
right_half = result_np[:, 5:, :]
|
||||
|
||||
assert np.all(left_half == [255, 0, 0])
|
||||
assert np.all(right_half == [0, 0, 0])
|
||||
|
||||
def test_extract_masked_region_all_masked(self):
|
||||
"""Test extracting when entire image is masked."""
|
||||
from utils.sam import extract_masked_region
|
||||
|
||||
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||
mask = np.ones((10, 10), dtype=bool)
|
||||
|
||||
result = extract_masked_region(image, mask)
|
||||
result_np = np.array(result)
|
||||
|
||||
# Entire image should be preserved
|
||||
assert np.all(result_np == [255, 0, 0])
|
||||
|
||||
def test_extract_masked_region_all_zero_mask(self):
|
||||
"""Test extracting when mask is all zeros."""
|
||||
from utils.sam import extract_masked_region
|
||||
|
||||
image = Image.new("RGB", (10, 10), color=(255, 0, 0))
|
||||
mask = np.zeros((10, 10), dtype=bool)
|
||||
|
||||
result = extract_masked_region(image, mask)
|
||||
result_np = np.array(result)
|
||||
|
||||
# Entire image should be black
|
||||
assert np.all(result_np == [0, 0, 0])
|
||||
121
mini-nav/tests/test_scene_scoring.py
Normal file
121
mini-nav/tests/test_scene_scoring.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Tests for scene scoring algorithm in multi-object retrieval."""
|
||||
|
||||
import pytest
|
||||
from benchmarks.tasks.multi_object_retrieval import _compute_scene_score
|
||||
|
||||
|
||||
class TestSceneScoringAlgorithm:
|
||||
"""Test suite for scene scoring with co-occurrence penalty."""
|
||||
|
||||
def test_scene_score_basic(self):
|
||||
"""Test basic scene scoring with single match."""
|
||||
query_object_ids = ["obj_1", "obj_2", "obj_3"]
|
||||
|
||||
# Scene A has obj_1
|
||||
retrieved_results = {
|
||||
"scene_A": [("distance_1", "obj_1")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Hit rate = 1/3, similarity = 1/(1+distance_1)
|
||||
assert "scene_A" in scores
|
||||
assert scores["scene_A"] > 0
|
||||
|
||||
def test_scene_score_no_match(self):
|
||||
"""Test scene scoring when no objects match."""
|
||||
query_object_ids = ["obj_1", "obj_2", "obj_3"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("distance_1", "other_obj")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
assert scores["scene_A"] == 0.0
|
||||
|
||||
def test_scene_score_multiple_scenes(self):
|
||||
"""Test scoring across multiple scenes."""
|
||||
query_object_ids = ["obj_1", "obj_2"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1")],
|
||||
"scene_B": [("0.1", "obj_2")],
|
||||
"scene_C": [("0.1", "other")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Scenes with matches should have positive scores
|
||||
assert scores["scene_A"] > 0
|
||||
assert scores["scene_B"] > 0
|
||||
# Scene C has no match, score should be 0
|
||||
assert scores["scene_C"] == 0.0
|
||||
|
||||
def test_scene_score_gamma_zero(self):
|
||||
"""Test scoring with gamma=0 (no penalty)."""
|
||||
query_object_ids = ["obj_1", "obj_2", "obj_3", "obj_4", "obj_5"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1")],
|
||||
}
|
||||
|
||||
scores_gamma_0 = _compute_scene_score(query_object_ids, retrieved_results, gamma=0.0)
|
||||
scores_gamma_1 = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# With gamma=0, hit_rate^0 = 1, so score = similarity
|
||||
# With gamma=1, hit_rate^1 = 1/5, so score = similarity * 1/5
|
||||
# scores_gamma_0 should be larger
|
||||
assert scores_gamma_0["scene_A"] > scores_gamma_1["scene_A"]
|
||||
|
||||
def test_scene_score_multiple_matches(self):
|
||||
"""Test scoring when scene has multiple matching objects."""
|
||||
query_object_ids = ["obj_1", "obj_2"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1"), ("0.2", "obj_2")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Both objects match, hit_rate = 2/2 = 1.0
|
||||
# Score = (1/(1+0.1) + 1/(1+0.2)) * 1.0
|
||||
expected_similarity = 1.0 / 1.1 + 1.0 / 1.2
|
||||
assert abs(scores["scene_A"] - expected_similarity) < 0.01
|
||||
|
||||
def test_scene_score_distance_to_similarity(self):
|
||||
"""Test that smaller distance yields higher score."""
|
||||
query_object_ids = ["obj_1"]
|
||||
|
||||
retrieved_results = {
|
||||
"scene_close": [("0.01", "obj_1")],
|
||||
"scene_far": [("10.0", "obj_1")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# Closer scene should have higher score
|
||||
assert scores["scene_close"] > scores["scene_far"]
|
||||
|
||||
def test_scene_score_empty_results(self):
|
||||
"""Test scoring with empty retrieved results."""
|
||||
query_object_ids = ["obj_1", "obj_2"]
|
||||
|
||||
retrieved_results = {}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
assert scores == {}
|
||||
|
||||
def test_scene_score_empty_query(self):
|
||||
"""Test scoring with empty query objects."""
|
||||
query_object_ids = []
|
||||
|
||||
retrieved_results = {
|
||||
"scene_A": [("0.1", "obj_1")],
|
||||
}
|
||||
|
||||
scores = _compute_scene_score(query_object_ids, retrieved_results, gamma=1.0)
|
||||
|
||||
# With empty query, no scenes should have positive score
|
||||
assert all(score == 0.0 for score in scores.values())
|
||||
100
mini-nav/utils/sam.py
Normal file
100
mini-nav/utils/sam.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""SAM (Segment Anything Model) utilities for object segmentation."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
||||
|
||||
|
||||
def load_sam_model(
|
||||
model_name: str = "facebook/sam2.1-hiera-large",
|
||||
device: str = "cuda",
|
||||
checkpoint_dir: Path | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Load SAM 2.1 model and mask generator.
|
||||
|
||||
Args:
|
||||
model_name: SAM model name (currently supports facebook/sam2.1-hiera-*).
|
||||
device: Device to load model on (cuda or cpu).
|
||||
checkpoint_dir: Optional directory for model checkpoint cache.
|
||||
|
||||
Returns:
|
||||
Tuple of (sam_model, mask_generator).
|
||||
"""
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
device = "cpu"
|
||||
|
||||
# Build SAM2 model
|
||||
sam_model = build_sam2(model_name, device=device)
|
||||
|
||||
# Create automatic mask generator
|
||||
mask_generator = SAM2AutomaticMaskGenerator(sam_model)
|
||||
|
||||
return sam_model, mask_generator
|
||||
|
||||
|
||||
def segment_image(
|
||||
mask_generator: Any,
|
||||
image: Image.Image,
|
||||
min_area: int = 32 * 32,
|
||||
max_masks: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Segment image using SAM to extract object masks.
|
||||
|
||||
Args:
|
||||
mask_generator: SAM2AutomaticMaskGenerator instance.
|
||||
image: PIL Image to segment.
|
||||
min_area: Minimum mask area threshold in pixels.
|
||||
max_masks: Maximum number of masks to return.
|
||||
|
||||
Returns:
|
||||
List of mask dictionaries with keys:
|
||||
- segment: Binary mask (numpy array)
|
||||
- area: Mask area in pixels
|
||||
- bbox: Bounding box [x, y, width, height]
|
||||
- predicted_iou: Model's confidence in the mask
|
||||
- stability_score: Stability score for the mask
|
||||
"""
|
||||
# Convert PIL Image to numpy array
|
||||
image_np = np.array(image.convert("RGB"))
|
||||
|
||||
# Generate masks
|
||||
masks = mask_generator.generate(image_np)
|
||||
|
||||
if not masks:
|
||||
return []
|
||||
|
||||
# Filter by minimum area
|
||||
filtered_masks = [m for m in masks if m["area"] >= min_area]
|
||||
|
||||
if not filtered_masks:
|
||||
return []
|
||||
|
||||
# Sort by area (largest first) and limit to max_masks
|
||||
sorted_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True)
|
||||
return sorted_masks[:max_masks]
|
||||
|
||||
|
||||
def extract_masked_region(
|
||||
image: Image.Image,
|
||||
mask: np.ndarray,
|
||||
) -> Image.Image:
|
||||
"""Extract masked region from image.
|
||||
|
||||
Args:
|
||||
image: Original PIL Image.
|
||||
mask: Binary mask as numpy array (True = keep).
|
||||
|
||||
Returns:
|
||||
PIL Image with only the masked region visible.
|
||||
"""
|
||||
image_np = np.array(image.convert("RGB"))
|
||||
|
||||
# Apply mask
|
||||
masked_np = image_np * mask[:, :, np.newaxis]
|
||||
|
||||
return Image.fromarray(masked_np.astype(np.uint8))
|
||||
@@ -3,13 +3,16 @@ name = "mini-nav"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"accelerate>=1.12.0",
|
||||
"dash>=3.4.0",
|
||||
"dash-ag-grid>=33.3.3",
|
||||
"dash-mantine-components>=2.5.1",
|
||||
"datasets>=4.5.0",
|
||||
"habitat-baselines>=0.3.320250127",
|
||||
"habitat-lab>=0.3.320250127",
|
||||
"habitat-sim-uv-wheels-experimental>=0.2.2a2",
|
||||
"httpx[socks]>=0.28.1",
|
||||
"lancedb>=0.27.1",
|
||||
"polars[database,numpy,pandas,pydantic]>=1.37.1",
|
||||
|
||||
Reference in New Issue
Block a user