refactor(configs, data_loading): improve code clarity and add docstrings

This commit is contained in:
2026-02-28 22:07:25 +08:00
parent 88d1d0790d
commit f0479cc69b
6 changed files with 201 additions and 59 deletions

View File

@@ -1,3 +1,5 @@
"""Configuration management module for unified config."""
from .config import (
ConfigManager,
cfg_manager,
@@ -8,7 +10,6 @@ from .models import (
DatasetConfig,
ModelConfig,
OutputConfig,
PoolingType,
)
__all__ = [
@@ -17,7 +18,6 @@ __all__ = [
"OutputConfig",
"DatasetConfig",
"Config",
"PoolingType",
# Loader
"load_yaml",
"save_yaml",

View File

@@ -28,7 +28,7 @@ class ConfigManager:
"""Load configuration from config.yaml file.
Returns:
Loaded and validated FeatureCompressorConfig instance
Loaded and validated Config instance
"""
config = load_yaml(self.config_path, Config)
self._config = config
@@ -38,7 +38,7 @@ class ConfigManager:
"""Get loaded configuration, auto-loading if not already loaded.
Returns:
FeatureCompressorConfig instance
Config instance
Note:
Automatically loads config if not already loaded

View File

@@ -26,8 +26,8 @@ class OutputConfig(BaseModel):
@field_validator("directory", mode="after")
def convert_to_absolute(cls, v: Path) -> Path:
"""
Converts the path to an absolute path relative to the current working directory.
"""Converts the path to an absolute path relative to the project root.
This works even if the path doesn't exist on disk.
"""
if v.is_absolute():
@@ -55,8 +55,8 @@ class DatasetConfig(BaseModel):
@field_validator("dataset_root", "output_dir", mode="after")
def convert_to_absolute(cls, v: Path) -> Path:
"""
Converts the path to an absolute path relative to the project root.
"""Converts the path to an absolute path relative to the project root.
This works even if the path doesn't exist on disk.
"""
if v.is_absolute():

View File

@@ -1,3 +1,5 @@
"""Data loading module for synthetic and validation datasets."""
from .loader import load_synth_dataset, load_val_dataset
from .synthesizer import ImageSynthesizer

View File

@@ -1,10 +1,93 @@
"""Data loaders for synthetic and validation datasets using Hugging Face datasets."""
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any
from datasets import Dataset, Image
# Type alias for objects annotation
ObjectsDict = dict[str, list[list[float]] | list[str] | list[int] | list[float]]
def _parse_bbox_line(line: str) -> tuple[str, list[float], float] | None:
"""Parse a single line from synth annotation file.
Args:
line: Line in format "category xmin ymin xmax ymax"
Returns:
Tuple of (category, [xmin, ymin, width, height], area) or None if invalid
"""
parts = line.split()
if len(parts) != 5:
return None
category = parts[0]
xmin, ymin, xmax, ymax = map(int, parts[1:])
width = xmax - xmin
height = ymax - ymin
area = width * height
return category, [float(xmin), float(ymin), float(width), float(height)], float(area)
def _get_element_text(element: ET.Element | None, default: str = "0") -> str:
"""Get text from XML element, returning default if element or text is None."""
if element is None:
return default
return element.text if element.text is not None else default
def _parse_voc_xml(xml_path: Path) -> ObjectsDict:
"""Parse a VOC-format XML annotation file.
Args:
xml_path: Path to the XML annotation file
Returns:
Dictionary containing:
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
- category: List of object category names
- area: List of bounding box areas
- id: List of object IDs (0-based indices)
"""
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes: list[list[float]] = []
categories: list[str] = []
areas: list[float] = []
for obj in root.findall("object"):
name_elem = obj.find("name")
bndbox = obj.find("bndbox")
if name_elem is None or bndbox is None:
continue
name = name_elem.text
if name is None:
continue
xmin = int(_get_element_text(bndbox.find("xmin")))
ymin = int(_get_element_text(bndbox.find("ymin")))
xmax = int(_get_element_text(bndbox.find("xmax")))
ymax = int(_get_element_text(bndbox.find("ymax")))
width = xmax - xmin
height = ymax - ymin
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
categories.append(name)
areas.append(float(width * height))
return {
"bbox": bboxes,
"category": categories,
"area": areas,
"id": list(range(len(bboxes))),
}
def load_synth_dataset(
synth_dir: Path,
@@ -14,10 +97,23 @@ def load_synth_dataset(
Args:
synth_dir: Directory containing synthesized images and annotations
annotations_suffix: Suffix for annotation files
annotations_suffix: Suffix for annotation files (default: ".txt")
Returns:
Hugging Face Dataset with image and objects columns
Hugging Face Dataset with the following columns:
- image: PIL Image
- objects: dict containing:
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
- category: List of object category names
- area: List of bounding box areas
- id: List of object IDs (0-based indices)
Example:
>>> dataset = load_synth_dataset(Path("outputs/synth"))
>>> sample = dataset[0]
>>> # sample["image"] - PIL Image
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
>>> # sample["objects"]["category"] - ["category_name", ...]
"""
synth_dir = Path(synth_dir)
image_files = sorted(synth_dir.glob("synth_*.jpg"))
@@ -26,7 +122,7 @@ def load_synth_dataset(
return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image())
image_paths: list[str] = []
all_objects: list[dict[str, Any]] = []
all_objects: list[ObjectsDict] = []
for img_path in image_files:
image_paths.append(str(img_path))
@@ -39,26 +135,24 @@ def load_synth_dataset(
bboxes: list[list[float]] = []
categories: list[str] = []
areas: list[float] = []
ids: list[int] = []
obj_id = 0
with open(anno_path, "r") as f:
for idx, line in enumerate(f):
with open(anno_path, "r", encoding="utf-8") as f:
for line in f:
if not (line := line.strip()):
continue
parts = line.split()
if len(parts) != 5:
result = _parse_bbox_line(line)
if result is None:
continue
xmin, ymin, xmax, ymax = map(int, parts[1:])
width, height = xmax - xmin, ymax - ymin
category, bbox, area = result
bboxes.append(bbox)
categories.append(category)
areas.append(area)
obj_id += 1
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
categories.append(parts[0])
areas.append(float(width * height))
ids.append(idx)
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": ids})
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": list(range(len(bboxes)))})
dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects})
return dataset.cast_column("image", Image())
@@ -68,14 +162,29 @@ def load_val_dataset(
scenes_dir: Path,
split: str = "easy",
) -> Dataset:
"""Load validation dataset from scene images.
"""Load validation dataset from scene images with VOC-format XML annotations.
Args:
scenes_dir: Directory containing scene subdirectories
split: Scene split to load ('easy' or 'hard')
Returns:
Hugging Face Dataset with image and image_id columns
Hugging Face Dataset with the following columns:
- image: PIL Image
- image_id: Image identifier (filename stem without extension)
- objects: dict containing (loaded from XML annotations):
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
- category: List of object category names
- area: List of bounding box areas
- id: List of object IDs (0-based indices)
Example:
>>> dataset = load_val_dataset(Path("datasets/InsDet-FULL/Scenes"), "easy")
>>> sample = dataset[0]
>>> # sample["image"] - PIL Image
>>> # sample["image_id"] - "rgb_000"
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
>>> # sample["objects"]["category"] - ["category_name", ...]
"""
scenes_dir = Path(scenes_dir)
split_dir = scenes_dir / split
@@ -86,11 +195,27 @@ def load_val_dataset(
rgb_files = sorted(split_dir.glob("*/rgb_*.jpg"))
if not rgb_files:
return Dataset.from_dict({"image": [], "image_id": []}).cast_column("image", Image())
return Dataset.from_dict({"image": [], "image_id": [], "objects": []}).cast_column("image", Image())
image_paths: list[str] = []
image_ids: list[str] = []
all_objects: list[ObjectsDict] = []
for img_path in rgb_files:
image_paths.append(str(img_path))
image_ids.append(img_path.stem)
xml_path = img_path.with_suffix(".xml")
if xml_path.exists():
objects: ObjectsDict = _parse_voc_xml(xml_path)
else:
objects = {"bbox": [], "category": [], "area": [], "id": []}
all_objects.append(objects)
dataset = Dataset.from_dict({
"image": [str(p) for p in rgb_files],
"image_id": [p.stem for p in rgb_files],
"image": image_paths,
"image_id": image_ids,
"objects": all_objects,
})
return dataset.cast_column("image", Image())

View File

@@ -57,7 +57,11 @@ class ImageSynthesizer:
"""List of background image paths."""
if self._background_categories is None:
self._background_categories = sorted(
[p.name for p in self.background_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png"]]
[
p.name
for p in self.background_dir.iterdir()
if p.suffix in [".jpg", ".jpeg", ".png"]
]
)
# Return as list of Path for type compatibility
return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value]
@@ -126,7 +130,9 @@ class ImageSynthesizer:
mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True)
return image, mask
def _compute_overlap(self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]) -> float:
def _compute_overlap(
self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]
) -> float:
"""Compute overlap ratio between two boxes.
Args:
@@ -153,7 +159,26 @@ class ImageSynthesizer:
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
min_area = min(box1_area, box2_area)
return inter_area / min_area if min_area > 0 else 0.0
return inter_area / (min_area if min_area > 0 else 0.0)
def _has_overlap(
self,
new_box: tuple[int, int, int, int],
existing_boxes: list[tuple[int, int, int, int]],
) -> bool:
"""Check if new_box overlaps with any existing boxes.
Args:
new_box: The new bounding box (xmin, ymin, xmax, ymax)
existing_boxes: List of existing bounding boxes
Returns:
True if any overlap exceeds threshold, False otherwise
"""
for existing_box in existing_boxes:
if self._compute_overlap(new_box, existing_box) > self.overlap_threshold:
return True
return False
def _place_object(
self,
@@ -182,7 +207,7 @@ class ImageSynthesizer:
new_w = int(obj_w * scale)
new_h = int(obj_h * scale)
if new_w <= 0 or new_h <= 0:
if new_w <= 0 or new_h <= 0 or new_w > bg_w or new_h > bg_h:
return None
obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS)
@@ -198,32 +223,18 @@ class ImageSynthesizer:
new_box = (x, y, x + new_w, y + new_h)
# Check overlap with existing boxes
valid = True
for existing_box in existing_boxes:
overlap = self._compute_overlap(new_box, existing_box)
if overlap > self.overlap_threshold:
valid = False
break
if valid:
# Composite object onto background
if not self._has_overlap(new_box, existing_boxes):
# Composite object onto background using Pillow's paste method
background = background.copy()
mask_array = np.array(obj_mask) / 255.0
bg_array = np.array(background)
obj_array = np.array(obj_image)
background.paste(obj_image, (x, y), mask=obj_mask)
# Apply mask
mask_3d = np.stack([mask_array] * 3, axis=-1)
bg_array[y:y+new_h, x:x+new_w] = (
bg_array[y:y+new_h, x:x+new_w] * (1 - mask_3d) +
obj_array * mask_3d
)
return Image.fromarray(bg_array), obj_mask, new_box
return background, obj_mask, new_box
return None
def synthesize_scene(self) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
def synthesize_scene(
self,
) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
"""Synthesize a single scene with random objects.
Returns:
@@ -251,10 +262,14 @@ class ImageSynthesizer:
# Get random rotation
angle = random.uniform(*self.rotation_range)
obj_image, obj_mask = self._rotate_image_and_mask(obj_image, obj_mask, angle)
obj_image, obj_mask = self._rotate_image_and_mask(
obj_image, obj_mask, angle
)
# Try to place object
result = self._place_object(background, obj_image, obj_mask, placed_boxes, scale)
result = self._place_object(
background, obj_image, obj_mask, placed_boxes, scale
)
if result is not None:
background, _, box = result
@@ -286,7 +301,7 @@ class ImageSynthesizer:
# Save annotation
anno_path = self.output_dir / f"synth_{i:04d}.txt"
with open(anno_path, "w") as f:
with open(anno_path, "w", encoding="utf-8") as f:
for category, xmin, ymin, xmax, ymax in annotations:
f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")