mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
23 lines
639 B
Python
23 lines
639 B
Python
import typer
|
|
from commands import app
|
|
|
|
|
|
@app.command()
|
|
def train(
|
|
ctx: typer.Context,
|
|
epoch_size: int = typer.Option(10, "--epoch", "-e", help="Number of epochs"),
|
|
batch_size: int = typer.Option(64, "--batch", "-b", help="Batch size"),
|
|
lr: float = typer.Option(1e-4, "--lr", "-l", help="Learning rate"),
|
|
checkpoint_path: str = typer.Option(
|
|
"hash_checkpoint.pt", "--checkpoint", "-c", help="Checkpoint path"
|
|
),
|
|
):
|
|
from compressors import train as train_module
|
|
|
|
train_module(
|
|
epoch_size=epoch_size,
|
|
batch_size=batch_size,
|
|
lr=lr,
|
|
checkpoint_path=checkpoint_path,
|
|
)
|