-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiffytest.py
More file actions
87 lines (51 loc) · 1.89 KB
/
diffytest.py
File metadata and controls
87 lines (51 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch as T
from dnc import DNC
import pickle
print('Loading data...')
file = open('./IMDb-data/imdbwords.pickle', 'rb')
words = pickle.load(file)
file.close()
file = open('trainset', 'rb')
trainset = pickle.load(file)
file.close()
file = open('trainlabels', 'rb')
trainlables = pickle.load(file)
file.close()
train_data_loader = T.utils.data.DataLoader(dataset=trainset, batch_size=1, shuffle=False)
trainset = iter(train_data_loader)
print('Defining model...')
diffy = DNC(25, 128, num_layers=2, independent_linears=True)
loss_fn = T.nn.MSELoss()
optimizer = T.optim.Adam(diffy.parameters(), lr=0.0001, betas=[0.9, 0.98])
maxVal = 0
maxItem = []
print('Finding max...')
for item in trainset:
if maxVal < len(item):
maxVal = len(item)
maxItem = item
print('Padding values...')
for i in range(len(trainset)):
if len(trainset[i]) <maxVal:
while len(trainset[i]) < maxVal:
trainset[i].append(0)
inputs = T.tensor(trainset)
inputs = inputs.reshape((1, 25000, 73, 25))
inputs = inputs.to(T.float)
inloader = T.utils.data.DataLoader(dataset=inputs, batch_size=1, shuffle=False)
inputset = iter(inloader)
(controller_hidden, memory, read_vectors) = (None, None, None)
ranges = 2 * len(trainset)
print('Beginnning training loop...')
for it in range(ranges):
optimizer.zero_grad()
seq = next(inputset)
#Forward pass
output, (controller_hidden, memory, read_vectors) = diffy(seq, (None, memory, None), reset_experience=False)
final_out = T.sum(output,(1,2), keepdim=True) #outer(mid_out)
loss = loss_fn(final_out, trainlabels[it].to(T.float).reshape((1,1,1)))
loss.backward()
optimizer.step()
memory = {k : (v.detach() if isinstance(v, T.autograd.Variable) else v) for k, v in memory.items()}
if it % 10 == 9:
print('Step: {}, Loss: {}'.format(it+1, loss))