PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
460072a
from typing import Optional, Tuple, Type | |
import torch | |
import torch.nn as nn | |
from torch.distributions import Categorical | |
from rl_algo_impls.shared.actor import Actor, PiForward | |
from rl_algo_impls.shared.module.module import mlp | |
class MaskedCategorical(Categorical): | |
def __init__( | |
self, | |
probs=None, | |
logits=None, | |
validate_args=None, | |
mask: Optional[torch.Tensor] = None, | |
): | |
if mask is not None: | |
assert logits is not None, "mask requires logits and not probs" | |
logits = torch.where(mask, logits, -1e8) | |
self.mask = mask | |
super().__init__(probs, logits, validate_args) | |
def entropy(self) -> torch.Tensor: | |
if self.mask is None: | |
return super().entropy() | |
# If mask set, then use approximation for entropy | |
p_log_p = self.logits * self.probs # type: ignore | |
masked = torch.where(self.mask, p_log_p, 0) | |
return -masked.sum(-1) | |
class CategoricalActorHead(Actor): | |
def __init__( | |
self, | |
act_dim: int, | |
in_dim: int, | |
hidden_sizes: Tuple[int, ...] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
) -> None: | |
super().__init__() | |
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,) | |
self._fc = mlp( | |
layer_sizes, | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=0.01, | |
) | |
def forward( | |
self, | |
obs: torch.Tensor, | |
actions: Optional[torch.Tensor] = None, | |
action_masks: Optional[torch.Tensor] = None, | |
) -> PiForward: | |
logits = self._fc(obs) | |
pi = MaskedCategorical(logits=logits, mask=action_masks) | |
return self.pi_forward(pi, actions) | |
def action_shape(self) -> Tuple[int, ...]: | |
return () | |