-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
128 lines (99 loc) · 5.18 KB
/
model.py
File metadata and controls
128 lines (99 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.utils.rnn import pack_padded_sequence
import numpy as np
import torch.nn.functional as F
class EncoderCNN(nn.Module):
def __init__(self):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
modules = list(resnet.children())[:-2]
self.resnet = nn.Sequential(*modules)
self.fine_tune()
def fine_tune(self, fine_tune=True):
for param in self.resnet.parameters():
param.requires_grad = False
for child in list(self.resnet.children())[5:]:
for param in child.parameters():
param.requires_grad = fine_tune
def forward(self, images):
features = self.resnet(images)
features = features.permute(0, 2, 3, 1)
return features
class Attention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super(Attention, self).__init__()
# linear layer to transform encoder's & decoder's output
self.encoder_attn = nn.Linear(encoder_dim, attention_dim)
self.decoder_attn = nn.Linear(decoder_dim, attention_dim)
self.full_attn = nn.Linear(attention_dim, 1)
def forward(self, encoder_out, decoder_hidden):
attn1 = self.encoder_attn(encoder_out) # (batch_size, num_pixels, attention_dim)
attn2 = self.decoder_attn(decoder_hidden) # (batch_size, attention_dim)
attn = self.full_attn(F.relu(attn1 + attn2.unsqueeze(1))) # (batch_size, num_pixels, 1)
# softmax for calculating weights for weighted encoding based on attention
alpha = F.softmax(attn, dim=1) # (batch_size, num_pixels,1)
attn_weighted_encoding = (encoder_out * alpha).sum(dim=1) # (batch_size, encoder_dim)
alpha = alpha = alpha.squeeze(2) # (batch_size, num_pixels)
return attn_weighted_encoding, alpha
class DecoderRNN(nn.Module):
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim=2048, dropout=0.5, ):
super(DecoderRNN, self).__init__()
self.attention_dim = attention_dim
self.embed_dim = embed_dim
self.decoder_dim = decoder_dim # feature size of decoder's RNN
self.vocab_size = vocab_size
self.device = device
self.encoder_dim = encoder_dim # feature size of encoded images
self.dropout = dropout
self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.dropout = nn.Dropout(p=dropout)
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)
self.init_h = nn.Linear(encoder_dim, decoder_dim)
self.init_c = nn.Linear(encoder_dim, decoder_dim)
self.f_beta = nn.Linear(decoder_dim, encoder_dim)
self.fc = nn.Linear(decoder_dim, vocab_size)
self.init_weights()
def init_weights(self):
# initializes layers w/ uniform distribution for easier convergence
self.embedding.weight.data.uniform_(-0.1, 0.1)
self.fc.bias.data.fill_(0)
self.fc.weight.data.uniform_(-0.1, 0.1)
def init_hidden_state(self, encoder_out):
mean_encoder_out = encoder_out.mean(dim=1)
h = self.init_h(mean_encoder_out)
c = self.init_c(mean_encoder_out)
return h, c
def forward(self, encoder_out, encoded_captions, caption_lens):
batch_size = encoder_out.size(0)
# flatten image
encoder_out = encoder_out.view(batch_size, -1, self.encoder_dim) # encoded image
num_pixels = encoder_out.size(1)
# sorting input data by the decreasing caption length ( in order not to process <pads>)
caption_lens, sort_idx = caption_lens.sort(dim=0, descending=True)
encoder_out = encoder_out[sort_idx]
encoded_captions = encoded_captions[sort_idx]
embeddings = self.embedding(encoded_captions)
h, c = self.init_hidden_state(encoder_out)
# <end> token will not be included
decode_lens = (caption_lens - 1).tolist()
predictions = torch.zeros(batch_size, max(decode_lens), self.vocab_size).to(self.device)
alphas = torch.zeros(batch_size, max(decode_lens), num_pixels).to(self.device)
for t in range(max(decode_lens)):
batch_size_t = sum([l > t for l in decode_lens])
# attention weighted encodings
attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
# sigmoid gating scalar
gate = F.sigmoid(self.f_beta(h[:batch_size_t]))
attention_weighted_encoding = gate * attention_weighted_encoding
h, c = self.decode_step(
torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
(h[:batch_size_t], c[:batch_size_t]))
# next word prediction
preds = self.fc(self.dropout(h))
# save the prediction and alpha for every time step
predictions[:batch_size_t, t, :] = preds
alphas[:batch_size_t, t, :] = alpha
return predictions, encoded_captions, decode_lens, alphas, sort_idx