import torch | |
import torch.nn as nn | |
class PerceptionAgent(nn.Module): | |
def __init__(self, config): | |
super(PerceptionAgent, self).__init__() | |
self.cnn_layers = nn.Sequential( | |
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
# Additional layers can be defined based on config | |
) | |
self.fc_layers = nn.Sequential( | |
nn.Linear(16 * 32 * 32, 256), | |
nn.ReLU(), | |
nn.Linear(256, config["perception_output_size"]) | |
) | |
def forward(self, x): | |
x = self.cnn_layers(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc_layers(x) | |
return x | |