refactor(configs, data_loading): improve code clarity and add docstrings

This commit is contained in:
2026-02-28 22:07:25 +08:00
parent 88d1d0790d
commit f0479cc69b
6 changed files with 201 additions and 59 deletions

View File

@@ -1,3 +1,5 @@
"""Configuration management module for unified config."""
from .config import ( from .config import (
ConfigManager, ConfigManager,
cfg_manager, cfg_manager,
@@ -8,7 +10,6 @@ from .models import (
DatasetConfig, DatasetConfig,
ModelConfig, ModelConfig,
OutputConfig, OutputConfig,
PoolingType,
) )
__all__ = [ __all__ = [
@@ -17,7 +18,6 @@ __all__ = [
"OutputConfig", "OutputConfig",
"DatasetConfig", "DatasetConfig",
"Config", "Config",
"PoolingType",
# Loader # Loader
"load_yaml", "load_yaml",
"save_yaml", "save_yaml",

View File

@@ -28,7 +28,7 @@ class ConfigManager:
"""Load configuration from config.yaml file. """Load configuration from config.yaml file.
Returns: Returns:
Loaded and validated FeatureCompressorConfig instance Loaded and validated Config instance
""" """
config = load_yaml(self.config_path, Config) config = load_yaml(self.config_path, Config)
self._config = config self._config = config
@@ -38,7 +38,7 @@ class ConfigManager:
"""Get loaded configuration, auto-loading if not already loaded. """Get loaded configuration, auto-loading if not already loaded.
Returns: Returns:
FeatureCompressorConfig instance Config instance
Note: Note:
Automatically loads config if not already loaded Automatically loads config if not already loaded

View File

@@ -26,8 +26,8 @@ class OutputConfig(BaseModel):
@field_validator("directory", mode="after") @field_validator("directory", mode="after")
def convert_to_absolute(cls, v: Path) -> Path: def convert_to_absolute(cls, v: Path) -> Path:
""" """Converts the path to an absolute path relative to the project root.
Converts the path to an absolute path relative to the current working directory.
This works even if the path doesn't exist on disk. This works even if the path doesn't exist on disk.
""" """
if v.is_absolute(): if v.is_absolute():
@@ -55,8 +55,8 @@ class DatasetConfig(BaseModel):
@field_validator("dataset_root", "output_dir", mode="after") @field_validator("dataset_root", "output_dir", mode="after")
def convert_to_absolute(cls, v: Path) -> Path: def convert_to_absolute(cls, v: Path) -> Path:
""" """Converts the path to an absolute path relative to the project root.
Converts the path to an absolute path relative to the project root.
This works even if the path doesn't exist on disk. This works even if the path doesn't exist on disk.
""" """
if v.is_absolute(): if v.is_absolute():

View File

@@ -1,3 +1,5 @@
"""Data loading module for synthetic and validation datasets."""
from .loader import load_synth_dataset, load_val_dataset from .loader import load_synth_dataset, load_val_dataset
from .synthesizer import ImageSynthesizer from .synthesizer import ImageSynthesizer

View File

@@ -1,10 +1,93 @@
"""Data loaders for synthetic and validation datasets using Hugging Face datasets.""" """Data loaders for synthetic and validation datasets using Hugging Face datasets."""
import xml.etree.ElementTree as ET
from pathlib import Path from pathlib import Path
from typing import Any
from datasets import Dataset, Image from datasets import Dataset, Image
# Type alias for objects annotation
ObjectsDict = dict[str, list[list[float]] | list[str] | list[int] | list[float]]
def _parse_bbox_line(line: str) -> tuple[str, list[float], float] | None:
"""Parse a single line from synth annotation file.
Args:
line: Line in format "category xmin ymin xmax ymax"
Returns:
Tuple of (category, [xmin, ymin, width, height], area) or None if invalid
"""
parts = line.split()
if len(parts) != 5:
return None
category = parts[0]
xmin, ymin, xmax, ymax = map(int, parts[1:])
width = xmax - xmin
height = ymax - ymin
area = width * height
return category, [float(xmin), float(ymin), float(width), float(height)], float(area)
def _get_element_text(element: ET.Element | None, default: str = "0") -> str:
"""Get text from XML element, returning default if element or text is None."""
if element is None:
return default
return element.text if element.text is not None else default
def _parse_voc_xml(xml_path: Path) -> ObjectsDict:
"""Parse a VOC-format XML annotation file.
Args:
xml_path: Path to the XML annotation file
Returns:
Dictionary containing:
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
- category: List of object category names
- area: List of bounding box areas
- id: List of object IDs (0-based indices)
"""
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes: list[list[float]] = []
categories: list[str] = []
areas: list[float] = []
for obj in root.findall("object"):
name_elem = obj.find("name")
bndbox = obj.find("bndbox")
if name_elem is None or bndbox is None:
continue
name = name_elem.text
if name is None:
continue
xmin = int(_get_element_text(bndbox.find("xmin")))
ymin = int(_get_element_text(bndbox.find("ymin")))
xmax = int(_get_element_text(bndbox.find("xmax")))
ymax = int(_get_element_text(bndbox.find("ymax")))
width = xmax - xmin
height = ymax - ymin
bboxes.append([float(xmin), float(ymin), float(width), float(height)])
categories.append(name)
areas.append(float(width * height))
return {
"bbox": bboxes,
"category": categories,
"area": areas,
"id": list(range(len(bboxes))),
}
def load_synth_dataset( def load_synth_dataset(
synth_dir: Path, synth_dir: Path,
@@ -14,10 +97,23 @@ def load_synth_dataset(
Args: Args:
synth_dir: Directory containing synthesized images and annotations synth_dir: Directory containing synthesized images and annotations
annotations_suffix: Suffix for annotation files annotations_suffix: Suffix for annotation files (default: ".txt")
Returns: Returns:
Hugging Face Dataset with image and objects columns Hugging Face Dataset with the following columns:
- image: PIL Image
- objects: dict containing:
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
- category: List of object category names
- area: List of bounding box areas
- id: List of object IDs (0-based indices)
Example:
>>> dataset = load_synth_dataset(Path("outputs/synth"))
>>> sample = dataset[0]
>>> # sample["image"] - PIL Image
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
>>> # sample["objects"]["category"] - ["category_name", ...]
""" """
synth_dir = Path(synth_dir) synth_dir = Path(synth_dir)
image_files = sorted(synth_dir.glob("synth_*.jpg")) image_files = sorted(synth_dir.glob("synth_*.jpg"))
@@ -26,7 +122,7 @@ def load_synth_dataset(
return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image()) return Dataset.from_dict({"image": [], "objects": []}).cast_column("image", Image())
image_paths: list[str] = [] image_paths: list[str] = []
all_objects: list[dict[str, Any]] = [] all_objects: list[ObjectsDict] = []
for img_path in image_files: for img_path in image_files:
image_paths.append(str(img_path)) image_paths.append(str(img_path))
@@ -39,26 +135,24 @@ def load_synth_dataset(
bboxes: list[list[float]] = [] bboxes: list[list[float]] = []
categories: list[str] = [] categories: list[str] = []
areas: list[float] = [] areas: list[float] = []
ids: list[int] = [] obj_id = 0
with open(anno_path, "r") as f: with open(anno_path, "r", encoding="utf-8") as f:
for idx, line in enumerate(f): for line in f:
if not (line := line.strip()): if not (line := line.strip()):
continue continue
parts = line.split() result = _parse_bbox_line(line)
if len(parts) != 5: if result is None:
continue continue
xmin, ymin, xmax, ymax = map(int, parts[1:]) category, bbox, area = result
width, height = xmax - xmin, ymax - ymin bboxes.append(bbox)
categories.append(category)
areas.append(area)
obj_id += 1
bboxes.append([float(xmin), float(ymin), float(width), float(height)]) all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": list(range(len(bboxes)))})
categories.append(parts[0])
areas.append(float(width * height))
ids.append(idx)
all_objects.append({"bbox": bboxes, "category": categories, "area": areas, "id": ids})
dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects}) dataset = Dataset.from_dict({"image": image_paths, "objects": all_objects})
return dataset.cast_column("image", Image()) return dataset.cast_column("image", Image())
@@ -68,14 +162,29 @@ def load_val_dataset(
scenes_dir: Path, scenes_dir: Path,
split: str = "easy", split: str = "easy",
) -> Dataset: ) -> Dataset:
"""Load validation dataset from scene images. """Load validation dataset from scene images with VOC-format XML annotations.
Args: Args:
scenes_dir: Directory containing scene subdirectories scenes_dir: Directory containing scene subdirectories
split: Scene split to load ('easy' or 'hard') split: Scene split to load ('easy' or 'hard')
Returns: Returns:
Hugging Face Dataset with image and image_id columns Hugging Face Dataset with the following columns:
- image: PIL Image
- image_id: Image identifier (filename stem without extension)
- objects: dict containing (loaded from XML annotations):
- bbox: List of bounding boxes in [xmin, ymin, width, height] format
- category: List of object category names
- area: List of bounding box areas
- id: List of object IDs (0-based indices)
Example:
>>> dataset = load_val_dataset(Path("datasets/InsDet-FULL/Scenes"), "easy")
>>> sample = dataset[0]
>>> # sample["image"] - PIL Image
>>> # sample["image_id"] - "rgb_000"
>>> # sample["objects"]["bbox"] - [[xmin, ymin, width, height], ...]
>>> # sample["objects"]["category"] - ["category_name", ...]
""" """
scenes_dir = Path(scenes_dir) scenes_dir = Path(scenes_dir)
split_dir = scenes_dir / split split_dir = scenes_dir / split
@@ -86,11 +195,27 @@ def load_val_dataset(
rgb_files = sorted(split_dir.glob("*/rgb_*.jpg")) rgb_files = sorted(split_dir.glob("*/rgb_*.jpg"))
if not rgb_files: if not rgb_files:
return Dataset.from_dict({"image": [], "image_id": []}).cast_column("image", Image()) return Dataset.from_dict({"image": [], "image_id": [], "objects": []}).cast_column("image", Image())
image_paths: list[str] = []
image_ids: list[str] = []
all_objects: list[ObjectsDict] = []
for img_path in rgb_files:
image_paths.append(str(img_path))
image_ids.append(img_path.stem)
xml_path = img_path.with_suffix(".xml")
if xml_path.exists():
objects: ObjectsDict = _parse_voc_xml(xml_path)
else:
objects = {"bbox": [], "category": [], "area": [], "id": []}
all_objects.append(objects)
dataset = Dataset.from_dict({ dataset = Dataset.from_dict({
"image": [str(p) for p in rgb_files], "image": image_paths,
"image_id": [p.stem for p in rgb_files], "image_id": image_ids,
"objects": all_objects,
}) })
return dataset.cast_column("image", Image()) return dataset.cast_column("image", Image())

View File

@@ -57,7 +57,11 @@ class ImageSynthesizer:
"""List of background image paths.""" """List of background image paths."""
if self._background_categories is None: if self._background_categories is None:
self._background_categories = sorted( self._background_categories = sorted(
[p.name for p in self.background_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".png"]] [
p.name
for p in self.background_dir.iterdir()
if p.suffix in [".jpg", ".jpeg", ".png"]
]
) )
# Return as list of Path for type compatibility # Return as list of Path for type compatibility
return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value] return [self.background_dir / name for name in self._background_categories] # type: ignore[return-value]
@@ -126,7 +130,9 @@ class ImageSynthesizer:
mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True) mask = mask.rotate(angle, resample=Resampling.BILINEAR, expand=True)
return image, mask return image, mask
def _compute_overlap(self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]) -> float: def _compute_overlap(
self, box1: tuple[int, int, int, int], box2: tuple[int, int, int, int]
) -> float:
"""Compute overlap ratio between two boxes. """Compute overlap ratio between two boxes.
Args: Args:
@@ -153,7 +159,26 @@ class ImageSynthesizer:
box2_area = (x2_max - x2_min) * (y2_max - y2_min) box2_area = (x2_max - x2_min) * (y2_max - y2_min)
min_area = min(box1_area, box2_area) min_area = min(box1_area, box2_area)
return inter_area / min_area if min_area > 0 else 0.0 return inter_area / (min_area if min_area > 0 else 0.0)
def _has_overlap(
self,
new_box: tuple[int, int, int, int],
existing_boxes: list[tuple[int, int, int, int]],
) -> bool:
"""Check if new_box overlaps with any existing boxes.
Args:
new_box: The new bounding box (xmin, ymin, xmax, ymax)
existing_boxes: List of existing bounding boxes
Returns:
True if any overlap exceeds threshold, False otherwise
"""
for existing_box in existing_boxes:
if self._compute_overlap(new_box, existing_box) > self.overlap_threshold:
return True
return False
def _place_object( def _place_object(
self, self,
@@ -182,7 +207,7 @@ class ImageSynthesizer:
new_w = int(obj_w * scale) new_w = int(obj_w * scale)
new_h = int(obj_h * scale) new_h = int(obj_h * scale)
if new_w <= 0 or new_h <= 0: if new_w <= 0 or new_h <= 0 or new_w > bg_w or new_h > bg_h:
return None return None
obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS) obj_image = obj_image.resize((new_w, new_h), Resampling.LANCZOS)
@@ -198,32 +223,18 @@ class ImageSynthesizer:
new_box = (x, y, x + new_w, y + new_h) new_box = (x, y, x + new_w, y + new_h)
# Check overlap with existing boxes # Check overlap with existing boxes
valid = True if not self._has_overlap(new_box, existing_boxes):
for existing_box in existing_boxes: # Composite object onto background using Pillow's paste method
overlap = self._compute_overlap(new_box, existing_box)
if overlap > self.overlap_threshold:
valid = False
break
if valid:
# Composite object onto background
background = background.copy() background = background.copy()
mask_array = np.array(obj_mask) / 255.0 background.paste(obj_image, (x, y), mask=obj_mask)
bg_array = np.array(background)
obj_array = np.array(obj_image)
# Apply mask return background, obj_mask, new_box
mask_3d = np.stack([mask_array] * 3, axis=-1)
bg_array[y:y+new_h, x:x+new_w] = (
bg_array[y:y+new_h, x:x+new_w] * (1 - mask_3d) +
obj_array * mask_3d
)
return Image.fromarray(bg_array), obj_mask, new_box
return None return None
def synthesize_scene(self) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]: def synthesize_scene(
self,
) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
"""Synthesize a single scene with random objects. """Synthesize a single scene with random objects.
Returns: Returns:
@@ -251,10 +262,14 @@ class ImageSynthesizer:
# Get random rotation # Get random rotation
angle = random.uniform(*self.rotation_range) angle = random.uniform(*self.rotation_range)
obj_image, obj_mask = self._rotate_image_and_mask(obj_image, obj_mask, angle) obj_image, obj_mask = self._rotate_image_and_mask(
obj_image, obj_mask, angle
)
# Try to place object # Try to place object
result = self._place_object(background, obj_image, obj_mask, placed_boxes, scale) result = self._place_object(
background, obj_image, obj_mask, placed_boxes, scale
)
if result is not None: if result is not None:
background, _, box = result background, _, box = result
@@ -286,7 +301,7 @@ class ImageSynthesizer:
# Save annotation # Save annotation
anno_path = self.output_dir / f"synth_{i:04d}.txt" anno_path = self.output_dir / f"synth_{i:04d}.txt"
with open(anno_path, "w") as f: with open(anno_path, "w", encoding="utf-8") as f:
for category, xmin, ymin, xmax, ymax in annotations: for category, xmin, ymin, xmax, ymax in annotations:
f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n") f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")