PPO playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c
872aa5c
from stable_baselines3.common.vec_env.base_vec_env import VecEnv | |
from typing import Optional, Sequence | |
from gym.spaces import Box, Discrete | |
from shared.policy.on_policy import ActorCritic | |
class PPOActorCritic(ActorCritic): | |
def __init__( | |
self, | |
env: VecEnv, | |
pi_hidden_sizes: Optional[Sequence[int]] = None, | |
v_hidden_sizes: Optional[Sequence[int]] = None, | |
**kwargs, | |
) -> None: | |
obs_space = env.observation_space | |
if isinstance(obs_space, Box): | |
if len(obs_space.shape) == 3: | |
pi_hidden_sizes = pi_hidden_sizes or [] | |
v_hidden_sizes = v_hidden_sizes or [] | |
elif len(obs_space.shape) == 1: | |
pi_hidden_sizes = pi_hidden_sizes or [64, 64] | |
v_hidden_sizes = v_hidden_sizes or [64, 64] | |
else: | |
raise ValueError(f"Unsupported observation space: {obs_space}") | |
elif isinstance(obs_space, Discrete): | |
pi_hidden_sizes = pi_hidden_sizes or [64] | |
v_hidden_sizes = v_hidden_sizes or [64] | |
else: | |
raise ValueError(f"Unsupported observation space: {obs_space}") | |
super().__init__( | |
env, | |
pi_hidden_sizes, | |
v_hidden_sizes, | |
**kwargs, | |
) | |