Files
Mini-Nav/mini-nav/datasets/loader.py

106 lines
3.3 KiB
Python

"""Data loaders for synthetic and validation datasets."""
from collections.abc import Iterator
from pathlib import Path
from PIL import Image
class SynthDataset:
"""Dataset loader for synthesized training images."""
def __init__(self, synth_dir: Path, annotations_suffix: str = ".txt"):
"""
Initialize the synthetic dataset loader.
Args:
synth_dir: Directory containing synthesized images and annotations
annotations_suffix: Suffix for annotation files
"""
self.synth_dir = Path(synth_dir)
self.annotations_suffix = annotations_suffix
# Find all images
self.image_files = sorted(self.synth_dir.glob("synth_*.jpg"))
def __len__(self) -> int:
return len(self.image_files)
def __getitem__(self, idx: int) -> tuple[Image.Image, list[tuple[str, int, int, int, int]]]:
"""Get a single item.
Args:
idx: Index of the item
Returns:
Tuple of (image, annotations) where annotations is a list of
(category, xmin, ymin, xmax, ymax)
"""
img_path = self.image_files[idx]
image = Image.open(img_path).convert("RGB")
# Load annotations
anno_path = img_path.with_suffix(self.annotations_suffix)
annotations: list[tuple[str, int, int, int, int]] = []
if anno_path.exists():
with open(anno_path, "r") as f:
for line in f:
line = line.strip()
if line:
parts = line.split()
if len(parts) == 5:
category = parts[0]
xmin, ymin, xmax, ymax = map(int, parts[1:])
annotations.append((category, xmin, ymin, xmax, ymax))
return image, annotations
def __iter__(self) -> Iterator[tuple[Image.Image, list[tuple[str, int, int, int, int]]]]:
"""Iterate over the dataset."""
for i in range(len(self)):
yield self[i]
class ValDataset:
"""Dataset loader for validation scene images."""
def __init__(self, scenes_dir: Path, split: str = "easy"):
"""
Initialize the validation dataset loader.
Args:
scenes_dir: Directory containing scene subdirectories
split: Scene split to load ('easy' or 'hard')
"""
self.scenes_dir = Path(scenes_dir)
self.split = split
self.split_dir = self.scenes_dir / split
if not self.split_dir.exists():
raise ValueError(f"Scene split directory not found: {self.split_dir}")
# Find all RGB images
self.image_files = sorted(self.split_dir.glob("*/rgb_*.jpg"))
def __len__(self) -> int:
return len(self.image_files)
def __getitem__(self, idx: int) -> tuple[Image.Image, Path]:
"""Get a single item.
Args:
idx: Index of the item
Returns:
Tuple of (image, scene_path)
"""
img_path = self.image_files[idx]
image = Image.open(img_path).convert("RGB")
return image, img_path.parent
def __iter__(self) -> Iterator[tuple[Image.Image, Path]]:
"""Iterate over the dataset."""
for i in range(len(self)):
yield self[i]