sgoodfriend's picture
PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
460072a
raw
history blame
3.44 kB
from typing import Dict, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray
from torch.distributions import Distribution, constraints
from rl_algo_impls.shared.actor.actor import Actor, PiForward
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
from rl_algo_impls.shared.encoder import EncoderOutDim
from rl_algo_impls.shared.module.module import mlp
class MultiCategorical(Distribution):
def __init__(
self,
nvec: NDArray[np.int64],
probs=None,
logits=None,
validate_args=None,
masks: Optional[torch.Tensor] = None,
):
# Either probs or logits should be set
assert (probs is None) != (logits is None)
masks_split = (
torch.split(masks, nvec.tolist(), dim=1)
if masks is not None
else [None] * len(nvec)
)
if probs:
self.dists = [
MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
]
param = probs
else:
assert logits is not None
self.dists = [
MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
]
param = logits
batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
prob_stack = torch.stack(
[c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
)
return prob_stack.sum(dim=-1)
def entropy(self) -> torch.Tensor:
return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
@property
def mode(self) -> torch.Tensor:
return torch.stack([c.mode for c in self.dists], dim=-1)
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
# Constraints handled by child distributions in dist
return {}
class MultiDiscreteActorHead(Actor):
def __init__(
self,
nvec: NDArray[np.int64],
in_dim: EncoderOutDim,
hidden_sizes: Tuple[int, ...] = (32,),
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: bool = True,
) -> None:
super().__init__()
self.nvec = nvec
assert isinstance(in_dim, int)
layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
self._fc = mlp(
layer_sizes,
activation,
init_layers_orthogonal=init_layers_orthogonal,
final_layer_gain=0.01,
)
def forward(
self,
obs: torch.Tensor,
actions: Optional[torch.Tensor] = None,
action_masks: Optional[torch.Tensor] = None,
) -> PiForward:
logits = self._fc(obs)
pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
return self.pi_forward(pi, actions)
@property
def action_shape(self) -> Tuple[int, ...]:
return (len(self.nvec),)