Spaces:
Runtime error
Runtime error
import functools | |
from functools import partial | |
from typing import Any, Dict, Optional, Tuple, Union | |
import chex | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from chex._src.pytypes import PRNGKey | |
from gymnax.environments import environment, spaces | |
from gymnax.environments.environment import TEnvParams, TEnvState | |
from gymnax.environments.spaces import Space | |
from jax import lax | |
from jax2d.engine import PhysicsEngine, create_empty_sim, recalculate_mass_and_inertia | |
from jax2d.sim_state import CollisionManifold, SimState | |
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
from kinetix.environment.wrappers import ( | |
AutoReplayWrapper, | |
AutoResetWrapper, | |
UnderspecifiedToGymnaxWrapper, | |
DenseRewardWrapper, | |
LogWrapper, | |
) | |
from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state | |
from kinetix.pcg.pcg_state import PCGState | |
from kinetix.render.renderer_symbolic_entity import make_render_entities | |
from kinetix.render.renderer_pixels import make_render_pixels, make_render_pixels_rl | |
from kinetix.render.renderer_symbolic_flat import make_render_symbolic | |
from kinetix.util.saving import load_pcg_state_pickle | |
from jaxued.environments import UnderspecifiedEnv | |
def create_empty_env(static_env_params): | |
sim_state = create_empty_sim(static_env_params) | |
return EnvState( | |
timestep=0, | |
thruster_bindings=jnp.zeros(static_env_params.num_thrusters, dtype=jnp.int32), | |
motor_bindings=jnp.zeros(static_env_params.num_joints, dtype=jnp.int32), | |
motor_auto=jnp.zeros(static_env_params.num_joints, dtype=bool), | |
polygon_shape_roles=jnp.zeros(static_env_params.num_polygons, dtype=jnp.int32), | |
circle_shape_roles=jnp.zeros(static_env_params.num_circles, dtype=jnp.int32), | |
polygon_highlighted=jnp.zeros(static_env_params.num_polygons, dtype=bool), | |
circle_highlighted=jnp.zeros(static_env_params.num_circles, dtype=bool), | |
polygon_densities=jnp.ones(static_env_params.num_polygons, dtype=jnp.float32), | |
circle_densities=jnp.ones(static_env_params.num_circles, dtype=jnp.float32), | |
**sim_state.__dict__, | |
) | |
def index_motor_actions( | |
action: jnp.ndarray, | |
state: EnvState, | |
clip_min=None, | |
clip_max=None, | |
): | |
# Expand the motor actions to all joints with the same colour | |
return jnp.clip(action[state.motor_bindings], clip_min, clip_max) | |
def index_thruster_actions( | |
action: jnp.ndarray, | |
state: EnvState, | |
clip_min=None, | |
clip_max=None, | |
): | |
# Expand the thruster actions to all joints with the same colour | |
return jnp.clip(action[state.thruster_bindings], clip_min, clip_max) | |
def convert_continuous_actions( | |
action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams | |
): | |
action_motor = action[: static_env_params.num_motor_bindings] | |
action_thruster = action[static_env_params.num_motor_bindings :] | |
action_motor = index_motor_actions(action_motor, state, -1, 1) | |
action_thruster = index_thruster_actions(action_thruster, state, 0, 1) | |
action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor) | |
action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0) | |
return action_to_perform | |
def convert_discrete_actions(action: int, state: SimState, static_env_params: StaticEnvParams, params: EnvParams): | |
# so, we have | |
# 0 to NJC * 2 - 1: Joint Actions | |
# NJC * 2: No-op | |
# NJC * 2 + 1 to NJC * 2 + 1 + NTC - 1: Thruster Actions | |
# action here is a categorical action | |
which_idx = action // 2 | |
which_dir = action % 2 | |
actions = ( | |
jnp.zeros(static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings) | |
.at[which_idx] | |
.set(which_dir * 2 - 1) | |
) | |
actions = actions * ( | |
1 - (action >= static_env_params.num_motor_bindings * 2) | |
) # if action is the last one, set it to zero, i.e., a no-op. Alternatively, if the action is larger than NJC * 2, then it is a thruster action and we shouldn't control the joints. | |
actions = jax.lax.select( | |
action > static_env_params.num_motor_bindings * 2, | |
actions.at[action - static_env_params.num_motor_bindings * 2 - 1 + static_env_params.num_motor_bindings].set(1), | |
actions, | |
) | |
action_motor = index_motor_actions(actions[: static_env_params.num_motor_bindings], state, -1, 1) | |
action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor) | |
action_thruster = index_thruster_actions(actions[static_env_params.num_motor_bindings :], state, 0, 1) | |
action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0) | |
return action_to_perform | |
def convert_multi_discrete_actions( | |
action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams | |
): | |
# Comes in with each action being in {0,1,2} for joints and {0,1} for thrusters | |
# Convert to [-1., 1.] for joints and [0., 1.] for thrusters | |
def _single_motor_action(act): | |
return jax.lax.switch( | |
act, | |
[lambda: 0.0, lambda: 1.0, lambda: -1.0], | |
) | |
def _single_thruster_act(act): | |
return jax.lax.select( | |
act == 0, | |
0.0, | |
1.0, | |
) | |
action_motor = jax.vmap(_single_motor_action)(action[: static_env_params.num_motor_bindings]) | |
action_thruster = jax.vmap(_single_thruster_act)(action[static_env_params.num_motor_bindings :]) | |
action_motor = index_motor_actions(action_motor, state, -1, 1) | |
action_thruster = index_thruster_actions(action_thruster, state, 0, 1) | |
action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor) | |
action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0) | |
return action_to_perform | |
def make_kinetix_env_from_args( | |
obs_type, action_type, reset_type, static_env_params=None, auto_reset_fn=None, dense_reward_scale=1.0 | |
): | |
if obs_type == "entity": | |
if action_type == "multidiscrete": | |
env = KinetixEntityMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "discrete": | |
env = KinetixEntityDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "continuous": | |
env = KinetixEntityContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
else: | |
raise ValueError(f"Unknown action type: {action_type}") | |
elif obs_type == "symbolic": | |
if action_type == "multidiscrete": | |
env = KinetixSymbolicMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "discrete": | |
env = KinetixSymbolicDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "continuous": | |
env = KinetixSymbolicContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
else: | |
raise ValueError(f"Unknown action type: {action_type}") | |
elif obs_type == "pixels": | |
if action_type == "multidiscrete": | |
env = KinetixPixelsMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "discrete": | |
env = KinetixPixelsDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "continuous": | |
env = KinetixPixelsContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
else: | |
raise ValueError(f"Unknown action type: {action_type}") | |
elif obs_type == "blind": | |
if action_type == "discrete": | |
env = KinetixBlindDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
elif action_type == "continuous": | |
env = KinetixBlindContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) | |
else: | |
raise ValueError(f"Unknown action type: {action_type}") | |
else: | |
raise ValueError(f"Unknown observation type: {obs_type}") | |
# Wrap | |
if reset_type == "replay": | |
env = AutoReplayWrapper(env) | |
elif reset_type == "reset": | |
env = AutoResetWrapper(env, sample_level=auto_reset_fn) | |
else: | |
raise ValueError(f"Unknown reset type {reset_type}") | |
env = UnderspecifiedToGymnaxWrapper(env) | |
env = DenseRewardWrapper(env, dense_reward_scale=dense_reward_scale) | |
env = LogWrapper(env) | |
return env | |
def make_kinetix_env_from_name(name, static_env_params=None): | |
kwargs = dict(filename_to_use_for_reset=None, should_do_pcg_reset=True, static_env_params=static_env_params) | |
values = { | |
"Kinetix-Pixels-MultiDiscrete-v1": KinetixPixelsMultiDiscreteActions, | |
"Kinetix-Pixels-Discrete-v1": KinetixPixelsDiscreteActions, | |
"Kinetix-Pixels-Continuous-v1": KinetixPixelsContinuousActions, | |
"Kinetix-Symbolic-MultiDiscrete-v1": KinetixSymbolicMultiDiscreteActions, | |
"Kinetix-Symbolic-Discrete-v1": KinetixSymbolicDiscreteActions, | |
"Kinetix-Symbolic-Continuous-v1": KinetixSymbolicContinuousActions, | |
"Kinetix-Blind-Discrete-v1": KinetixBlindDiscreteActions, | |
"Kinetix-Blind-Continuous-v1": KinetixBlindContinuousActions, | |
"Kinetix-Entity-Discrete-v1": KinetixEntityDiscreteActions, | |
"Kinetix-Entity-Continuous-v1": KinetixEntityContinuousActions, | |
"Kinetix-Entity-MultiDiscrete-v1": KinetixEntityMultiDiscreteActions, | |
} | |
return values[name](**kwargs) | |
class ObservationSpace: | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
pass | |
def get_obs(self, state: EnvState): | |
raise NotImplementedError() | |
def observation_space(self, params: EnvParams): | |
raise NotImplementedError() | |
class PixelObservations(ObservationSpace): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
self.render_function = make_render_pixels_rl(params, static_env_params) | |
self.static_env_params = static_env_params | |
def get_obs(self, state: EnvState): | |
return self.render_function(state) | |
def observation_space(self, params: EnvParams) -> spaces.Box: | |
return spaces.Box( | |
0.0, | |
1.0, | |
tuple(a // self.static_env_params.downscale for a in self.static_env_params.screen_dim) + (3,), | |
dtype=jnp.float32, | |
) | |
class SymbolicObservations(ObservationSpace): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
self.render_function = make_render_symbolic(params, static_env_params) | |
def get_obs(self, state: EnvState): | |
return self.render_function(state) | |
class EntityObservations(ObservationSpace): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
self.render_function = make_render_entities(params, static_env_params) | |
def get_obs(self, state: EnvState): | |
return self.render_function(state) | |
class BlindObservations(ObservationSpace): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
self.params = params | |
def get_obs(self, state: EnvState): | |
return jax.nn.one_hot(state.timestep, self.params.max_timesteps + 1) | |
def get_observation_space_from_name(name: str, params, static_env_params): | |
if "Pixels" in name: | |
return PixelObservations(params, static_env_params) | |
elif "Symbolic" in name: | |
return SymbolicObservations(params, static_env_params) | |
elif "Entity" in name: | |
return EntityObservations(params, static_env_params) | |
if "Blind" in name: | |
return BlindObservations(params, static_env_params) | |
else: | |
raise ValueError(f"Unknown name {name}") | |
class ActionType: | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
# This is the processed, unified action space size that is shared with all action types | |
# 1 dim per motor and thruster | |
self.unified_action_space_size = static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings | |
def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]: | |
raise NotImplementedError() | |
def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray: | |
raise NotImplementedError() | |
def noop_action(self) -> jnp.ndarray: | |
raise NotImplementedError() | |
def random_action(self, rng: chex.PRNGKey): | |
raise NotImplementedError() | |
class ActionTypeContinuous(ActionType): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
super().__init__(params, static_env_params) | |
self.params = params | |
self.static_env_params = static_env_params | |
def action_space(self, params: EnvParams | None = None) -> spaces.Discrete | spaces.Box: | |
return spaces.Box( | |
low=jnp.ones(self.unified_action_space_size) * -1.0, | |
high=jnp.ones(self.unified_action_space_size) * 1.0, | |
shape=(self.unified_action_space_size,), | |
) | |
def process_action(self, action: PRNGKey, state: EnvState, static_env_params: StaticEnvParams) -> PRNGKey: | |
return convert_continuous_actions(action, state, static_env_params, self.params) | |
def noop_action(self) -> jnp.ndarray: | |
return jnp.zeros(self.unified_action_space_size, dtype=jnp.float32) | |
def random_action(self, rng: chex.PRNGKey) -> jnp.ndarray: | |
actions = jax.random.uniform(rng, shape=(self.unified_action_space_size,), minval=-1.0, maxval=1.0) | |
# Motors between -1 and 1, thrusters between 0 and 1 | |
actions = actions.at[self.static_env_params.num_motor_bindings :].set( | |
jnp.abs(actions[self.static_env_params.num_motor_bindings :]) | |
) | |
return actions | |
class ActionTypeDiscrete(ActionType): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
super().__init__(params, static_env_params) | |
self.params = params | |
self.static_env_params = static_env_params | |
self._n_actions = ( | |
self.static_env_params.num_motor_bindings * 2 + 1 + self.static_env_params.num_thruster_bindings | |
) | |
def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete: | |
return spaces.Discrete(self._n_actions) | |
def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray: | |
return convert_discrete_actions(action, state, static_env_params, self.params) | |
def noop_action(self) -> int: | |
return self.static_env_params.num_motor_bindings * 2 | |
def random_action(self, rng: chex.PRNGKey): | |
return jax.random.randint(rng, shape=(), minval=0, maxval=self._n_actions) | |
class MultiDiscrete(Space): | |
def __init__(self, n, number_of_dims_per_distribution): | |
self.number_of_dims_per_distribution = number_of_dims_per_distribution | |
self.n = n | |
self.shape = (number_of_dims_per_distribution.shape[0],) | |
self.dtype = jnp.int_ | |
def sample(self, rng: chex.PRNGKey) -> chex.Array: | |
"""Sample random action uniformly from set of categorical choices.""" | |
uniform_sample = jax.random.uniform(rng, shape=self.shape) * self.number_of_dims_per_distribution | |
md_dist = jnp.floor(uniform_sample) | |
return md_dist.astype(self.dtype) | |
def contains(self, x) -> jnp.ndarray: | |
"""Check whether specific object is within space.""" | |
range_cond = jnp.logical_and(x >= 0, (x < self.number_of_dims_per_distribution).all()) | |
return range_cond | |
class ActionTypeMultiDiscrete(ActionType): | |
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): | |
super().__init__(params, static_env_params) | |
self.params = params | |
self.static_env_params = static_env_params | |
# This is the action space that will be used internally by an agent | |
# 3 dims per motor (foward, backward, off) and 2 per thruster (on, off) | |
self.n_hot_action_space_size = ( | |
self.static_env_params.num_motor_bindings * 3 + self.static_env_params.num_thruster_bindings * 2 | |
) | |
def _make_sample_random(): | |
minval = jnp.zeros(self.unified_action_space_size, dtype=jnp.int32) | |
maxval = jnp.ones(self.unified_action_space_size, dtype=jnp.int32) * 3 | |
maxval = maxval.at[self.static_env_params.num_motor_bindings :].set(2) | |
def random(rng): | |
return jax.random.randint(rng, shape=(self.unified_action_space_size,), minval=minval, maxval=maxval) | |
return random | |
self._random = _make_sample_random | |
self.number_of_dims_per_distribution = jnp.concatenate( | |
[ | |
np.ones(self.static_env_params.num_motor_bindings) * 3, | |
np.ones(self.static_env_params.num_thruster_bindings) * 2, | |
] | |
).astype(np.int32) | |
def action_space(self, params: Optional[EnvParams] = None) -> MultiDiscrete: | |
return MultiDiscrete(self.n_hot_action_space_size, self.number_of_dims_per_distribution) | |
def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray: | |
return convert_multi_discrete_actions(action, state, static_env_params, self.params) | |
def noop_action(self): | |
return jnp.zeros(self.unified_action_space_size, dtype=jnp.int32) | |
def random_action(self, rng: chex.PRNGKey): | |
return self._random()(rng) | |
class BasePhysicsEnv(UnderspecifiedEnv): | |
def __init__( | |
self, | |
action_type: ActionType, | |
observation_space: ObservationSpace, | |
static_env_params: StaticEnvParams, | |
target_index: int = 0, | |
filename_to_use_for_reset=None, # "worlds/games/bipedal_v1", | |
should_do_pcg_reset: bool = False, | |
): | |
super().__init__() | |
self.target_index = target_index | |
self.static_env_params = static_env_params | |
self.action_type = action_type | |
self._observation_space = observation_space | |
self.physics_engine = PhysicsEngine(self.static_env_params) | |
self.should_do_pcg_reset = should_do_pcg_reset | |
self.filename_to_use_for_reset = filename_to_use_for_reset | |
if self.filename_to_use_for_reset is not None: | |
self.reset_state = load_pcg_state_pickle(filename_to_use_for_reset) | |
else: | |
env_state = create_empty_env(self.static_env_params) | |
self.reset_state = env_state_to_pcg_state(env_state) | |
# Action / Observations | |
def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]: | |
return self.action_type.action_space(params) | |
def observation_space(self, params: Any): | |
return self._observation_space.observation_space(params) | |
def get_obs(self, state: EnvState): | |
return self._observation_space.get_obs(state) | |
def step_env(self, rng, state, action: jnp.ndarray, params): | |
action_processed = self.action_type.process_action(action, state, self.static_env_params) | |
return self.engine_step(state, action_processed, params) | |
def reset_env(self, rng, params): | |
# Wrap in AutoResetWrapper or AutoReplayWrapper | |
raise NotImplementedError() | |
def reset_env_to_level(self, rng, state: EnvState, params): | |
if isinstance(state, PCGState): | |
return self.reset_env_to_pcg_level(rng, state, params) | |
return self.get_obs(state), state | |
def reset_env_to_pcg_level(self, rng, state: PCGState, params): | |
env_state = sample_pcg_state(rng, state, params, self.static_env_params) | |
return self.get_obs(env_state), env_state | |
def default_params(self) -> EnvParams: | |
return EnvParams() | |
def default_static_params() -> StaticEnvParams: | |
return StaticEnvParams() | |
def compute_reward_info( | |
self, state: EnvState, manifolds: tuple[CollisionManifold, CollisionManifold, CollisionManifold] | |
) -> float: | |
def get_active(manifold: CollisionManifold) -> jnp.ndarray: | |
return manifold.active | |
def dist(a, b): | |
return jnp.linalg.norm(a - b) | |
def dist_rr(idxa, idxb): | |
return dist(state.polygon.position[idxa], state.polygon.position[idxb]) | |
def dist_cc(idxa, idxb): | |
return dist(state.circle.position[idxa], state.circle.position[idxb]) | |
def dist_cr(idxa, idxb): | |
return dist(state.circle.position[idxa], state.polygon.position[idxb]) | |
info = { | |
"GoalR": False, | |
} | |
negative_reward = 0 | |
reward = 0 | |
distance = 0 | |
rs = state.polygon_shape_roles * state.polygon.active | |
cs = state.circle_shape_roles * state.circle.active | |
# Polygon Polygon | |
r1 = rs[self.physics_engine.poly_poly_pairs[:, 0]] | |
r2 = rs[self.physics_engine.poly_poly_pairs[:, 1]] | |
reward += ((r1 * r2 == 2) * get_active(manifolds[0])).sum() | |
negative_reward += ((r1 * r2 == 3) * get_active(manifolds[0])).sum() | |
distance += ( | |
(r1 * r2 == 2) | |
* dist_rr(self.physics_engine.poly_poly_pairs[:, 0], self.physics_engine.poly_poly_pairs[:, 1]) | |
).sum() | |
# Circle Polygon | |
c1 = cs[self.physics_engine.circle_poly_pairs[:, 0]] | |
r2 = rs[self.physics_engine.circle_poly_pairs[:, 1]] | |
reward += ((c1 * r2 == 2) * get_active(manifolds[1])).sum() | |
negative_reward += ((c1 * r2 == 3) * get_active(manifolds[1])).sum() | |
t = dist_cr(self.physics_engine.circle_poly_pairs[:, 0], self.physics_engine.circle_poly_pairs[:, 1]) | |
distance += ((c1 * r2 == 2) * t).sum() | |
# Circle Circle | |
c1 = cs[self.physics_engine.circle_circle_pairs[:, 0]] | |
c2 = cs[self.physics_engine.circle_circle_pairs[:, 1]] | |
reward += ((c1 * c2 == 2) * get_active(manifolds[2])).sum() | |
negative_reward += ((c1 * c2 == 3) * get_active(manifolds[2])).sum() | |
distance += ( | |
(c1 * c2 == 2) | |
* dist_cc(self.physics_engine.circle_circle_pairs[:, 0], self.physics_engine.circle_circle_pairs[:, 1]) | |
).sum() | |
reward = jax.lax.select( | |
negative_reward > 0, | |
-1.0, | |
jax.lax.select( | |
reward > 0, | |
1.0, | |
0.0, | |
), | |
) | |
info["GoalR"] = reward > 0 | |
info["distance"] = distance | |
return reward, info | |
def engine_step(self, env_state, action_to_perform, env_params): | |
def _single_step(env_state, unused): | |
env_state, mfolds = self.physics_engine.step( | |
env_state, | |
env_params, | |
action_to_perform, | |
) | |
reward, info = self.compute_reward_info(env_state, mfolds) | |
done = reward != 0 | |
info = {"rr_manifolds": None, "cr_manifolds": None} | info | |
return env_state, (reward, done, info) | |
env_state, (rewards, dones, infos) = jax.lax.scan( | |
_single_step, env_state, xs=None, length=self.static_env_params.frame_skip | |
) | |
env_state = env_state.replace(timestep=env_state.timestep + 1) | |
reward = rewards.max() | |
done = dones.sum() > 0 | jax.tree.reduce( | |
jnp.logical_or, jax.tree.map(lambda x: jnp.isnan(x).any(), env_state), False | |
) | |
done |= env_state.timestep >= env_params.max_timesteps | |
info = jax.tree.map(lambda x: x[-1], infos) | |
return ( | |
lax.stop_gradient(self.get_obs(env_state)), | |
lax.stop_gradient(env_state), | |
reward, | |
done, | |
info, | |
) | |
def step( | |
self, | |
key: chex.PRNGKey, | |
state: TEnvState, | |
action: Union[int, float, chex.Array], | |
params: Optional[TEnvParams] = None, | |
) -> Tuple[chex.Array, TEnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]: | |
raise NotImplementedError() | |
class KinetixPixelsDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeDiscrete(params, static_env_params), | |
observation_space=PixelObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Pixels-Discrete-v1" | |
class KinetixPixelsContinuousActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeContinuous(params, static_env_params), | |
observation_space=PixelObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Pixels-Continuous-v1" | |
class KinetixPixelsMultiDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeMultiDiscrete(params, static_env_params), | |
observation_space=PixelObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Pixels-MultiDiscrete-v1" | |
class KinetixSymbolicDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeDiscrete(params, static_env_params), | |
observation_space=SymbolicObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Symbolic-Discrete-v1" | |
class KinetixSymbolicContinuousActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeContinuous(params, static_env_params), | |
observation_space=SymbolicObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Symbolic-Continuous-v1" | |
class KinetixSymbolicMultiDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeMultiDiscrete(params, static_env_params), | |
observation_space=SymbolicObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Symbolic-MultiDiscrete-v1" | |
class KinetixEntityDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeDiscrete(params, static_env_params), | |
observation_space=EntityObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Entity-Discrete-v1" | |
class KinetixEntityContinuousActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeContinuous(params, static_env_params), | |
observation_space=EntityObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Entity-Continuous-v1" | |
class KinetixEntityMultiDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeMultiDiscrete(params, static_env_params), | |
observation_space=EntityObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Entity-MultiDiscrete-v1" | |
class KinetixBlindDiscreteActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeDiscrete(params, static_env_params), | |
observation_space=BlindObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Blind-Discrete-v1" | |
class KinetixBlindContinuousActions(BasePhysicsEnv): | |
def __init__( | |
self, | |
static_env_params: StaticEnvParams | None = None, | |
**kwargs, | |
): | |
params = self.default_params | |
static_env_params = static_env_params or self.default_static_params() | |
super().__init__( | |
action_type=ActionTypeContinuous(params, static_env_params), | |
observation_space=BlindObservations(params, static_env_params), | |
static_env_params=static_env_params, | |
**kwargs, | |
) | |
def name(self) -> str: | |
return "Kinetix-Blind-Continuous-v1" | |