sgoodfriend's picture
PPO playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c
872aa5c
raw
history blame
1.3 kB
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from typing import Optional, Sequence
from gym.spaces import Box, Discrete
from shared.policy.on_policy import ActorCritic
class PPOActorCritic(ActorCritic):
def __init__(
self,
env: VecEnv,
pi_hidden_sizes: Optional[Sequence[int]] = None,
v_hidden_sizes: Optional[Sequence[int]] = None,
**kwargs,
) -> None:
obs_space = env.observation_space
if isinstance(obs_space, Box):
if len(obs_space.shape) == 3:
pi_hidden_sizes = pi_hidden_sizes or []
v_hidden_sizes = v_hidden_sizes or []
elif len(obs_space.shape) == 1:
pi_hidden_sizes = pi_hidden_sizes or [64, 64]
v_hidden_sizes = v_hidden_sizes or [64, 64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
elif isinstance(obs_space, Discrete):
pi_hidden_sizes = pi_hidden_sizes or [64]
v_hidden_sizes = v_hidden_sizes or [64]
else:
raise ValueError(f"Unsupported observation space: {obs_space}")
super().__init__(
env,
pi_hidden_sizes,
v_hidden_sizes,
**kwargs,
)