From 7dbd704d6b7be471a6a2d266590cdb52d0451cfa Mon Sep 17 00:00:00 2001 From: SikongJueluo Date: Fri, 6 Mar 2026 11:41:35 +0800 Subject: [PATCH] refactor(cli): migrate from argparse to typer for command-line interface --- mini-nav/commands/__init__.py | 6 ++ mini-nav/commands/benchmark.py | 53 +++++++++++++++++ mini-nav/commands/generate.py | 25 ++++++++ mini-nav/commands/train.py | 20 +++++++ mini-nav/commands/visualize.py | 12 ++++ mini-nav/main.py | 101 +++++---------------------------- pyproject.toml | 1 + uv.lock | 68 ++++++++++++++++++++-- 8 files changed, 195 insertions(+), 91 deletions(-) create mode 100644 mini-nav/commands/__init__.py create mode 100644 mini-nav/commands/benchmark.py create mode 100644 mini-nav/commands/generate.py create mode 100644 mini-nav/commands/train.py create mode 100644 mini-nav/commands/visualize.py diff --git a/mini-nav/commands/__init__.py b/mini-nav/commands/__init__.py new file mode 100644 index 0000000..a067577 --- /dev/null +++ b/mini-nav/commands/__init__.py @@ -0,0 +1,6 @@ +from .train import train +from .benchmark import benchmark +from .visualize import visualize +from .generate import generate + +__all__ = ["train", "benchmark", "visualize", "generate"] diff --git a/mini-nav/commands/benchmark.py b/mini-nav/commands/benchmark.py new file mode 100644 index 0000000..4cab0ba --- /dev/null +++ b/mini-nav/commands/benchmark.py @@ -0,0 +1,53 @@ +from typing import cast + +import typer + + +def benchmark( + ctx: typer.Context, + model_path: str = typer.Option( + None, "--model", "-m", help="Path to compressor model weights" + ), +): + import torch + from benchmarks import run_benchmark + from compressors import DinoCompressor + from configs import cfg_manager + from transformers import AutoImageProcessor, BitImageProcessorFast + from utils import get_device + + config = cfg_manager.get() + benchmark_cfg = config.benchmark + + if not benchmark_cfg.enabled: + typer.echo( + "Benchmark is not enabled. Set benchmark.enabled=true in config.yaml", + err=True, + ) + raise typer.Exit(code=1) + + device = get_device() + + model_cfg = config.model + processor = cast( + BitImageProcessorFast, + AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device), + ) + + model = DinoCompressor().to(device) + if model_path: + from compressors import HashCompressor + + compressor = HashCompressor( + input_dim=model_cfg.compression_dim, + hash_bits=model_cfg.compression_dim, + ) + compressor.load_state_dict(torch.load(model_path)) + model.compressor = compressor + + run_benchmark( + model=model, + processor=processor, + config=benchmark_cfg, + model_name="dinov2", + ) diff --git a/mini-nav/commands/generate.py b/mini-nav/commands/generate.py new file mode 100644 index 0000000..e7c0024 --- /dev/null +++ b/mini-nav/commands/generate.py @@ -0,0 +1,25 @@ +import typer + + +def generate(ctx: typer.Context): + from configs import cfg_manager + from data_loading.synthesizer import ImageSynthesizer + + config = cfg_manager.get() + dataset_cfg = config.dataset + + synthesizer = ImageSynthesizer( + dataset_root=dataset_cfg.dataset_root, + output_dir=dataset_cfg.output_dir, + num_objects_range=dataset_cfg.num_objects_range, + num_scenes=dataset_cfg.num_scenes, + object_scale_range=dataset_cfg.object_scale_range, + rotation_range=dataset_cfg.rotation_range, + overlap_threshold=dataset_cfg.overlap_threshold, + seed=dataset_cfg.seed, + ) + + generated_files = synthesizer.generate() + typer.echo( + f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}" + ) diff --git a/mini-nav/commands/train.py b/mini-nav/commands/train.py new file mode 100644 index 0000000..4362d1d --- /dev/null +++ b/mini-nav/commands/train.py @@ -0,0 +1,20 @@ +import typer + + +def train( + ctx: typer.Context, + epoch_size: int = typer.Option(10, "--epoch", "-e", help="Number of epochs"), + batch_size: int = typer.Option(64, "--batch", "-b", help="Batch size"), + lr: float = typer.Option(1e-4, "--lr", "-l", help="Learning rate"), + checkpoint_path: str = typer.Option( + "hash_checkpoint.pt", "--checkpoint", "-c", help="Checkpoint path" + ), +): + from compressors import train as train_module + + train_module( + epoch_size=epoch_size, + batch_size=batch_size, + lr=lr, + checkpoint_path=checkpoint_path, + ) diff --git a/mini-nav/commands/visualize.py b/mini-nav/commands/visualize.py new file mode 100644 index 0000000..58e2503 --- /dev/null +++ b/mini-nav/commands/visualize.py @@ -0,0 +1,12 @@ +import typer + + +def visualize( + ctx: typer.Context, + host: str = typer.Option("127.0.0.1", "--host", help="Server host"), + port: int = typer.Option(8050, "--port", "-p", help="Server port"), + debug: bool = typer.Option(True, "--debug/--no-debug", help="Enable debug mode"), +): + from visualizer import app as dash_app + + dash_app.run(host=host, port=port, debug=debug) diff --git a/mini-nav/main.py b/mini-nav/main.py index d4cf882..1bc18e7 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -1,89 +1,16 @@ -import argparse +import typer +from commands import benchmark, generate, train, visualize + +app = typer.Typer( + name="mini-nav", + help="Mini-Nav: A vision-language navigation system", + add_completion=False, +) + +app.command(name="train")(train) +app.command(name="benchmark")(benchmark) +app.command(name="visualize")(visualize) +app.command(name="generate")(generate) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "action", - choices=["train", "benchmark", "visualize", "generate"], - help="Action to perform: train, benchmark, visualize, or generate", - ) - args = parser.parse_args() - - if args.action == "train": - from compressors import train - - train( - epoch_size=10, batch_size=64, lr=1e-4, checkpoint_path="hash_checkpoint.pt" - ) - elif args.action == "benchmark": - from typing import cast - - import torch - from benchmarks import run_benchmark - from compressors import DinoCompressor - from configs import cfg_manager - from transformers import AutoImageProcessor, BitImageProcessorFast - from utils import get_device - - config = cfg_manager.get() - benchmark_cfg = config.benchmark - - if not benchmark_cfg.enabled: - print("Benchmark is not enabled. Set benchmark.enabled=true in config.yaml") - exit(1) - - device = get_device() - - # Load model and processor based on config - model_cfg = config.model - processor = cast( - BitImageProcessorFast, - AutoImageProcessor.from_pretrained(model_cfg.dino_model, device_map=device), - ) - - # Load compressor weights if specified in model config - model = DinoCompressor().to(device) - if model_cfg.compressor_path is not None: - from compressors import HashCompressor - - compressor = HashCompressor( - input_dim=model_cfg.compression_dim, - output_dim=model_cfg.compression_dim, - ) - compressor.load_state_dict(torch.load(model_cfg.compressor_path)) - # Wrap with compressor if path is specified - model.compressor = compressor - - # Run benchmark - run_benchmark( - model=model, - processor=processor, - config=benchmark_cfg, - model_name="dinov2", - ) - elif args.action == "visualize": - from visualizer import app - - app.run(debug=True) - else: # generate - from configs import cfg_manager - from data_loading.synthesizer import ImageSynthesizer - - config = cfg_manager.get() - dataset_cfg = config.dataset - - synthesizer = ImageSynthesizer( - dataset_root=dataset_cfg.dataset_root, - output_dir=dataset_cfg.output_dir, - num_objects_range=dataset_cfg.num_objects_range, - num_scenes=dataset_cfg.num_scenes, - object_scale_range=dataset_cfg.object_scale_range, - rotation_range=dataset_cfg.rotation_range, - overlap_threshold=dataset_cfg.overlap_threshold, - seed=dataset_cfg.seed, - ) - - generated_files = synthesizer.generate() - print( - f"Generated {len(generated_files)} synthesized images in {dataset_cfg.output_dir}" - ) + app() diff --git a/pyproject.toml b/pyproject.toml index a216237..9d77d09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "torch>=2.10.0", "torchvision>=0.25.0", "transformers>=5.0.0", + "typer>=0.24.1", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index f0e05aa..90f5e24 100644 --- a/uv.lock +++ b/uv.lock @@ -250,6 +250,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1026,6 +1035,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3d/96/fa3cb37a6ffe7b81073d8c74f7cb95204d0922ac1668b264685aa34add20/lancedb-0.27.1-cp39-abi3-win_amd64.whl", hash = "sha256:f2150a66758ce6fe3cff226ac1ffcac2d5f5e2c9b35bc4c2d5923abcebef98cc", size = 53374010, upload-time = "2026-01-27T03:57:13.434Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -1111,6 +1132,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mini-nav" version = "0.1.0" @@ -1133,6 +1163,7 @@ dependencies = [ { name = "torchvision", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, { name = "torchvision", version = "0.25.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "transformers" }, + { name = "typer" }, ] [package.dev-dependencies] @@ -1160,6 +1191,7 @@ requires-dist = [ { name = "torchvision", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = ">=0.25.0" }, { name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = ">=0.25.0", index = "https://download.pytorch.org/whl/cu130" }, { name = "transformers", specifier = ">=5.0.0" }, + { name = "typer", specifier = ">=0.24.1" }, ] [package.metadata.requires-dev] @@ -2628,6 +2660,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/67/f3/6cd296376653270ac1b423bb30bd70942d9916b6978c6f40472d6ac038e7/retrying-1.4.2-py3-none-any.whl", hash = "sha256:bbc004aeb542a74f3569aeddf42a2516efefcdaff90df0eb38fbfbf19f179f59", size = 10859, upload-time = "2025-08-03T03:35:23.829Z" }, ] +[[package]] +name = "rich" +version = "14.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, +] + [[package]] name = "safetensors" version = "0.7.0" @@ -3117,10 +3162,10 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform != 'linux' and sys_platform != 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/ea/304cf7afb744aa626fa9855245526484ee55aba610d9973a0521c552a843/torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c37fc46eedd9175f9c81814cc47308f1b42cfe4987e532d4b423d23852f2bf63", size = 79411450, upload-time = "2026-02-06T17:37:35.75Z" }, - { url = "https://files.pythonhosted.org/packages/25/d8/9e6b8e7df981a1e3ea3907fd5a74673e791da483e8c307f0b6ff012626d0/torch-2.10.0-1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f699f31a236a677b3118bc0a3ef3d89c0c29b5ec0b20f4c4bf0b110378487464", size = 79423460, upload-time = "2026-02-06T17:37:39.657Z" }, - { url = "https://files.pythonhosted.org/packages/c9/2f/0b295dd8d199ef71e6f176f576473d645d41357b7b8aa978cc6b042575df/torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c", size = 79498197, upload-time = "2026-02-06T17:37:27.627Z" }, - { url = "https://files.pythonhosted.org/packages/a4/1b/af5fccb50c341bd69dc016769503cb0857c1423fbe9343410dfeb65240f2/torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe", size = 79498248, upload-time = "2026-02-06T17:37:31.982Z" }, + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, @@ -3312,6 +3357,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, ] +[[package]] +name = "typer" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, +] + [[package]] name = "typer-slim" version = "0.21.1"