diff --git a/mini-nav/commands/__init__.py b/mini-nav/commands/__init__.py index a067577..540bbaa 100644 --- a/mini-nav/commands/__init__.py +++ b/mini-nav/commands/__init__.py @@ -1,6 +1,7 @@ -from .train import train +from .app import app from .benchmark import benchmark -from .visualize import visualize from .generate import generate +from .train import train +from .visualize import visualize -__all__ = ["train", "benchmark", "visualize", "generate"] +__all__ = ["app", "train", "benchmark", "visualize", "generate"] diff --git a/mini-nav/commands/app.py b/mini-nav/commands/app.py new file mode 100644 index 0000000..f4fa8eb --- /dev/null +++ b/mini-nav/commands/app.py @@ -0,0 +1,7 @@ +import typer + +app = typer.Typer( + name="mini-nav", + help="Mini-Nav: A vision-language navigation system", + add_completion=False, +) diff --git a/mini-nav/commands/benchmark.py b/mini-nav/commands/benchmark.py index 4cab0ba..758c6b7 100644 --- a/mini-nav/commands/benchmark.py +++ b/mini-nav/commands/benchmark.py @@ -1,8 +1,10 @@ from typing import cast import typer +from commands import app +@app.command() def benchmark( ctx: typer.Context, model_path: str = typer.Option( diff --git a/mini-nav/commands/generate.py b/mini-nav/commands/generate.py index e7c0024..ba65686 100644 --- a/mini-nav/commands/generate.py +++ b/mini-nav/commands/generate.py @@ -1,6 +1,8 @@ import typer +from commands import app +@app.command() def generate(ctx: typer.Context): from configs import cfg_manager from data_loading.synthesizer import ImageSynthesizer diff --git a/mini-nav/commands/train.py b/mini-nav/commands/train.py index 4362d1d..221d169 100644 --- a/mini-nav/commands/train.py +++ b/mini-nav/commands/train.py @@ -1,6 +1,8 @@ import typer +from commands import app +@app.command() def train( ctx: typer.Context, epoch_size: int = typer.Option(10, "--epoch", "-e", help="Number of epochs"), diff --git a/mini-nav/commands/visualize.py b/mini-nav/commands/visualize.py index 58e2503..4bfe6ef 100644 --- a/mini-nav/commands/visualize.py +++ b/mini-nav/commands/visualize.py @@ -1,6 +1,8 @@ import typer +from commands import app +@app.command() def visualize( ctx: typer.Context, host: str = typer.Option("127.0.0.1", "--host", help="Server host"), diff --git a/mini-nav/main.py b/mini-nav/main.py index 1bc18e7..466faa4 100644 --- a/mini-nav/main.py +++ b/mini-nav/main.py @@ -1,16 +1,4 @@ -import typer -from commands import benchmark, generate, train, visualize - -app = typer.Typer( - name="mini-nav", - help="Mini-Nav: A vision-language navigation system", - add_completion=False, -) - -app.command(name="train")(train) -app.command(name="benchmark")(benchmark) -app.command(name="visualize")(visualize) -app.command(name="generate")(generate) +from commands import app if __name__ == "__main__": app()