Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 71 additions & 133 deletions vit_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PatchEmbedding(nn.Module):
Args:
image_size (int):
Input image size.

patch_size (int):
Patch size, input image will be split into (image_size // patch_size) ^ 2 patches.

Expand All @@ -21,13 +21,13 @@ class PatchEmbedding(nn.Module):
The embedding dimension.

"""
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
embed_dim=768):

super(PatchEmbedding, self).__init__()
super().__init__()
self.num_patches = (image_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)

Expand All @@ -41,59 +41,36 @@ class PositionalEmbedding(nn.Module):

"""
def __init__(self, seq_len, embed_dim):
super(PositionalEmbedding, self).__init__()
super().__init__()
self.embedding = nn.Parameter(torch.zeros(1, seq_len, embed_dim))

def forward(self, x):
return x + self.embedding


class GELU(nn.Module):
r"""
Implementation of Gaussian Error Linerar Units activation function.
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`

"""
def __init__(self):
super(GELU, self).__init__()

def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) \
* (x + 0.044715 * torch.pow(x, 3))))
return x + self.embedding


class MLPBlock(nn.Module):
def MLPBlock(input_dim, hidden_dim, dropout_rate=0.1):
r"""
Implementation of the MLP / feed-forward block.

Args:
input_dim (int):
Input dimension, same as the output dimension.

hidden_dim (int):
Dimension of the hidden fully-connected layer.

dropout_rate (float):
Probability of an element to be zeroed. Default: 0.5.

"""
def __init__(self,
input_dim,
hidden_dim,
dropout_rate=0.1):

super(MLPBlock, self).__init__()

self.block = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, input_dim),
nn.Dropout(dropout_rate)
)

def forward(self, x):
return self.block(x)
return nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_dim, input_dim),
nn.Dropout(dropout_rate)
)


class Attention(nn.Module):
Expand All @@ -104,7 +81,7 @@ class Attention(nn.Module):
Args:
embed_dim (int):
Embedding dimension.

num_heads (int):
Number of attention heads.

Expand All @@ -115,13 +92,13 @@ class Attention(nn.Module):
Dropout rate of the output projection.

"""
def __init__(self,
embed_dim,
num_heads,
atten_drop=0.0,
def __init__(self,
embed_dim,
num_heads,
atten_drop=0.0,
proj_drop=0.0):

super(Attention, self).__init__()
super().__init__()

self.embed_dim = embed_dim
self.num_heads = num_heads
Expand All @@ -133,8 +110,8 @@ def __init__(self,

def forward(self, x):
b, l, c = x.size()
assert c == self.embed_dim
assert c == self.embed_dim

qkv = self.qkv_proj(x)
qkv = qkv.reshape(b, l, 3, self.num_heads, c // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
Expand All @@ -151,7 +128,16 @@ def forward(self, x):
return x


class EncoderBlock(nn.Module):
class Residual(nn.Sequential):
def forward(self, x):
return x + super().forward(x)


def EncoderBlock(embed_dim,
num_heads,
hidden_dim,
atten_drop=0.,
proj_drop=0.):
r"""
Implementation of the transformer encoder block.

Expand All @@ -175,82 +161,34 @@ class EncoderBlock(nn.Module):
Dropout rate in `MLPBlock` module.

"""
def __init__(self,
embed_dim,
num_heads,
hidden_dim,
atten_drop=0.,
proj_drop=0.):

super(EncoderBlock, self).__init__()

self.norm1 = nn.LayerNorm(embed_dim)
self.atten = Attention(embed_dim, num_heads, atten_drop, proj_drop)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLPBlock(embed_dim, hidden_dim, proj_drop)

def forward(self, x):
x = x + self.atten(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
return nn.Sequential(
Residual(nn.LayerNorm(embed_dim), Attention(embed_dim, num_heads, atten_drop, proj_drop)),
Residual(nn.LayerNorm(embed_dim), MLPBlock(embed_dim, hidden_dim, proj_drop))
)


class Transformer(nn.Module):
def Transformer(num_layers,
embed_dim,
num_heads,
hidden_dim,
seq_length,
atten_drop=0.,
proj_drop=0.):
r"""
Implementation of transformer encoder for feature extraction.
See `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>.`
See `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>.`

"""
def __init__(self,
num_layers,
embed_dim,
num_heads,
hidden_dim,
seq_length,
atten_drop=0.,
proj_drop=0.):

super(Transformer, self).__init__()

self.pos_embedding = PositionalEmbedding(seq_length, embed_dim)
self.dropout = nn.Dropout(proj_drop)
self.norm = nn.LayerNorm(embed_dim)

self.blocks = nn.ModuleList(
[EncoderBlock(embed_dim,
num_heads,
hidden_dim,
atten_drop,
proj_drop) for _ in range(num_layers)])


def forward(self, x):
x = self.pos_embedding(x)
x = self.dropout(x)

for block in self.blocks:
x = block(x)

return self.norm(x)


class PreLogitsLayer(nn.Module):
r"""
Implementation of pre-logits layer.
encoders = [EncoderBlock(embed_dim, num_heads, hidden_dim, atten_drop, proj_drop)
for _ in range(num_layers)]

"""
def __init__(self, embed_dim, repr_dim=None):
super(PreLogitsLayer, self).__init__()
if repr_dim is not None:
self.proj = nn.Sequential(
nn.Linear(embed_dim, repr_dim),
nn.Tanh()
)
else:
self.proj = nn.Identity()

def forward(self, x):
return self.proj(x)
return nn.Sequential(
PositionalEmbedding(seq_length, embed_dim),
nn.Dropout(proj_drop),
nn.Sequential(*encoders),
nn.LayerNorm(embed_dim))


class ViT(nn.Module):
Expand All @@ -265,13 +203,13 @@ def __init__(self,
hidden_dim=3072,
atten_drop=0.,
proj_drop=0.1,
repr_dim=None,
**kwargs):
repr_dim=None):

super(ViT, self).__init__()
super().__init__()

self.num_classes = num_classes
self.embed_dim = embed_dim
self.repr_dim = repr_dim

seq_length = (image_size // patch_size) ** 2 + 1

Expand All @@ -285,26 +223,26 @@ def __init__(self,
hidden_dim,
seq_length,
atten_drop,
proj_drop,)
proj_drop)

self.pre_logits = PreLogitsLayer(embed_dim, repr_dim)

repr_dim = repr_dim if repr_dim else embed_dim
self.head = nn.Linear(repr_dim, num_classes)
if repr_dim is not None:
self.head = nn.Sequential(
nn.Linear(embed_dim, repr_dim),
nn.Tanh(),
nn.Linear(repr_dim, num_classes)
)
else:
self.head = nn.Linear(embed_dim, num_classes)

def forward(self, x):
"""
Defined the forward operation.
"""
b, _, _, _ = x.size()
def forward(self, x, *, cls_only=True):
x = self.patch_embedding(x)
x = torch.cat([self.cls_token.expand(b, -1, -1), x], dim=1)
b, n, e = x.size()
x = torch.cat([self.cls_token.expand(b, 1, e), x], dim=1)

x = self.transformer(x)
x = self.pre_logits(x)

# only support cls tocken now
x = x[:, 0]
if cls_only:
x = x[:, 0]

return self.head(x)

Expand Down