-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbase_train.py
More file actions
31 lines (25 loc) · 923 Bytes
/
base_train.py
File metadata and controls
31 lines (25 loc) · 923 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import tensorflow as tf
import os
from tqdm import tqdm
class BaseTrain:
def __init__(self, sess, model, data, c, logger):
self.sess = sess
self.data = data
self.c = c
self.model = model
self.logger = logger
self.init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
self.sess.run(self.init)
def train(self):
for cur_epoch in range(self.model.cur_epoch_tensor.eval(
self.sess).item(), self.c.n_epochs):
self.train_epoch()
self.sess.run(self.model.increment_cur_epoch_tensor)
def train_epoch(self):
"""Loop over number of iterations, add train steps and summaries"""
raise NotImplementedError
def train_step(self):
"""Run session and return stuff to summary"""
raise NotImplementedError