Spaces:
Runtime error
Runtime error
File size: 11,932 Bytes
e0f25ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import functools
from chex._src.pytypes import PRNGKey
import jax
import jax.numpy as jnp
import chex
from jax.numpy import ndarray
import numpy as np
from flax import struct
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union, Any
from gymnax.environments import spaces, environment
from kinetix.environment.env_state import EnvParams, EnvState
from jaxued.environments import UnderspecifiedEnv
class UnderspecifiedEnvWrapper(UnderspecifiedEnv):
"""Base class for Gymnax wrappers."""
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
class GymnaxWrapper(object):
"""Base class for Gymnax wrappers."""
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
# From Here: https://github.com/DramaCow/jaxued/blob/main/src/jaxued/wrappers/autoreset.py
class AutoResetWrapper(UnderspecifiedEnvWrapper):
"""
This is a wrapper around an `UnderspecifiedEnv`, allowing for the environment to be automatically reset upon completion of an episode. This behaviour is similar to the default Gymnax interface. The user can specify a callable `sample_level` that takes in a PRNGKey and returns a level.
Warning:
To maintain compliance with UnderspecifiedEnv interface, user can reset to an
arbitrary level. This includes levels outside the support of sample_level(). Consequently,
the tagged rng is defaulted to jax.random.PRNGKey(0). If your code relies on this, careful
attention may be required.
"""
def __init__(self, env: UnderspecifiedEnv, sample_level: Callable[[chex.PRNGKey], EnvState]):
self._env = env
self.sample_level = sample_level
@property
def default_params(self) -> EnvParams:
return self._env.default_params
def reset_env(self, rng, params):
rng, rng_sample, rng_reset = jax.random.split(rng, 3)
state_to_reset_to = self.sample_level(rng_sample)
return self._env.reset_env_to_pcg_level(rng_reset, state_to_reset_to, params)
def step_env(
self,
rng: chex.PRNGKey,
state: EnvState,
action: Union[int, float],
params: EnvParams,
) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
rng_reset, rng_step = jax.random.split(rng, 2)
obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state, action, params)
obs_re, env_state_re = self.reset_env(rng_reset, params)
env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
return obs, env_state, reward, done, info
def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
# raise NotImplementedError("This method should not be called directly. Use reset instead.")
obs, env_state = self._env.reset_to_level(rng, level, params)
return obs, env_state
def action_space(self, params: EnvParams) -> Any:
return self._env.action_space(params)
class AutoReplayWrapper(UnderspecifiedEnv):
"""
This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
This is useful for training/rolling out multiple times on the same level.
"""
def __init__(self, env: UnderspecifiedEnv):
self._env = env
@property
def default_params(self) -> EnvParams:
return self._env.default_params
def step_env(
self,
rng: chex.PRNGKey,
state: EnvState,
action: Union[int, float],
params: EnvParams,
) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
rng_reset, rng_step = jax.random.split(rng)
obs_re, env_state_re = self._env.reset_to_level(rng_reset, state.level, params)
obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state.env_state, action, params)
env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
return obs, state.replace(env_state=env_state), reward, done, info
def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
obs, env_state = self._env.reset_to_level(rng, level, params)
return obs, AutoReplayState(env_state=env_state, level=level)
def action_space(self, params: EnvParams) -> Any:
return self._env.action_space(params)
@struct.dataclass
class AutoReplayState:
env_state: EnvState
level: EnvState
class AutoReplayWrapper(UnderspecifiedEnvWrapper):
"""
This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
This is useful for training/rolling out multiple times on the same level.
"""
def __init__(self, env: UnderspecifiedEnv):
self._env = env
@property
def default_params(self) -> EnvParams:
return self._env.default_params
def step_env(
self,
rng: chex.PRNGKey,
state: EnvState,
action: Union[int, float],
params: EnvParams,
) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
rng_reset, rng_step = jax.random.split(rng)
obs_re, env_state_re = self._env.reset_to_level(rng_reset, state.level, params)
obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state.env_state, action, params)
env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
return obs, state.replace(env_state=env_state), reward, done, info
def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
obs, env_state = self._env.reset_to_level(rng, level, params)
return obs, AutoReplayState(env_state=env_state, level=level)
def action_space(self, params: EnvParams) -> Any:
return self._env.action_space(params)
class UnderspecifiedToGymnaxWrapper(environment.Environment):
def __init__(self, env):
self._env = env
# provide proxy access to regular attributes of wrapped object
def __getattr__(self, name):
return getattr(self._env, name)
@property
def default_params(self) -> Any:
return self._env.default_params
def step_env(
self, key: jax.Array, state: Any, action: int | float | jax.Array | ndarray | np.bool_ | np.number, params: Any
) -> Tuple[jax.Array | ndarray | np.bool_ | np.number | Any | Dict[Any, Any]]:
return self._env.step_env(key, state, action, params)
def reset_env(self, key: PRNGKey, params: Any) -> Tuple[PRNGKey | np.ndarray | np.bool_ | np.number | Any]:
return self._env.reset_env(key, params)
def action_space(self, params: Any):
return self._env.action_space(params)
class BatchEnvWrapper(GymnaxWrapper):
"""Batches reset and step functions"""
def __init__(self, env, num_envs: int):
super().__init__(env)
self.num_envs = num_envs
self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
self.reset_to_level_fn = jax.vmap(self._env.reset_to_level, in_axes=(0, 0, None))
self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
@partial(jax.jit, static_argnums=(0, 2))
def reset(self, rng, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs, env_state = self.reset_fn(rngs, params)
return obs, env_state
@partial(jax.jit, static_argnums=(0, 3))
def reset_to_level(self, rng, level, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs, env_state = self.reset_to_level_fn(rngs, level, params)
return obs, env_state
@partial(jax.jit, static_argnums=(0, 4))
def step(self, rng, state, action, params=None):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, self.num_envs)
obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
return obs, state, reward, done, info
@struct.dataclass
class DenseRewardState:
env_state: EnvState
last_distance: float = -1.0
class DenseRewardWrapper(GymnaxWrapper):
def __init__(self, env, dense_reward_scale: float = 1.0) -> None:
super().__init__(env)
self.dense_reward_scale = dense_reward_scale
def step(self, key, state, action: int, params=None):
obs, env_state, reward, done, info = self._env.step_env(key, state.env_state, action, params)
delta_dist = (
-(info["distance"] - state.last_distance) * params.dense_reward_scale
) # if distance got less, then reward is positive
delta_dist = jnp.nan_to_num(delta_dist, nan=0.0, posinf=0.0, neginf=0.0)
reward = reward + jax.lax.select(
(state.last_distance == -1) | (self.dense_reward_scale == 0.0), 0.0, delta_dist * self.dense_reward_scale
)
return obs, DenseRewardState(env_state, info["distance"]), reward, done, info
def reset(self, rng, params=None):
obs, env_state = self._env.reset(rng, params)
return obs, DenseRewardState(env_state, -1.0)
def reset_to_level(self, rng, level, params=None):
obs, env_state = self._env.reset_to_level(rng, level, params)
return obs, DenseRewardState(env_state, -1.0)
@struct.dataclass
class LogEnvState:
env_state: Any
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
timestep: int
class LogWrapper(GymnaxWrapper):
"""Log the episode returns and lengths."""
def __init__(self, env):
super().__init__(env)
@partial(jax.jit, static_argnums=(0, 2))
def reset(self, key: chex.PRNGKey, params=None):
obs, env_state = self._env.reset(key, params)
state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
return obs, state
def reset_to_level(self, key: chex.PRNGKey, level: EnvState, params=None):
obs, env_state = self._env.reset_to_level(key, level, params)
state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
return obs, state
@partial(jax.jit, static_argnums=(0, 4))
def step(
self,
key: chex.PRNGKey,
state,
action: Union[int, float],
params=None,
):
obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
new_episode_return = state.episode_returns + reward
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - done),
episode_lengths=new_episode_length * (1 - done),
returned_episode_returns=state.returned_episode_returns * (1 - done) + new_episode_return * done,
returned_episode_lengths=state.returned_episode_lengths * (1 - done) + new_episode_length * done,
timestep=state.timestep + 1,
)
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode_solved"] = info["GoalR"]
info["timestep"] = state.timestep
info["returned_episode"] = done
return obs, state, reward, done, info
|