mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
144 lines
4.9 KiB
Python
144 lines
4.9 KiB
Python
"""Pydantic data models for feature compressor configuration."""
|
|
|
|
from pathlib import Path
|
|
from typing import Literal, Optional
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
|
|
|
|
class ModelConfig(BaseModel):
|
|
"""Configuration for the vision model and compression."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
dino_model: str = "facebook/dinov2-large"
|
|
compression_dim: int = Field(
|
|
default=512, gt=0, description="Output feature dimension"
|
|
)
|
|
device: str = "auto"
|
|
sam_model: str = Field(
|
|
default="facebook/sam2.1-hiera-large",
|
|
description="SAM model name from HuggingFace",
|
|
)
|
|
sam_min_mask_area: int = Field(
|
|
default=100, gt=0, description="Minimum mask area threshold"
|
|
)
|
|
sam_max_masks: int = Field(
|
|
default=10, gt=0, description="Maximum number of masks to keep"
|
|
)
|
|
compressor_path: Optional[str] = Field(
|
|
default=None, description="Path to trained HashCompressor weights"
|
|
)
|
|
|
|
|
|
class OutputConfig(BaseModel):
|
|
"""Configuration for output settings."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
directory: Path = Path(__file__).parent.parent.parent / "outputs"
|
|
|
|
@field_validator("directory", mode="after")
|
|
def convert_to_absolute(cls, v: Path) -> Path:
|
|
"""Converts the path to an absolute path relative to the project root.
|
|
|
|
This works even if the path doesn't exist on disk.
|
|
"""
|
|
if v.is_absolute():
|
|
return v
|
|
return Path(__file__).parent.parent.parent / v
|
|
|
|
|
|
class DatasetConfig(BaseModel):
|
|
"""Configuration for synthetic dataset generation."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
dataset_root: Path = (
|
|
Path(__file__).parent.parent.parent / "datasets" / "InsDet-FULL"
|
|
)
|
|
output_dir: Path = (
|
|
Path(__file__).parent.parent.parent / "datasets" / "InsDet-FULL" / "Synthesized"
|
|
)
|
|
num_objects_range: tuple[int, int] = (3, 8)
|
|
num_scenes: int = 1000
|
|
object_scale_range: tuple[float, float] = (0.1, 0.4)
|
|
rotation_range: tuple[int, int] = (-30, 30)
|
|
overlap_threshold: float = 0.3
|
|
seed: int = 42
|
|
|
|
@field_validator("dataset_root", "output_dir", mode="after")
|
|
def convert_to_absolute(cls, v: Path) -> Path:
|
|
"""Converts the path to an absolute path relative to the project root.
|
|
|
|
This works even if the path doesn't exist on disk.
|
|
"""
|
|
if v.is_absolute():
|
|
return v
|
|
return Path(__file__).parent.parent.parent / v
|
|
|
|
@field_validator("num_objects_range", mode="after")
|
|
def validate_num_objects(cls, v: tuple[int, int]) -> tuple[int, int]:
|
|
if v[0] < 1 or v[1] < v[0]:
|
|
raise ValueError("num_objects_range must have min >= 1 and min <= max")
|
|
return v
|
|
|
|
@field_validator("object_scale_range", mode="after")
|
|
def validate_scale(cls, v: tuple[float, float]) -> tuple[float, float]:
|
|
if v[0] <= 0 or v[1] <= 0 or v[1] < v[0]:
|
|
raise ValueError(
|
|
"object_scale_range must have positive values and min <= max"
|
|
)
|
|
return v
|
|
|
|
@field_validator("overlap_threshold", mode="after")
|
|
def validate_overlap(cls, v: float) -> float:
|
|
if not 0 <= v <= 1:
|
|
raise ValueError("overlap_threshold must be between 0 and 1")
|
|
return v
|
|
|
|
|
|
class DatasetSourceConfig(BaseModel):
|
|
"""Configuration for benchmark dataset source."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
source_type: Literal["huggingface", "local"] = "huggingface"
|
|
path: str = Field(default="", description="HuggingFace dataset ID or local path")
|
|
img_column: str = Field(default="img", description="Image column name")
|
|
label_column: str = Field(default="label", description="Label column name")
|
|
|
|
|
|
class BenchmarkTaskConfig(BaseModel):
|
|
"""Configuration for benchmark task."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
name: str = Field(default="recall_at_k", description="Task name")
|
|
type: str = Field(default="retrieval", description="Task type")
|
|
top_k: int = Field(default=10, gt=0, description="Top K for recall evaluation")
|
|
|
|
|
|
class BenchmarkConfig(BaseModel):
|
|
"""Configuration for benchmark evaluation."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
dataset: DatasetSourceConfig = Field(default_factory=DatasetSourceConfig)
|
|
task: BenchmarkTaskConfig = Field(default_factory=BenchmarkTaskConfig)
|
|
batch_size: int = Field(default=64, gt=0, description="Batch size for DataLoader")
|
|
model_table_prefix: str = Field(
|
|
default="benchmark", description="Prefix for LanceDB table names"
|
|
)
|
|
|
|
|
|
class Config(BaseModel):
|
|
"""Root configuration for the feature compressor."""
|
|
|
|
model_config = ConfigDict(extra="ignore")
|
|
|
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
|
output: OutputConfig = Field(default_factory=OutputConfig)
|
|
dataset: DatasetConfig = Field(default_factory=DatasetConfig)
|
|
benchmark: BenchmarkConfig = Field(default_factory=BenchmarkConfig)
|