mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(configs, data_loading): improve code clarity and add docstrings
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
"""Configuration management module for unified config."""
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
ConfigManager,
|
ConfigManager,
|
||||||
cfg_manager,
|
cfg_manager,
|
||||||
@@ -8,7 +10,6 @@ from .models import (
|
|||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
OutputConfig,
|
OutputConfig,
|
||||||
PoolingType,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -17,7 +18,6 @@ __all__ = [
|
|||||||
"OutputConfig",
|
"OutputConfig",
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"Config",
|
"Config",
|
||||||
"PoolingType",
|
|
||||||
# Loader
|
# Loader
|
||||||
"load_yaml",
|
"load_yaml",
|
||||||
"save_yaml",
|
"save_yaml",
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class ConfigManager:
|
|||||||
"""Load configuration from config.yaml file.
|
"""Load configuration from config.yaml file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded and validated FeatureCompressorConfig instance
|
Loaded and validated Config instance
|
||||||
"""
|
"""
|
||||||
config = load_yaml(self.config_path, Config)
|
config = load_yaml(self.config_path, Config)
|
||||||
self._config = config
|
self._config = config
|
||||||
@@ -38,7 +38,7 @@ class ConfigManager:
|
|||||||
"""Get loaded configuration, auto-loading if not already loaded.
|
"""Get loaded configuration, auto-loading if not already loaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FeatureCompressorConfig instance
|
Config instance
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Automatically loads config if not already loaded
|
Automatically loads config if not already loaded
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ class OutputConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("directory", mode="after")
|
@field_validator("directory", mode="after")
|
||||||
def convert_to_absolute(cls, v: Path) -> Path:
|
def convert_to_absolute(cls, v: Path) -> Path:
|
||||||
"""
|
"""Converts the path to an absolute path relative to the project root.
|
||||||
Converts the path to an absolute path relative to the current working directory.
|
|
||||||
This works even if the path doesn't exist on disk.
|
This works even if the path doesn't exist on disk.
|
||||||
"""
|
"""
|
||||||
if v.is_absolute():
|
if v.is_absolute():
|
||||||
@@ -55,8 +55,8 @@ class DatasetConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("dataset_root", "output_dir", mode="after")
|
@field_validator("dataset_root", "output_dir", mode="after")
|
||||||
def convert_to_absolute(cls, v: Path) -> Path:
|
def convert_to_absolute(cls, v: Path) -> Path:
|
||||||
"""
|
"""Converts the path to an absolute path relative to the project root.
|
||||||
Converts the path to an absolute path relative to the project root.
|
|
||||||
This works even if the path doesn't exist on disk.
|
This works even if the path doesn't exist on disk.
|
||||||
"""
|
"""
|
||||||
if v.is_absolute():
|
if v.is_absolute():
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Data loading module for synthetic and validation datasets."""
|
||||||
|
|
||||||
from .loader import load_synth_dataset, load_val_dataset
|
from .loader import load_synth_dataset, load_val_dataset
|
||||||
from .synthesizer import ImageSynthesizer
|
from .synthesizer import ImageSynthesizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,93 @@
|
|||||||
"""Data loaders for synthetic and validation datasets using Hugging Face datasets."""
|
"""Data loaders for synthetic and validation datasets using Hugging Face datasets."""
|
||||||
|
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from datasets import Dataset, Image
|
from datasets import Dataset, Image
|
||||||
|
|
||||||
|
# Type alias for objects annotation
|
||||||
|
ObjectsDict = dict[str, list[list[float]] | list[str] | list[int] | list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bbox_line(line: str) -> tuple[str, list[float], float] | None:
|
||||||
|
"""Parse a single line from synth annotation file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line: Line in format "category xmin ymin xmax ymax"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (category, [xmin, ymin, width, height], area) or None if invalid
|
||||||
|
"""
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) != 5:
|
||||||
|
return None
|
||||||
|
|
||||||
|
category = parts[0]
|
||||||
|
xmin, ymin, xmax, ymax = map(int, parts[1:])
|
||||||
|
width = xmax - xmin
|
||||||
|
height = ymax - ymin
|
||||||
|
area = width * height
|
||||||
|
|
||||||
|
return category, [float(xmin), float(ymin), float(width), float(height)], float(area)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_element_text(element: ET.Element | None, default: str = "0") -> str:
|
||||||
|
"""Get text from XML element, returning default if element or text is None."""
|
||||||
|
if element is None:
|
||||||
|
return default
|
||||||
|
return element.text if element.text is not None else default
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_voc_xml(xml_path: Path) -> ObjectsDict:
|
||||||
|
"""Parse a VOC-format XML annotation file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xml_path: Path to the XML annotation file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
|
||||||
|
- category: List of object category names
|
||||||
|
- area: List of bounding box areas
|
||||||
|
- id: List of object IDs (0-based indices)
|
||||||
|
"""
|
||||||
|
tree = ET.parse(xml_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
bboxes: list[list[float]] = []
|
||||||
|
categories: list[str] = []
|
||||||
|
areas: list[float] = []
|
||||||
|
|
||||||
|
for obj in root.findall("object"):
|
||||||
|
name_elem = obj.find("name")
|
||||||
|
bndbox = obj.find("bndbox")
|
||||||
|
|
||||||
|
if name_elem is None or bndbox is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = name_elem.text
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
xmin = int(_get_element_text(bndbox.find("xmin")))
|
||||||
|
ymin = int(_get_element_text(bndbox.find("ymin")))
|
||||||
|
xmax = int(_get_element_text(bndbox.find("xmax")))
|
||||||
|
ymax = int(_get_element_text(bndbox.find("ymax")))
|
||||||
|
|
||||||
|
width = xmax - xmin
|
||||||
|
height = ymax - ymin
|
||||||
|
|
||||||
|
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
|
||||||
|
categories.append(name)
|
||||||
|
areas.append(float(width * height))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"bbox": bboxes,
|
||||||
|
"category": categories,
|
||||||
|
"area": areas,
|
||||||
|
"id": list(range(len(bboxes))),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_synth_dataset(
|
def load_synth_dataset(
|
||||||
synth_dir: Path,
|
synth_dir: Path,
|
||||||
@@ -14,10 +97,23 @@ def load_synth_dataset(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
synth_dir: Directory containing synthesized images and annotations
|
synth_dir: Directory containing synthesized images and annotations
|
||||||
annotations_suffix: Suffix for annotation files
|
annotations_suffix: Suffix for annotation files (default: ".txt")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Hugging Face Dataset with image and objects columns
|
Hugging Face Dataset with the following columns:
|
||||||
|
- image: PIL Image
|
||||||
|
- objects: dict containing:
|
||||||
|
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
|
||||||
|
- category: List of object category names
|
||||||
|
- area: List of bounding box areas
|
||||||
|
- id: List of object IDs (0-based indices)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> dataset = load_synth_dataset(Path("outputs/synth"))
|
||||||
|
>>> sample = dataset[0]
|
||||||
|
>>> # sample["image"] - PIL Image
|
||||||
|
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
|
||||||
|
>>> # sample["objects"]["category"] - ["category_name", ...]
|
||||||
"""
|
"""
|
||||||
synth_dir = Path(synth_dir)
|
synth_dir = Path(synth_dir)
|
||||||
image_files = sorted(synth_dir.glob("synth_*.jpg"))
|
image_files = sorted(synth_dir.glob("synth_*.jpg"))
|
||||||
@@ -26,7 +122,7 @@ def load_synth_dataset(
|
|||||||
return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image())
|
return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image())
|
||||||
|
|
||||||
image_paths: list[str] = []
|
image_paths: list[str] = []
|
||||||
all_objects: list[dict[str, Any]] = []
|
all_objects: list[ObjectsDict] = []
|
||||||
|
|
||||||
for img_path in image_files:
|
for img_path in image_files:
|
||||||
image_paths.append(str(img_path))
|
image_paths.append(str(img_path))
|
||||||
@@ -39,26 +135,24 @@ def load_synth_dataset(
|
|||||||
bboxes: list[list[float]] = []
|
bboxes: list[list[float]] = []
|
||||||
categories: list[str] = []
|
categories: list[str] = []
|
||||||
areas: list[float] = []
|
areas: list[float] = []
|
||||||
ids: list[int] = []
|
obj_id = 0
|
||||||
|
|
||||||
with open(anno_path, "r") as f:
|
with open(anno_path, "r", encoding="utf-8") as f:
|
||||||
for idx, line in enumerate(f):
|
for line in f:
|
||||||
if not (line := line.strip()):
|
if not (line := line.strip()):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parts = line.split()
|
result = _parse_bbox_line(line)
|
||||||
if len(parts) != 5:
|
if result is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
xmin, ymin, xmax, ymax = map(int, parts[1:])
|
category, bbox, area = result
|
||||||
width, height = xmax - xmin, ymax - ymin
|
bboxes.append(bbox)
|
||||||
|
categories.append(category)
|
||||||
|
areas.append(area)
|
||||||
|
obj_id += 1
|
||||||
|
|
||||||
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
|
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": list(range(len(bboxes)))})
|
||||||
categories.append(parts[0])
|
|
||||||
areas.append(float(width * height))
|
|
||||||
ids.append(idx)
|
|
||||||
|
|
||||||
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": ids})
|
|
||||||
|
|
||||||
dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects})
|
dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects})
|
||||||
return dataset.cast_column("image", Image())
|
return dataset.cast_column("image", Image())
|
||||||
@@ -68,14 +162,29 @@ def load_val_dataset(
|
|||||||
scenes_dir: Path,
|
scenes_dir: Path,
|
||||||
split: str = "easy",
|
split: str = "easy",
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load validation dataset from scene images.
|
"""Load validation dataset from scene images with VOC-format XML annotations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scenes_dir: Directory containing scene subdirectories
|
scenes_dir: Directory containing scene subdirectories
|
||||||
split: Scene split to load ('easy' or 'hard')
|
split: Scene split to load ('easy' or 'hard')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Hugging Face Dataset with image and image_id columns
|
Hugging Face Dataset with the following columns:
|
||||||
|
- image: PIL Image
|
||||||
|
- image_id: Image identifier (filename stem without extension)
|
||||||
|
- objects: dict containing (loaded from XML annotations):
|
||||||
|
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
|
||||||
|
- category: List of object category names
|
||||||
|
- area: List of bounding box areas
|
||||||
|
- id: List of object IDs (0-based indices)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> dataset = load_val_dataset(Path("datasets/InsDet-FULL/Scenes"), "easy")
|
||||||
|
>>> sample = dataset[0]
|
||||||
|
>>> # sample["image"] - PIL Image
|
||||||
|
>>> # sample["image_id"] - "rgb_000"
|
||||||
|
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
|
||||||
|
>>> # sample["objects"]["category"] - ["category_name", ...]
|
||||||
"""
|
"""
|
||||||
scenes_dir = Path(scenes_dir)
|
scenes_dir = Path(scenes_dir)
|
||||||
split_dir = scenes_dir / split
|
split_dir = scenes_dir / split
|
||||||
@@ -86,11 +195,27 @@ def load_val_dataset(
|
|||||||
rgb_files = sorted(split_dir.glob("*/rgb_*.jpg"))
|
rgb_files = sorted(split_dir.glob("*/rgb_*.jpg"))
|
||||||
|
|
||||||
if not rgb_files:
|
if not rgb_files:
|
||||||
return Dataset.from_dict({"image": [], "image_id": []}).cast_column("image", Image())
|
return Dataset.from_dict({"image": [], "image_id": [], "objects": []}).cast_column("image", Image())
|
||||||
|
|
||||||
|
image_paths: list[str] = []
|
||||||
|
image_ids: list[str] = []
|
||||||
|
all_objects: list[ObjectsDict] = []
|
||||||
|
|
||||||
|
for img_path in rgb_files:
|
||||||
|
image_paths.append(str(img_path))
|
||||||
|
image_ids.append(img_path.stem)
|
||||||
|
|
||||||
|
xml_path = img_path.with_suffix(".xml")
|
||||||
|
if xml_path.exists():
|
||||||
|
objects: ObjectsDict = _parse_voc_xml(xml_path)
|
||||||
|
else:
|
||||||
|
objects = {"bbox": [], "category": [], "area": [], "id": []}
|
||||||
|
|
||||||
|
all_objects.append(objects)
|
||||||
|
|
||||||
dataset = Dataset.from_dict({
|
dataset = Dataset.from_dict({
|
||||||
"image": [str(p) for p in rgb_files],
|
"image": image_paths,
|
||||||
"image_id": [p.stem for p in rgb_files],
|
"image_id": image_ids,
|
||||||
|
"objects": all_objects,
|
||||||
})
|
})
|
||||||
|
|
||||||
return dataset.cast_column("image", Image())
|
return dataset.cast_column("image", Image())
|
||||||
|
|||||||
@@ -57,7 +57,11 @@ class ImageSynthesizer:
|
|||||||
"""List of background image paths."""
|
"""List of background image paths."""
|
||||||
if self._background_categories is None:
|
if self._background_categories is None:
|
||||||
self._background_categories = sorted(
|
self._background_categories = sorted(
|
||||||
[p.name for p in self.background_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png"]]
|
[
|
||||||
|
p.name
|
||||||
|
for p in self.background_dir.iterdir()
|
||||||
|
if p.suffix in [".jpg", ".jpeg", ".png"]
|
||||||
|
]
|
||||||
)
|
)
|
||||||
# Return as list of Path for type compatibility
|
# Return as list of Path for type compatibility
|
||||||
return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value]
|
return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value]
|
||||||
@@ -126,7 +130,9 @@ class ImageSynthesizer:
|
|||||||
mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True)
|
mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True)
|
||||||
return image, mask
|
return image, mask
|
||||||
|
|
||||||
def _compute_overlap(self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]) -> float:
|
def _compute_overlap(
|
||||||
|
self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]
|
||||||
|
) -> float:
|
||||||
"""Compute overlap ratio between two boxes.
|
"""Compute overlap ratio between two boxes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -153,7 +159,26 @@ class ImageSynthesizer:
|
|||||||
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
|
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
|
||||||
min_area = min(box1_area, box2_area)
|
min_area = min(box1_area, box2_area)
|
||||||
|
|
||||||
return inter_area / min_area if min_area > 0 else 0.0
|
return inter_area / (min_area if min_area > 0 else 0.0)
|
||||||
|
|
||||||
|
def _has_overlap(
|
||||||
|
self,
|
||||||
|
new_box: tuple[int, int, int, int],
|
||||||
|
existing_boxes: list[tuple[int, int, int, int]],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if new_box overlaps with any existing boxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_box: The new bounding box (xmin, ymin, xmax, ymax)
|
||||||
|
existing_boxes: List of existing bounding boxes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if any overlap exceeds threshold, False otherwise
|
||||||
|
"""
|
||||||
|
for existing_box in existing_boxes:
|
||||||
|
if self._compute_overlap(new_box, existing_box) > self.overlap_threshold:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _place_object(
|
def _place_object(
|
||||||
self,
|
self,
|
||||||
@@ -182,7 +207,7 @@ class ImageSynthesizer:
|
|||||||
new_w = int(obj_w * scale)
|
new_w = int(obj_w * scale)
|
||||||
new_h = int(obj_h * scale)
|
new_h = int(obj_h * scale)
|
||||||
|
|
||||||
if new_w <= 0 or new_h <= 0:
|
if new_w <= 0 or new_h <= 0 or new_w > bg_w or new_h > bg_h:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS)
|
obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS)
|
||||||
@@ -198,32 +223,18 @@ class ImageSynthesizer:
|
|||||||
new_box = (x, y, x + new_w, y + new_h)
|
new_box = (x, y, x + new_w, y + new_h)
|
||||||
|
|
||||||
# Check overlap with existing boxes
|
# Check overlap with existing boxes
|
||||||
valid = True
|
if not self._has_overlap(new_box, existing_boxes):
|
||||||
for existing_box in existing_boxes:
|
# Composite object onto background using Pillow's paste method
|
||||||
overlap = self._compute_overlap(new_box, existing_box)
|
|
||||||
if overlap > self.overlap_threshold:
|
|
||||||
valid = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if valid:
|
|
||||||
# Composite object onto background
|
|
||||||
background = background.copy()
|
background = background.copy()
|
||||||
mask_array = np.array(obj_mask) / 255.0
|
background.paste(obj_image, (x, y), mask=obj_mask)
|
||||||
bg_array = np.array(background)
|
|
||||||
obj_array = np.array(obj_image)
|
|
||||||
|
|
||||||
# Apply mask
|
return background, obj_mask, new_box
|
||||||
mask_3d = np.stack([mask_array] * 3, axis=-1)
|
|
||||||
bg_array[y:y+new_h, x:x+new_w] = (
|
|
||||||
bg_array[y:y+new_h, x:x+new_w] * (1 - mask_3d) +
|
|
||||||
obj_array * mask_3d
|
|
||||||
)
|
|
||||||
|
|
||||||
return Image.fromarray(bg_array), obj_mask, new_box
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def synthesize_scene(self) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
|
def synthesize_scene(
|
||||||
|
self,
|
||||||
|
) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
|
||||||
"""Synthesize a single scene with random objects.
|
"""Synthesize a single scene with random objects.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -251,10 +262,14 @@ class ImageSynthesizer:
|
|||||||
|
|
||||||
# Get random rotation
|
# Get random rotation
|
||||||
angle = random.uniform(*self.rotation_range)
|
angle = random.uniform(*self.rotation_range)
|
||||||
obj_image, obj_mask = self._rotate_image_and_mask(obj_image, obj_mask, angle)
|
obj_image, obj_mask = self._rotate_image_and_mask(
|
||||||
|
obj_image, obj_mask, angle
|
||||||
|
)
|
||||||
|
|
||||||
# Try to place object
|
# Try to place object
|
||||||
result = self._place_object(background, obj_image, obj_mask, placed_boxes, scale)
|
result = self._place_object(
|
||||||
|
background, obj_image, obj_mask, placed_boxes, scale
|
||||||
|
)
|
||||||
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
background, _, box = result
|
background, _, box = result
|
||||||
@@ -286,7 +301,7 @@ class ImageSynthesizer:
|
|||||||
|
|
||||||
# Save annotation
|
# Save annotation
|
||||||
anno_path = self.output_dir / f"synth_{i:04d}.txt"
|
anno_path = self.output_dir / f"synth_{i:04d}.txt"
|
||||||
with open(anno_path, "w") as f:
|
with open(anno_path, "w", encoding="utf-8") as f:
|
||||||
for category, xmin, ymin, xmax, ymax in annotations:
|
for category, xmin, ymin, xmax, ymax in annotations:
|
||||||
f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")
|
f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user