From f0479cc69b5c1dc3b77529c37eb005fa08823e5f Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Sat, 28 Feb 2026 22:07:25 +0800 Subject: [PATCH] refactor(configs, data_loading): improve code clarity and add docstrings --- mini-nav/configs/__init__.py | 4 +- mini-nav/configs/config.py | 4 +- mini-nav/configs/models.py | 8 +- mini-nav/data_loading/__init__.py | 2 + mini-nav/data_loading/loader.py | 171 +++++++++++++++++++++++---- mini-nav/data_loading/synthesizer.py | 71 ++++++----- 6 files changed, 201 insertions(+), 59 deletions(-) diff --git a/mini-nav/configs/__init__.py b/mini-nav/configs/__init__.py index 68771ba..cf4f113 100644 --- a/mini-nav/configs/__init__.py +++ b/mini-nav/configs/__init__.py @@ -1,3 +1,5 @@ +"""Configuration management module for unified config.""" + from .config import ( ConfigManager, cfg_manager, @@ -8,7 +10,6 @@ from .models import ( DatasetConfig, ModelConfig, OutputConfig, - PoolingType, ) __all__ = [ @@ -17,7 +18,6 @@ __all__ = [ "OutputConfig", "DatasetConfig", "Config", - "PoolingType", # Loader "load_yaml", "save_yaml", diff --git a/mini-nav/configs/config.py b/mini-nav/configs/config.py index 6d23d28..d984c27 100644 --- a/mini-nav/configs/config.py +++ b/mini-nav/configs/config.py @@ -28,7 +28,7 @@ class ConfigManager: """Load configuration from config.yaml file. Returns: - Loaded and validated FeatureCompressorConfig instance + Loaded and validated Config instance """ config = load_yaml(self.config_path, Config) self._config = config @@ -38,7 +38,7 @@ class ConfigManager: """Get loaded configuration, auto-loading if not already loaded. Returns: - FeatureCompressorConfig instance + Config instance Note: Automatically loads config if not already loaded diff --git a/mini-nav/configs/models.py b/mini-nav/configs/models.py index dd6c9cd..a87d837 100644 --- a/mini-nav/configs/models.py +++ b/mini-nav/configs/models.py @@ -26,8 +26,8 @@ class OutputConfig(BaseModel): @field_validator("directory", mode="after") def convert_to_absolute(cls, v: Path) -> Path: - """ - Converts the path to an absolute path relative to the current working directory. + """Converts the path to an absolute path relative to the project root. + This works even if the path doesn't exist on disk. """ if v.is_absolute(): @@ -55,8 +55,8 @@ class DatasetConfig(BaseModel): @field_validator("dataset_root", "output_dir", mode="after") def convert_to_absolute(cls, v: Path) -> Path: - """ - Converts the path to an absolute path relative to the project root. + """Converts the path to an absolute path relative to the project root. + This works even if the path doesn't exist on disk. """ if v.is_absolute(): diff --git a/mini-nav/data_loading/__init__.py b/mini-nav/data_loading/__init__.py index 52f0001..429558a 100644 --- a/mini-nav/data_loading/__init__.py +++ b/mini-nav/data_loading/__init__.py @@ -1,3 +1,5 @@ +"""Data loading module for synthetic and validation datasets.""" + from .loader import load_synth_dataset, load_val_dataset from .synthesizer import ImageSynthesizer diff --git a/mini-nav/data_loading/loader.py b/mini-nav/data_loading/loader.py index 0893584..7402d5b 100644 --- a/mini-nav/data_loading/loader.py +++ b/mini-nav/data_loading/loader.py @@ -1,10 +1,93 @@ """Data loaders for synthetic and validation datasets using Hugging Face datasets.""" +import xml.etree.ElementTree as ET from pathlib import Path -from typing import Any from datasets import Dataset, Image +# Type alias for objects annotation +ObjectsDict = dict[str, list[list[float]] | list[str] | list[int] | list[float]] + + +def _parse_bbox_line(line: str) -> tuple[str, list[float], float] | None: + """Parse a single line from synth annotation file. + + Args: + line: Line in format "category xmin ymin xmax ymax" + + Returns: + Tuple of (category, [xmin, ymin, width, height], area) or None if invalid + """ + parts = line.split() + if len(parts) != 5: + return None + + category = parts[0] + xmin, ymin, xmax, ymax = map(int, parts[1:]) + width = xmax - xmin + height = ymax - ymin + area = width * height + + return category, [float(xmin), float(ymin), float(width), float(height)], float(area) + + +def _get_element_text(element: ET.Element | None, default: str = "0") -> str: + """Get text from XML element, returning default if element or text is None.""" + if element is None: + return default + return element.text if element.text is not None else default + + +def _parse_voc_xml(xml_path: Path) -> ObjectsDict: + """Parse a VOC-format XML annotation file. + + Args: + xml_path: Path to the XML annotation file + + Returns: + Dictionary containing: + - bbox: List of bounding boxes in [xmin, ymin, width, height] format + - category: List of object category names + - area: List of bounding box areas + - id: List of object IDs (0-based indices) + """ + tree = ET.parse(xml_path) + root = tree.getroot() + + bboxes: list[list[float]] = [] + categories: list[str] = [] + areas: list[float] = [] + + for obj in root.findall("object"): + name_elem = obj.find("name") + bndbox = obj.find("bndbox") + + if name_elem is None or bndbox is None: + continue + + name = name_elem.text + if name is None: + continue + + xmin = int(_get_element_text(bndbox.find("xmin"))) + ymin = int(_get_element_text(bndbox.find("ymin"))) + xmax = int(_get_element_text(bndbox.find("xmax"))) + ymax = int(_get_element_text(bndbox.find("ymax"))) + + width = xmax - xmin + height = ymax - ymin + + bboxes.append([float(xmin), float(ymin), float(width), float(height)]) + categories.append(name) + areas.append(float(width * height)) + + return { + "bbox": bboxes, + "category": categories, + "area": areas, + "id": list(range(len(bboxes))), + } + def load_synth_dataset( synth_dir: Path, @@ -14,10 +97,23 @@ def load_synth_dataset( Args: synth_dir: Directory containing synthesized images and annotations - annotations_suffix: Suffix for annotation files + annotations_suffix: Suffix for annotation files (default: ".txt") Returns: - Hugging Face Dataset with image and objects columns + Hugging Face Dataset with the following columns: + - image: PIL Image + - objects: dict containing: + - bbox: List of bounding boxes in [xmin, ymin, width, height] format + - category: List of object category names + - area: List of bounding box areas + - id: List of object IDs (0-based indices) + + Example: + >>> dataset = load_synth_dataset(Path("outputs/synth")) + >>> sample = dataset[0] + >>> # sample["image"] - PIL Image + >>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...] + >>> # sample["objects"]["category"] - ["category_name", ...] """ synth_dir = Path(synth_dir) image_files = sorted(synth_dir.glob("synth_*.jpg")) @@ -26,7 +122,7 @@ def load_synth_dataset( return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image()) image_paths: list[str] = [] - all_objects: list[dict[str, Any]] = [] + all_objects: list[ObjectsDict] = [] for img_path in image_files: image_paths.append(str(img_path)) @@ -39,26 +135,24 @@ def load_synth_dataset( bboxes: list[list[float]] = [] categories: list[str] = [] areas: list[float] = [] - ids: list[int] = [] + obj_id = 0 - with open(anno_path, "r") as f: - for idx, line in enumerate(f): + with open(anno_path, "r", encoding="utf-8") as f: + for line in f: if not (line := line.strip()): continue - parts = line.split() - if len(parts) != 5: + result = _parse_bbox_line(line) + if result is None: continue - xmin, ymin, xmax, ymax = map(int, parts[1:]) - width, height = xmax - xmin, ymax - ymin + category, bbox, area = result + bboxes.append(bbox) + categories.append(category) + areas.append(area) + obj_id += 1 - bboxes.append([float(xmin), float(ymin), float(width), float(height)]) - categories.append(parts[0]) - areas.append(float(width * height)) - ids.append(idx) - - all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": ids}) + all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": list(range(len(bboxes)))}) dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects}) return dataset.cast_column("image", Image()) @@ -68,14 +162,29 @@ def load_val_dataset( scenes_dir: Path, split: str = "easy", ) -> Dataset: - """Load validation dataset from scene images. + """Load validation dataset from scene images with VOC-format XML annotations. Args: scenes_dir: Directory containing scene subdirectories split: Scene split to load ('easy' or 'hard') Returns: - Hugging Face Dataset with image and image_id columns + Hugging Face Dataset with the following columns: + - image: PIL Image + - image_id: Image identifier (filename stem without extension) + - objects: dict containing (loaded from XML annotations): + - bbox: List of bounding boxes in [xmin, ymin, width, height] format + - category: List of object category names + - area: List of bounding box areas + - id: List of object IDs (0-based indices) + + Example: + >>> dataset = load_val_dataset(Path("datasets/InsDet-FULL/Scenes"), "easy") + >>> sample = dataset[0] + >>> # sample["image"] - PIL Image + >>> # sample["image_id"] - "rgb_000" + >>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...] + >>> # sample["objects"]["category"] - ["category_name", ...] """ scenes_dir = Path(scenes_dir) split_dir = scenes_dir / split @@ -86,11 +195,27 @@ def load_val_dataset( rgb_files = sorted(split_dir.glob("*/rgb_*.jpg")) if not rgb_files: - return Dataset.from_dict({"image": [], "image_id": []}).cast_column("image", Image()) + return Dataset.from_dict({"image": [], "image_id": [], "objects": []}).cast_column("image", Image()) + + image_paths: list[str] = [] + image_ids: list[str] = [] + all_objects: list[ObjectsDict] = [] + + for img_path in rgb_files: + image_paths.append(str(img_path)) + image_ids.append(img_path.stem) + + xml_path = img_path.with_suffix(".xml") + if xml_path.exists(): + objects: ObjectsDict = _parse_voc_xml(xml_path) + else: + objects = {"bbox": [], "category": [], "area": [], "id": []} + + all_objects.append(objects) dataset = Dataset.from_dict({ - "image": [str(p) for p in rgb_files], - "image_id": [p.stem for p in rgb_files], + "image": image_paths, + "image_id": image_ids, + "objects": all_objects, }) - return dataset.cast_column("image", Image()) diff --git a/mini-nav/data_loading/synthesizer.py b/mini-nav/data_loading/synthesizer.py index 553b466..eecebf0 100644 --- a/mini-nav/data_loading/synthesizer.py +++ b/mini-nav/data_loading/synthesizer.py @@ -57,7 +57,11 @@ class ImageSynthesizer: """List of background image paths.""" if self._background_categories is None: self._background_categories = sorted( - [p.name for p in self.background_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png"]] + [ + p.name + for p in self.background_dir.iterdir() + if p.suffix in [".jpg", ".jpeg", ".png"] + ] ) # Return as list of Path for type compatibility return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value] @@ -126,7 +130,9 @@ class ImageSynthesizer: mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True) return image, mask - def _compute_overlap(self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]) -> float: + def _compute_overlap( + self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int] + ) -> float: """Compute overlap ratio between two boxes. Args: @@ -153,7 +159,26 @@ class ImageSynthesizer: box2_area = (x2_max - x2_min) * (y2_max - y2_min) min_area = min(box1_area, box2_area) - return inter_area / min_area if min_area > 0 else 0.0 + return inter_area / (min_area if min_area > 0 else 0.0) + + def _has_overlap( + self, + new_box: tuple[int, int, int, int], + existing_boxes: list[tuple[int, int, int, int]], + ) -> bool: + """Check if new_box overlaps with any existing boxes. + + Args: + new_box: The new bounding box (xmin, ymin, xmax, ymax) + existing_boxes: List of existing bounding boxes + + Returns: + True if any overlap exceeds threshold, False otherwise + """ + for existing_box in existing_boxes: + if self._compute_overlap(new_box, existing_box) > self.overlap_threshold: + return True + return False def _place_object( self, @@ -182,7 +207,7 @@ class ImageSynthesizer: new_w = int(obj_w * scale) new_h = int(obj_h * scale) - if new_w <= 0 or new_h <= 0: + if new_w <= 0 or new_h <= 0 or new_w > bg_w or new_h > bg_h: return None obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS) @@ -198,32 +223,18 @@ class ImageSynthesizer: new_box = (x, y, x + new_w, y + new_h) # Check overlap with existing boxes - valid = True - for existing_box in existing_boxes: - overlap = self._compute_overlap(new_box, existing_box) - if overlap > self.overlap_threshold: - valid = False - break - - if valid: - # Composite object onto background + if not self._has_overlap(new_box, existing_boxes): + # Composite object onto background using Pillow's paste method background = background.copy() - mask_array = np.array(obj_mask) / 255.0 - bg_array = np.array(background) - obj_array = np.array(obj_image) + background.paste(obj_image, (x, y), mask=obj_mask) - # Apply mask - mask_3d = np.stack([mask_array] * 3, axis=-1) - bg_array[y:y+new_h, x:x+new_w] = ( - bg_array[y:y+new_h, x:x+new_w] * (1 - mask_3d) + - obj_array * mask_3d - ) - - return Image.fromarray(bg_array), obj_mask, new_box + return background, obj_mask, new_box return None - def synthesize_scene(self) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]: + def synthesize_scene( + self, + ) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]: """Synthesize a single scene with random objects. Returns: @@ -251,10 +262,14 @@ class ImageSynthesizer: # Get random rotation angle = random.uniform(*self.rotation_range) - obj_image, obj_mask = self._rotate_image_and_mask(obj_image, obj_mask, angle) + obj_image, obj_mask = self._rotate_image_and_mask( + obj_image, obj_mask, angle + ) # Try to place object - result = self._place_object(background, obj_image, obj_mask, placed_boxes, scale) + result = self._place_object( + background, obj_image, obj_mask, placed_boxes, scale + ) if result is not None: background, _, box = result @@ -286,7 +301,7 @@ class ImageSynthesizer: # Save annotation anno_path = self.output_dir / f"synth_{i:04d}.txt" - with open(anno_path, "w") as f: + with open(anno_path, "w", encoding="utf-8") as f: for category, xmin, ymin, xmax, ymax in annotations: f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")