-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
65 lines (52 loc) · 1.71 KB
/
main.py
File metadata and controls
65 lines (52 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import click
from game import Game
from pathlib import Path
from lit_module import LitModule
import lightning as L
import yaml
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
@click.group()
def cli():
"""Autonomous Driver CLI"""
pass
@cli.command()
@click.option('--checkpoint-path', type=Path, help='Path to the checkpoint file')
def run(checkpoint_path: Path):
"""Run the autonomous driver with a trained model"""
game = Game(checkpoint_path)
game.setup()
game.run()
@cli.command()
@click.option('--config-path', type=Path, required=True, help='Path to the config file')
@click.option('--checkpoint-path', type=Path, help='Path to the checkpoint file')
def train(config_path: Path, checkpoint_path: Path):
"""Train the autonomous driver model"""
config = load_config(config_path)
model = LitModule(config)
logger = TensorBoardLogger(
save_dir="lightning_logs",
name="autonomous-driver",
)
checkpoint_callback = ModelCheckpoint(
monitor='train_loss', # metric to monitor
dirpath=logger.log_dir + '/checkpoints',
filename='autonomous-driver-{step:06d}-{train_loss:.2f}',
save_top_k=5, # save top 5 checkpoints
mode='min', # minimize the monitored metric
save_last=True, # also save the last checkpoint
)
trainer = L.Trainer(
max_epochs=None,
logger=logger,
callbacks=[checkpoint_callback],
)
trainer.fit(
model,
ckpt_path=checkpoint_path,
)
def load_config(config: Path):
with open(config, 'r') as f:
return yaml.load(f, Loader=yaml.SafeLoader)
if __name__ == '__main__':
cli()