-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
55 lines (43 loc) · 1.93 KB
/
Copy pathmain.py
File metadata and controls
55 lines (43 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from data_loader import DataLoader
from model import *
from trainer import Trainer
from tensorflow.keras.optimizers import Adam
import argparse
parser = argparse.ArgumentParser(description="Train YasuoNet", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--mode", default='train', type=str, dest='mode', required=True)
parser.add_argument("--data_dir", type=str, dest='data_dir', required=True)
parser.add_argument("--batch_size", type=int, dest='batch_size', required=True)
parser.add_argument("--epochs", type=int, dest='epochs', required=True)
parser.add_argument("--learning_rate", default=1e-3, type=float, dest='learning_rate')
parser.add_argument("--ckpt_dir", default='./checkpoints', type=str, dest='ckpt_dir')
# parser.add_argument("--train_continue", default='off', type=str, dest='train_continue')
args = parser.parse_args()
# parameter
mode = args.mode
data_dir = args.data_dir
batch_size = args.batch_size
epochs = args.epochs
learning_rate = args.learning_rate
ckpt_dir = args.ckpt_dir
# train_continue = args.train_continue
def main():
# for basic model
# data_loader = DataLoader(data_dir, x_includes=['video', 'audio'])
# input_shape_dict = data_loader.get_metadata()['data_shape']
# model = build_basic_model(input_shape_dict)
# for sequence model
data_loader = DataLoader(data_dir, x_includes=['video', 'audio'], x_expand=2)
input_shape_dict = data_loader.get_metadata()['data_shape']
model = build_sequence_model(input_shape_dict)
if mode == 'train':
model.summary()
class_weights = (1, 1)
trainer = Trainer(model, data_loader, ckpt_dir)
trainer.train(Adam(learning_rate), epochs, batch_size, class_weights)
elif mode == 'test':
trainer = Trainer(model, data_loader, ckpt_dir)
trainer.test()
elif mode == 'predict':
pass
if __name__ == '__main__':
main()