diff --git a/vit_pytorch/modules.py b/vit_pytorch/modules.py index 9e4eaa9..98314e1 100644 --- a/vit_pytorch/modules.py +++ b/vit_pytorch/modules.py @@ -10,7 +10,7 @@ class PatchEmbedding(nn.Module): Args: image_size (int): Input image size. - + patch_size (int): Patch size, input image will be split into (image_size // patch_size) ^ 2 patches. @@ -21,13 +21,13 @@ class PatchEmbedding(nn.Module): The embedding dimension. """ - def __init__(self, - image_size=224, - patch_size=16, - in_channels=3, + def __init__(self, + image_size=224, + patch_size=16, + in_channels=3, embed_dim=768): - super(PatchEmbedding, self).__init__() + super().__init__() self.num_patches = (image_size // patch_size) ** 2 self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size) @@ -41,35 +41,21 @@ class PositionalEmbedding(nn.Module): """ def __init__(self, seq_len, embed_dim): - super(PositionalEmbedding, self).__init__() + super().__init__() self.embedding = nn.Parameter(torch.zeros(1, seq_len, embed_dim)) - - def forward(self, x): - return x + self.embedding - - -class GELU(nn.Module): - r""" - Implementation of Gaussian Error Linerar Units activation function. - See `Gaussian Error Linear Units (GELUs) ` - - """ - def __init__(self): - super(GELU, self).__init__() def forward(self, x): - return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) \ - * (x + 0.044715 * torch.pow(x, 3)))) + return x + self.embedding -class MLPBlock(nn.Module): +def MLPBlock(input_dim, hidden_dim, dropout_rate=0.1): r""" Implementation of the MLP / feed-forward block. Args: input_dim (int): Input dimension, same as the output dimension. - + hidden_dim (int): Dimension of the hidden fully-connected layer. @@ -77,23 +63,14 @@ class MLPBlock(nn.Module): Probability of an element to be zeroed. Default: 0.5. """ - def __init__(self, - input_dim, - hidden_dim, - dropout_rate=0.1): - - super(MLPBlock, self).__init__() - - self.block = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - GELU(), - nn.Dropout(dropout_rate), - nn.Linear(hidden_dim, input_dim), - nn.Dropout(dropout_rate) - ) - def forward(self, x): - return self.block(x) + return nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout_rate), + nn.Linear(hidden_dim, input_dim), + nn.Dropout(dropout_rate) + ) class Attention(nn.Module): @@ -104,7 +81,7 @@ class Attention(nn.Module): Args: embed_dim (int): Embedding dimension. - + num_heads (int): Number of attention heads. @@ -115,13 +92,13 @@ class Attention(nn.Module): Dropout rate of the output projection. """ - def __init__(self, - embed_dim, - num_heads, - atten_drop=0.0, + def __init__(self, + embed_dim, + num_heads, + atten_drop=0.0, proj_drop=0.0): - super(Attention, self).__init__() + super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -133,8 +110,8 @@ def __init__(self, def forward(self, x): b, l, c = x.size() - assert c == self.embed_dim - + assert c == self.embed_dim + qkv = self.qkv_proj(x) qkv = qkv.reshape(b, l, 3, self.num_heads, c // self.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) @@ -151,7 +128,16 @@ def forward(self, x): return x -class EncoderBlock(nn.Module): +class Residual(nn.Sequential): + def forward(self, x): + return x + super().forward(x) + + +def EncoderBlock(embed_dim, + num_heads, + hidden_dim, + atten_drop=0., + proj_drop=0.): r""" Implementation of the transformer encoder block. @@ -175,82 +161,34 @@ class EncoderBlock(nn.Module): Dropout rate in `MLPBlock` module. """ - def __init__(self, - embed_dim, - num_heads, - hidden_dim, - atten_drop=0., - proj_drop=0.): - - super(EncoderBlock, self).__init__() - - self.norm1 = nn.LayerNorm(embed_dim) - self.atten = Attention(embed_dim, num_heads, atten_drop, proj_drop) - self.norm2 = nn.LayerNorm(embed_dim) - self.mlp = MLPBlock(embed_dim, hidden_dim, proj_drop) - def forward(self, x): - x = x + self.atten(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x + return nn.Sequential( + Residual(nn.LayerNorm(embed_dim), Attention(embed_dim, num_heads, atten_drop, proj_drop)), + Residual(nn.LayerNorm(embed_dim), MLPBlock(embed_dim, hidden_dim, proj_drop)) + ) -class Transformer(nn.Module): +def Transformer(num_layers, + embed_dim, + num_heads, + hidden_dim, + seq_length, + atten_drop=0., + proj_drop=0.): r""" Implementation of transformer encoder for feature extraction. - See `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale .` + See `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale .` """ - def __init__(self, - num_layers, - embed_dim, - num_heads, - hidden_dim, - seq_length, - atten_drop=0., - proj_drop=0.): - - super(Transformer, self).__init__() - - self.pos_embedding = PositionalEmbedding(seq_length, embed_dim) - self.dropout = nn.Dropout(proj_drop) - self.norm = nn.LayerNorm(embed_dim) - - self.blocks = nn.ModuleList( - [EncoderBlock(embed_dim, - num_heads, - hidden_dim, - atten_drop, - proj_drop) for _ in range(num_layers)]) - - - def forward(self, x): - x = self.pos_embedding(x) - x = self.dropout(x) - - for block in self.blocks: - x = block(x) - - return self.norm(x) - -class PreLogitsLayer(nn.Module): - r""" - Implementation of pre-logits layer. + encoders = [EncoderBlock(embed_dim, num_heads, hidden_dim, atten_drop, proj_drop) + for _ in range(num_layers)] - """ - def __init__(self, embed_dim, repr_dim=None): - super(PreLogitsLayer, self).__init__() - if repr_dim is not None: - self.proj = nn.Sequential( - nn.Linear(embed_dim, repr_dim), - nn.Tanh() - ) - else: - self.proj = nn.Identity() - - def forward(self, x): - return self.proj(x) + return nn.Sequential( + PositionalEmbedding(seq_length, embed_dim), + nn.Dropout(proj_drop), + nn.Sequential(*encoders), + nn.LayerNorm(embed_dim)) class ViT(nn.Module): @@ -265,13 +203,13 @@ def __init__(self, hidden_dim=3072, atten_drop=0., proj_drop=0.1, - repr_dim=None, - **kwargs): + repr_dim=None): - super(ViT, self).__init__() + super().__init__() self.num_classes = num_classes self.embed_dim = embed_dim + self.repr_dim = repr_dim seq_length = (image_size // patch_size) ** 2 + 1 @@ -285,26 +223,26 @@ def __init__(self, hidden_dim, seq_length, atten_drop, - proj_drop,) + proj_drop) - self.pre_logits = PreLogitsLayer(embed_dim, repr_dim) - - repr_dim = repr_dim if repr_dim else embed_dim - self.head = nn.Linear(repr_dim, num_classes) + if repr_dim is not None: + self.head = nn.Sequential( + nn.Linear(embed_dim, repr_dim), + nn.Tanh(), + nn.Linear(repr_dim, num_classes) + ) + else: + self.head = nn.Linear(embed_dim, num_classes) - def forward(self, x): - """ - Defined the forward operation. - """ - b, _, _, _ = x.size() + def forward(self, x, *, cls_only=True): x = self.patch_embedding(x) - x = torch.cat([self.cls_token.expand(b, -1, -1), x], dim=1) + b, n, e = x.size() + x = torch.cat([self.cls_token.expand(b, 1, e), x], dim=1) x = self.transformer(x) - x = self.pre_logits(x) - # only support cls tocken now - x = x[:, 0] + if cls_only: + x = x[:, 0] return self.head(x)