sgoodfriend's picture
PPO playing MountainCarContinuous-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4
54227bc
raw
history blame
884 Bytes
import gym
import torch as th
import torch.nn as nn
from gym.spaces import Discrete
from typing import Sequence, Type
from shared.module import FeatureExtractor, mlp
class QNetwork(nn.Module):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
hidden_sizes: Sequence[int] = [],
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
) -> None:
super().__init__()
assert isinstance(action_space, Discrete)
self._feature_extractor = FeatureExtractor(observation_space, activation)
layer_sizes = (
(self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
)
self._fc = mlp(layer_sizes, activation)
def forward(self, obs: th.Tensor) -> th.Tensor:
x = self._feature_extractor(obs)
return self._fc(x)