feat(synthesizer): add CSV export and progress bar for dataset generation

This commit is contained in:
2026-03-02 13:01:47 +08:00
parent f0479cc69b
commit 370c4a6588
2 changed files with 29 additions and 10 deletions

View File

@@ -1,11 +1,13 @@
"""Image synthesizer for generating synthetic object detection datasets."""
import csv
import random
from pathlib import Path
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from tqdm.auto import tqdm
class ImageSynthesizer:
@@ -240,9 +242,6 @@ class ImageSynthesizer:
Returns:
Tuple of (synthesized_image, list of (category, xmin, ymin, xmax, ymax))
"""
random.seed(self.seed)
np.random.seed(self.seed)
# Load background
background, _ = self.get_random_background()
@@ -288,7 +287,7 @@ class ImageSynthesizer:
generated_files: list[Path] = []
for i in range(self.num_scenes):
for i in tqdm(range(self.num_scenes), desc="Generating scenes"):
# Update seed for each scene
random.seed(self.seed + i)
np.random.seed(self.seed + i)
@@ -301,9 +300,10 @@ class ImageSynthesizer:
# Save annotation
anno_path = self.output_dir / f"synth_{i:04d}.txt"
with open(anno_path, "w", encoding="utf-8") as f:
with open(anno_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
for category, xmin, ymin, xmax, ymax in annotations:
f.write(f"{category} {xmin} {ymin} {xmax} {ymax}\n")
writer.writerow([category, xmin, ymin, xmax, ymax])
generated_files.append(img_path)

View File

@@ -4,15 +4,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"action",
choices=["train", "benchmark", "visualize"],
help="Action to perform: train, benchmark, or visualize",
choices=["train", "benchmark", "visualize", "generate"],
help="Action to perform: train, benchmark, visualize, or generate",
)
args = parser.parse_args()
if args.action == "train":
from compressors import train
# 启动训练
train(
epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt"
)
@@ -20,7 +19,27 @@ if __name__ == "__main__":
from benchmarks import evaluate
evaluate("Dinov2", "CIFAR-10", "Recall@10")
else: # visualize
elif args.action == "visualize":
from visualizer import app
app.run(debug=True)
else: # generate
from configs import cfg_manager
from data_loading.synthesizer import ImageSynthesizer
config = cfg_manager.get()
dataset_cfg = config.dataset
synthesizer = ImageSynthesizer(
dataset_root=dataset_cfg.dataset_root,
output_dir=dataset_cfg.output_dir,
num_objects_range=dataset_cfg.num_objects_range,
num_scenes=dataset_cfg.num_scenes,
object_scale_range=dataset_cfg.object_scale_range,
rotation_range=dataset_cfg.rotation_range,
overlap_threshold=dataset_cfg.overlap_threshold,
seed=dataset_cfg.seed,
)
generated_files = synthesizer.generate()
print(f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}")