-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
37 lines (32 loc) · 1.21 KB
/
model.py
File metadata and controls
37 lines (32 loc) · 1.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.nn as nn
import torchvision.models as models
import math
import torch
class Model(nn.Module):
def __init__(self,
num_action_classes: int,
action_history_shape: tuple[int, int],
):
super().__init__()
# Load pretrained ResNet18 and modify the final layer
self.image_encoder = models.resnet18()#weights=models.ResNet18_Weights.IMAGENET1K_V1)
# Replace the final fully connected layer to output 512 features instead of 1000 classes
self.image_encoder.fc = nn.Linear(self.image_encoder.fc.in_features, 512)
self.action_history_encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(math.prod(action_history_shape), 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
)
self.action_decoder = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, num_action_classes),
)
def forward(self, image, action_history):
image_x = self.image_encoder(image)
action_history_x = self.action_history_encoder(action_history)
x = torch.cat([image_x, action_history_x], dim=1)
x = self.action_decoder(x)
return x