-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
88 lines (76 loc) · 3.53 KB
/
train.py
File metadata and controls
88 lines (76 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
from tqdm import tqdm
from base_train import BaseTrain
from utils import utils as u
class Train(BaseTrain):
def __init__(self, sess, model, data, c, logger):
super(Train, self).__init__(sess, model, data, c, logger)
def train_epoch(self):
cur_i = self.model.global_step_tensor.eval(self.sess).item()
print("Training epoch %d" % self.model.cur_epoch_tensor.eval(
self.sess).item())
loop = tqdm(range(self.c.n_epochs))
losses = []
accs = []
val_losses = []
val_accs = []
for i in loop:
self.model.save(self.sess)
loss, acc, val_loss, val_acc = self.train_step()
losses.append(loss)
accs.append(acc)
val_losses.append(val_loss)
val_accs.append(val_acc)
print("Iter {} ".format(i) + "training loss: ", loss)
print("Iter {} ".format(i) + "training accuracy: ", acc)
print("Iter {} ".format(i) + "validation loss: ", val_loss)
print("Iter {} ".format(i) + "validation accuracy: ", val_acc)
total_loss = np.mean(losses)
total_acc = np.mean(accs)
total_val_loss = np.mean(val_losses)
total_val_acc = np.mean(val_accs)
print("Epoch {}".format(cur_i) + "total training loss: ", total_loss)
print("Epoch {}".format(cur_i) + "total training accuracy: ", total_acc)
print("Epoch {}".format(cur_i) +
"total validation loss: ", total_val_loss)
print("Epoch {}".format(cur_i) +
"total validation accuracy: ", total_val_acc)
summaries = dict()
summaries['loss'] = total_loss
summaries['acc'] = total_acc
summaries['validation_loss'] = total_val_loss
summaries['validation_acc'] = total_val_acc
print("reached summaries")
self.logger.summarize(cur_i, summaries_dict=summaries)
self.model.save(self.sess)
def train_step(self):
# input_len = np.empty(self.c.b)
# input_len.fill(self.c.max_len)
print("minibatch training...")
pbar = tqdm(total=int(self.data[0].shape[0] // self.c.b))
step = 0
for batch_x, batch_y in u.minibatches(self.data[0], self.data[1],
self.c.b):
feed_dict = {self.model.x: batch_x, self.model.y: batch_y,
self.model.is_training: True}
_, loss, acc = self.sess.run([self.model.optimizer, self.model.cost,
self.model.accuracy],
feed_dict=feed_dict)
val_x = self.data[3][:self.c.b]
val_y = self.data[4][:self.c.b]
val_feed_dict = {self.model.x: val_x, self.model.y: val_y,
self.model.is_training: True}
# each 50th step compute validation loss and accuracy
if step % 50 == 0:
val_acc, val_loss = self.sess.run([self.model.accuracy,
self.model.cost],
feed_dict=val_feed_dict)
print("loss shape", loss.shape)
print("Iter {}".format(step * self.c.b) +
"\nTraining Loss: ", loss)
print("Training Accuracy: ", acc)
print("Validation Loss:", val_loss)
print("Validation Accuracy: ", val_acc)
step += 1
pbar.update(1)
return loss, acc, val_loss, val_acc