mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
refactor(configs, data_loading): improve code clarity and add docstrings
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""Configuration management module for unified config."""
|
||||
|
||||
from .config import (
|
||||
ConfigManager,
|
||||
cfg_manager,
|
||||
@@ -8,7 +10,6 @@ from .models import (
|
||||
DatasetConfig,
|
||||
ModelConfig,
|
||||
OutputConfig,
|
||||
PoolingType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -17,7 +18,6 @@ __all__ = [
|
||||
"OutputConfig",
|
||||
"DatasetConfig",
|
||||
"Config",
|
||||
"PoolingType",
|
||||
# Loader
|
||||
"load_yaml",
|
||||
"save_yaml",
|
||||
|
||||
@@ -28,7 +28,7 @@ class ConfigManager:
|
||||
"""Load configuration from config.yaml file.
|
||||
|
||||
Returns:
|
||||
Loaded and validated FeatureCompressorConfig instance
|
||||
Loaded and validated Config instance
|
||||
"""
|
||||
config = load_yaml(self.config_path, Config)
|
||||
self._config = config
|
||||
@@ -38,7 +38,7 @@ class ConfigManager:
|
||||
"""Get loaded configuration, auto-loading if not already loaded.
|
||||
|
||||
Returns:
|
||||
FeatureCompressorConfig instance
|
||||
Config instance
|
||||
|
||||
Note:
|
||||
Automatically loads config if not already loaded
|
||||
|
||||
@@ -26,8 +26,8 @@ class OutputConfig(BaseModel):
|
||||
|
||||
@field_validator("directory", mode="after")
|
||||
def convert_to_absolute(cls, v: Path) -> Path:
|
||||
"""
|
||||
Converts the path to an absolute path relative to the current working directory.
|
||||
"""Converts the path to an absolute path relative to the project root.
|
||||
|
||||
This works even if the path doesn't exist on disk.
|
||||
"""
|
||||
if v.is_absolute():
|
||||
@@ -55,8 +55,8 @@ class DatasetConfig(BaseModel):
|
||||
|
||||
@field_validator("dataset_root", "output_dir", mode="after")
|
||||
def convert_to_absolute(cls, v: Path) -> Path:
|
||||
"""
|
||||
Converts the path to an absolute path relative to the project root.
|
||||
"""Converts the path to an absolute path relative to the project root.
|
||||
|
||||
This works even if the path doesn't exist on disk.
|
||||
"""
|
||||
if v.is_absolute():
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Data loading module for synthetic and validation datasets."""
|
||||
|
||||
from .loader import load_synth_dataset, load_val_dataset
|
||||
from .synthesizer import ImageSynthesizer
|
||||
|
||||
|
||||
@@ -1,10 +1,93 @@
|
||||
"""Data loaders for synthetic and validation datasets using Hugging Face datasets."""
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from datasets import Dataset, Image
|
||||
|
||||
# Type alias for objects annotation
|
||||
ObjectsDict = dict[str, list[list[float]] | list[str] | list[int] | list[float]]
|
||||
|
||||
|
||||
def _parse_bbox_line(line: str) -> tuple[str, list[float], float] | None:
|
||||
"""Parse a single line from synth annotation file.
|
||||
|
||||
Args:
|
||||
line: Line in format "category xmin ymin xmax ymax"
|
||||
|
||||
Returns:
|
||||
Tuple of (category, [xmin, ymin, width, height], area) or None if invalid
|
||||
"""
|
||||
parts = line.split()
|
||||
if len(parts) != 5:
|
||||
return None
|
||||
|
||||
category = parts[0]
|
||||
xmin, ymin, xmax, ymax = map(int, parts[1:])
|
||||
width = xmax - xmin
|
||||
height = ymax - ymin
|
||||
area = width * height
|
||||
|
||||
return category, [float(xmin), float(ymin), float(width), float(height)], float(area)
|
||||
|
||||
|
||||
def _get_element_text(element: ET.Element | None, default: str = "0") -> str:
|
||||
"""Get text from XML element, returning default if element or text is None."""
|
||||
if element is None:
|
||||
return default
|
||||
return element.text if element.text is not None else default
|
||||
|
||||
|
||||
def _parse_voc_xml(xml_path: Path) -> ObjectsDict:
|
||||
"""Parse a VOC-format XML annotation file.
|
||||
|
||||
Args:
|
||||
xml_path: Path to the XML annotation file
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
|
||||
- category: List of object category names
|
||||
- area: List of bounding box areas
|
||||
- id: List of object IDs (0-based indices)
|
||||
"""
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
bboxes: list[list[float]] = []
|
||||
categories: list[str] = []
|
||||
areas: list[float] = []
|
||||
|
||||
for obj in root.findall("object"):
|
||||
name_elem = obj.find("name")
|
||||
bndbox = obj.find("bndbox")
|
||||
|
||||
if name_elem is None or bndbox is None:
|
||||
continue
|
||||
|
||||
name = name_elem.text
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
xmin = int(_get_element_text(bndbox.find("xmin")))
|
||||
ymin = int(_get_element_text(bndbox.find("ymin")))
|
||||
xmax = int(_get_element_text(bndbox.find("xmax")))
|
||||
ymax = int(_get_element_text(bndbox.find("ymax")))
|
||||
|
||||
width = xmax - xmin
|
||||
height = ymax - ymin
|
||||
|
||||
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
|
||||
categories.append(name)
|
||||
areas.append(float(width * height))
|
||||
|
||||
return {
|
||||
"bbox": bboxes,
|
||||
"category": categories,
|
||||
"area": areas,
|
||||
"id": list(range(len(bboxes))),
|
||||
}
|
||||
|
||||
|
||||
def load_synth_dataset(
|
||||
synth_dir: Path,
|
||||
@@ -14,10 +97,23 @@ def load_synth_dataset(
|
||||
|
||||
Args:
|
||||
synth_dir: Directory containing synthesized images and annotations
|
||||
annotations_suffix: Suffix for annotation files
|
||||
annotations_suffix: Suffix for annotation files (default: ".txt")
|
||||
|
||||
Returns:
|
||||
Hugging Face Dataset with image and objects columns
|
||||
Hugging Face Dataset with the following columns:
|
||||
- image: PIL Image
|
||||
- objects: dict containing:
|
||||
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
|
||||
- category: List of object category names
|
||||
- area: List of bounding box areas
|
||||
- id: List of object IDs (0-based indices)
|
||||
|
||||
Example:
|
||||
>>> dataset = load_synth_dataset(Path("outputs/synth"))
|
||||
>>> sample = dataset[0]
|
||||
>>> # sample["image"] - PIL Image
|
||||
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
|
||||
>>> # sample["objects"]["category"] - ["category_name", ...]
|
||||
"""
|
||||
synth_dir = Path(synth_dir)
|
||||
image_files = sorted(synth_dir.glob("synth_*.jpg"))
|
||||
@@ -26,7 +122,7 @@ def load_synth_dataset(
|
||||
return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image())
|
||||
|
||||
image_paths: list[str] = []
|
||||
all_objects: list[dict[str, Any]] = []
|
||||
all_objects: list[ObjectsDict] = []
|
||||
|
||||
for img_path in image_files:
|
||||
image_paths.append(str(img_path))
|
||||
@@ -39,26 +135,24 @@ def load_synth_dataset(
|
||||
bboxes: list[list[float]] = []
|
||||
categories: list[str] = []
|
||||
areas: list[float] = []
|
||||
ids: list[int] = []
|
||||
obj_id = 0
|
||||
|
||||
with open(anno_path, "r") as f:
|
||||
for idx, line in enumerate(f):
|
||||
with open(anno_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not (line := line.strip()):
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) != 5:
|
||||
result = _parse_bbox_line(line)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
xmin, ymin, xmax, ymax = map(int, parts[1:])
|
||||
width, height = xmax - xmin, ymax - ymin
|
||||
category, bbox, area = result
|
||||
bboxes.append(bbox)
|
||||
categories.append(category)
|
||||
areas.append(area)
|
||||
obj_id += 1
|
||||
|
||||
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
|
||||
categories.append(parts[0])
|
||||
areas.append(float(width * height))
|
||||
ids.append(idx)
|
||||
|
||||
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": ids})
|
||||
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": list(range(len(bboxes)))})
|
||||
|
||||
dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects})
|
||||
return dataset.cast_column("image", Image())
|
||||
@@ -68,14 +162,29 @@ def load_val_dataset(
|
||||
scenes_dir: Path,
|
||||
split: str = "easy",
|
||||
) -> Dataset:
|
||||
"""Load validation dataset from scene images.
|
||||
"""Load validation dataset from scene images with VOC-format XML annotations.
|
||||
|
||||
Args:
|
||||
scenes_dir: Directory containing scene subdirectories
|
||||
split: Scene split to load ('easy' or 'hard')
|
||||
|
||||
Returns:
|
||||
Hugging Face Dataset with image and image_id columns
|
||||
Hugging Face Dataset with the following columns:
|
||||
- image: PIL Image
|
||||
- image_id: Image identifier (filename stem without extension)
|
||||
- objects: dict containing (loaded from XML annotations):
|
||||
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
|
||||
- category: List of object category names
|
||||
- area: List of bounding box areas
|
||||
- id: List of object IDs (0-based indices)
|
||||
|
||||
Example:
|
||||
>>> dataset = load_val_dataset(Path("datasets/InsDet-FULL/Scenes"), "easy")
|
||||
>>> sample = dataset[0]
|
||||
>>> # sample["image"] - PIL Image
|
||||
>>> # sample["image_id"] - "rgb_000"
|
||||
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
|
||||
>>> # sample["objects"]["category"] - ["category_name", ...]
|
||||
"""
|
||||
scenes_dir = Path(scenes_dir)
|
||||
split_dir = scenes_dir / split
|
||||
@@ -86,11 +195,27 @@ def load_val_dataset(
|
||||
rgb_files = sorted(split_dir.glob("*/rgb_*.jpg"))
|
||||
|
||||
if not rgb_files:
|
||||
return Dataset.from_dict({"image": [], "image_id": []}).cast_column("image", Image())
|
||||
return Dataset.from_dict({"image": [], "image_id": [], "objects": []}).cast_column("image", Image())
|
||||
|
||||
image_paths: list[str] = []
|
||||
image_ids: list[str] = []
|
||||
all_objects: list[ObjectsDict] = []
|
||||
|
||||
for img_path in rgb_files:
|
||||
image_paths.append(str(img_path))
|
||||
image_ids.append(img_path.stem)
|
||||
|
||||
xml_path = img_path.with_suffix(".xml")
|
||||
if xml_path.exists():
|
||||
objects: ObjectsDict = _parse_voc_xml(xml_path)
|
||||
else:
|
||||
objects = {"bbox": [], "category": [], "area": [], "id": []}
|
||||
|
||||
all_objects.append(objects)
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"image": [str(p) for p in rgb_files],
|
||||
"image_id": [p.stem for p in rgb_files],
|
||||
"image": image_paths,
|
||||
"image_id": image_ids,
|
||||
"objects": all_objects,
|
||||
})
|
||||
|
||||
return dataset.cast_column("image", Image())
|
||||
|
||||
@@ -57,7 +57,11 @@ class ImageSynthesizer:
|
||||
"""List of background image paths."""
|
||||
if self._background_categories is None:
|
||||
self._background_categories = sorted(
|
||||
[p.name for p in self.background_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png"]]
|
||||
[
|
||||
p.name
|
||||
for p in self.background_dir.iterdir()
|
||||
if p.suffix in [".jpg", ".jpeg", ".png"]
|
||||
]
|
||||
)
|
||||
# Return as list of Path for type compatibility
|
||||
return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value]
|
||||
@@ -126,7 +130,9 @@ class ImageSynthesizer:
|
||||
mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True)
|
||||
return image, mask
|
||||
|
||||
def _compute_overlap(self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]) -> float:
|
||||
def _compute_overlap(
|
||||
self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]
|
||||
) -> float:
|
||||
"""Compute overlap ratio between two boxes.
|
||||
|
||||
Args:
|
||||
@@ -153,7 +159,26 @@ class ImageSynthesizer:
|
||||
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
|
||||
min_area = min(box1_area, box2_area)
|
||||
|
||||
return inter_area / min_area if min_area > 0 else 0.0
|
||||
return inter_area / (min_area if min_area > 0 else 0.0)
|
||||
|
||||
def _has_overlap(
|
||||
self,
|
||||
new_box: tuple[int, int, int, int],
|
||||
existing_boxes: list[tuple[int, int, int, int]],
|
||||
) -> bool:
|
||||
"""Check if new_box overlaps with any existing boxes.
|
||||
|
||||
Args:
|
||||
new_box: The new bounding box (xmin, ymin, xmax, ymax)
|
||||
existing_boxes: List of existing bounding boxes
|
||||
|
||||
Returns:
|
||||
True if any overlap exceeds threshold, False otherwise
|
||||
"""
|
||||
for existing_box in existing_boxes:
|
||||
if self._compute_overlap(new_box, existing_box) > self.overlap_threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _place_object(
|
||||
self,
|
||||
@@ -182,7 +207,7 @@ class ImageSynthesizer:
|
||||
new_w = int(obj_w * scale)
|
||||
new_h = int(obj_h * scale)
|
||||
|
||||
if new_w <= 0 or new_h <= 0:
|
||||
if new_w <= 0 or new_h <= 0 or new_w > bg_w or new_h > bg_h:
|
||||
return None
|
||||
|
||||
obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS)
|
||||
@@ -198,32 +223,18 @@ class ImageSynthesizer:
|
||||
new_box = (x, y, x + new_w, y + new_h)
|
||||
|
||||
# Check overlap with existing boxes
|
||||
valid = True
|
||||
for existing_box in existing_boxes:
|
||||
overlap = self._compute_overlap(new_box, existing_box)
|
||||
if overlap > self.overlap_threshold:
|
||||
valid = False
|
||||
break
|
||||
|
||||
if valid:
|
||||
# Composite object onto background
|
||||
if not self._has_overlap(new_box, existing_boxes):
|
||||
# Composite object onto background using Pillow's paste method
|
||||
background = background.copy()
|
||||
mask_array = np.array(obj_mask) / 255.0
|
||||
bg_array = np.array(background)
|
||||
obj_array = np.array(obj_image)
|
||||
background.paste(obj_image, (x, y), mask=obj_mask)
|
||||
|
||||
# Apply mask
|
||||
mask_3d = np.stack([mask_array] * 3, axis=-1)
|
||||
bg_array[y:y+new_h, x:x+new_w] = (
|
||||
bg_array[y:y+new_h, x:x+new_w] * (1 - mask_3d) +
|
||||
obj_array * mask_3d
|
||||
)
|
||||
|
||||
return Image.fromarray(bg_array), obj_mask, new_box
|
||||
return background, obj_mask, new_box
|
||||
|
||||
return None
|
||||
|
||||
def synthesize_scene(self) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
|
||||
def synthesize_scene(
|
||||
self,
|
||||
) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
|
||||
"""Synthesize a single scene with random objects.
|
||||
|
||||
Returns:
|
||||
@@ -251,10 +262,14 @@ class ImageSynthesizer:
|
||||
|
||||
# Get random rotation
|
||||
angle = random.uniform(*self.rotation_range)
|
||||
obj_image, obj_mask = self._rotate_image_and_mask(obj_image, obj_mask, angle)
|
||||
obj_image, obj_mask = self._rotate_image_and_mask(
|
||||
obj_image, obj_mask, angle
|
||||
)
|
||||
|
||||
# Try to place object
|
||||
result = self._place_object(background, obj_image, obj_mask, placed_boxes, scale)
|
||||
result = self._place_object(
|
||||
background, obj_image, obj_mask, placed_boxes, scale
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
background, _, box = result
|
||||
@@ -286,7 +301,7 @@ class ImageSynthesizer:
|
||||
|
||||
# Save annotation
|
||||
anno_path = self.output_dir / f"synth_{i:04d}.txt"
|
||||
with open(anno_path, "w") as f:
|
||||
with open(anno_path, "w", encoding="utf-8") as f:
|
||||
for category, xmin, ymin, xmax, ymax in annotations:
|
||||
f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user