import functools from functools import partial from typing import Any, Dict, Optional, Tuple, Union import chex import jax import jax.numpy as jnp import numpy as np from chex._src.pytypes import PRNGKey from gymnax.environments import environment, spaces from gymnax.environments.environment import TEnvParams, TEnvState from gymnax.environments.spaces import Space from jax import lax from jax2d.engine import PhysicsEngine, create_empty_sim, recalculate_mass_and_inertia from jax2d.sim_state import CollisionManifold, SimState from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams from kinetix.environment.wrappers import ( AutoReplayWrapper, AutoResetWrapper, UnderspecifiedToGymnaxWrapper, DenseRewardWrapper, LogWrapper, ) from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state from kinetix.pcg.pcg_state import PCGState from kinetix.render.renderer_symbolic_entity import make_render_entities from kinetix.render.renderer_pixels import make_render_pixels, make_render_pixels_rl from kinetix.render.renderer_symbolic_flat import make_render_symbolic from kinetix.util.saving import load_pcg_state_pickle from jaxued.environments import UnderspecifiedEnv def create_empty_env(static_env_params): sim_state = create_empty_sim(static_env_params) return EnvState( timestep=0, thruster_bindings=jnp.zeros(static_env_params.num_thrusters, dtype=jnp.int32), motor_bindings=jnp.zeros(static_env_params.num_joints, dtype=jnp.int32), motor_auto=jnp.zeros(static_env_params.num_joints, dtype=bool), polygon_shape_roles=jnp.zeros(static_env_params.num_polygons, dtype=jnp.int32), circle_shape_roles=jnp.zeros(static_env_params.num_circles, dtype=jnp.int32), polygon_highlighted=jnp.zeros(static_env_params.num_polygons, dtype=bool), circle_highlighted=jnp.zeros(static_env_params.num_circles, dtype=bool), polygon_densities=jnp.ones(static_env_params.num_polygons, dtype=jnp.float32), circle_densities=jnp.ones(static_env_params.num_circles, dtype=jnp.float32), **sim_state.__dict__, ) def index_motor_actions( action: jnp.ndarray, state: EnvState, clip_min=None, clip_max=None, ): # Expand the motor actions to all joints with the same colour return jnp.clip(action[state.motor_bindings], clip_min, clip_max) def index_thruster_actions( action: jnp.ndarray, state: EnvState, clip_min=None, clip_max=None, ): # Expand the thruster actions to all joints with the same colour return jnp.clip(action[state.thruster_bindings], clip_min, clip_max) def convert_continuous_actions( action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams ): action_motor = action[: static_env_params.num_motor_bindings] action_thruster = action[static_env_params.num_motor_bindings :] action_motor = index_motor_actions(action_motor, state, -1, 1) action_thruster = index_thruster_actions(action_thruster, state, 0, 1) action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor) action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0) return action_to_perform def convert_discrete_actions(action: int, state: SimState, static_env_params: StaticEnvParams, params: EnvParams): # so, we have # 0 to NJC * 2 - 1: Joint Actions # NJC * 2: No-op # NJC * 2 + 1 to NJC * 2 + 1 + NTC - 1: Thruster Actions # action here is a categorical action which_idx = action // 2 which_dir = action % 2 actions = ( jnp.zeros(static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings) .at[which_idx] .set(which_dir * 2 - 1) ) actions = actions * ( 1 - (action >= static_env_params.num_motor_bindings * 2) ) # if action is the last one, set it to zero, i.e., a no-op. Alternatively, if the action is larger than NJC * 2, then it is a thruster action and we shouldn't control the joints. actions = jax.lax.select( action > static_env_params.num_motor_bindings * 2, actions.at[action - static_env_params.num_motor_bindings * 2 - 1 + static_env_params.num_motor_bindings].set(1), actions, ) action_motor = index_motor_actions(actions[: static_env_params.num_motor_bindings], state, -1, 1) action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor) action_thruster = index_thruster_actions(actions[static_env_params.num_motor_bindings :], state, 0, 1) action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0) return action_to_perform def convert_multi_discrete_actions( action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams ): # Comes in with each action being in {0,1,2} for joints and {0,1} for thrusters # Convert to [-1., 1.] for joints and [0., 1.] for thrusters def _single_motor_action(act): return jax.lax.switch( act, [lambda: 0.0, lambda: 1.0, lambda: -1.0], ) def _single_thruster_act(act): return jax.lax.select( act == 0, 0.0, 1.0, ) action_motor = jax.vmap(_single_motor_action)(action[: static_env_params.num_motor_bindings]) action_thruster = jax.vmap(_single_thruster_act)(action[static_env_params.num_motor_bindings :]) action_motor = index_motor_actions(action_motor, state, -1, 1) action_thruster = index_thruster_actions(action_thruster, state, 0, 1) action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor) action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0) return action_to_perform def make_kinetix_env_from_args( obs_type, action_type, reset_type, static_env_params=None, auto_reset_fn=None, dense_reward_scale=1.0 ): if obs_type == "entity": if action_type == "multidiscrete": env = KinetixEntityMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "discrete": env = KinetixEntityDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "continuous": env = KinetixEntityContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) else: raise ValueError(f"Unknown action type: {action_type}") elif obs_type == "symbolic": if action_type == "multidiscrete": env = KinetixSymbolicMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "discrete": env = KinetixSymbolicDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "continuous": env = KinetixSymbolicContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) else: raise ValueError(f"Unknown action type: {action_type}") elif obs_type == "pixels": if action_type == "multidiscrete": env = KinetixPixelsMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "discrete": env = KinetixPixelsDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "continuous": env = KinetixPixelsContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) else: raise ValueError(f"Unknown action type: {action_type}") elif obs_type == "blind": if action_type == "discrete": env = KinetixBlindDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params) elif action_type == "continuous": env = KinetixBlindContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params) else: raise ValueError(f"Unknown action type: {action_type}") else: raise ValueError(f"Unknown observation type: {obs_type}") # Wrap if reset_type == "replay": env = AutoReplayWrapper(env) elif reset_type == "reset": env = AutoResetWrapper(env, sample_level=auto_reset_fn) else: raise ValueError(f"Unknown reset type {reset_type}") env = UnderspecifiedToGymnaxWrapper(env) env = DenseRewardWrapper(env, dense_reward_scale=dense_reward_scale) env = LogWrapper(env) return env def make_kinetix_env_from_name(name, static_env_params=None): kwargs = dict(filename_to_use_for_reset=None, should_do_pcg_reset=True, static_env_params=static_env_params) values = { "Kinetix-Pixels-MultiDiscrete-v1": KinetixPixelsMultiDiscreteActions, "Kinetix-Pixels-Discrete-v1": KinetixPixelsDiscreteActions, "Kinetix-Pixels-Continuous-v1": KinetixPixelsContinuousActions, "Kinetix-Symbolic-MultiDiscrete-v1": KinetixSymbolicMultiDiscreteActions, "Kinetix-Symbolic-Discrete-v1": KinetixSymbolicDiscreteActions, "Kinetix-Symbolic-Continuous-v1": KinetixSymbolicContinuousActions, "Kinetix-Blind-Discrete-v1": KinetixBlindDiscreteActions, "Kinetix-Blind-Continuous-v1": KinetixBlindContinuousActions, "Kinetix-Entity-Discrete-v1": KinetixEntityDiscreteActions, "Kinetix-Entity-Continuous-v1": KinetixEntityContinuousActions, "Kinetix-Entity-MultiDiscrete-v1": KinetixEntityMultiDiscreteActions, } return values[name](**kwargs) class ObservationSpace: def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): pass def get_obs(self, state: EnvState): raise NotImplementedError() def observation_space(self, params: EnvParams): raise NotImplementedError() class PixelObservations(ObservationSpace): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): self.render_function = make_render_pixels_rl(params, static_env_params) self.static_env_params = static_env_params def get_obs(self, state: EnvState): return self.render_function(state) def observation_space(self, params: EnvParams) -> spaces.Box: return spaces.Box( 0.0, 1.0, tuple(a // self.static_env_params.downscale for a in self.static_env_params.screen_dim) + (3,), dtype=jnp.float32, ) class SymbolicObservations(ObservationSpace): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): self.render_function = make_render_symbolic(params, static_env_params) def get_obs(self, state: EnvState): return self.render_function(state) class EntityObservations(ObservationSpace): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): self.render_function = make_render_entities(params, static_env_params) def get_obs(self, state: EnvState): return self.render_function(state) class BlindObservations(ObservationSpace): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): self.params = params def get_obs(self, state: EnvState): return jax.nn.one_hot(state.timestep, self.params.max_timesteps + 1) def get_observation_space_from_name(name: str, params, static_env_params): if "Pixels" in name: return PixelObservations(params, static_env_params) elif "Symbolic" in name: return SymbolicObservations(params, static_env_params) elif "Entity" in name: return EntityObservations(params, static_env_params) if "Blind" in name: return BlindObservations(params, static_env_params) else: raise ValueError(f"Unknown name {name}") class ActionType: def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): # This is the processed, unified action space size that is shared with all action types # 1 dim per motor and thruster self.unified_action_space_size = static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]: raise NotImplementedError() def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray: raise NotImplementedError() def noop_action(self) -> jnp.ndarray: raise NotImplementedError() def random_action(self, rng: chex.PRNGKey): raise NotImplementedError() class ActionTypeContinuous(ActionType): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): super().__init__(params, static_env_params) self.params = params self.static_env_params = static_env_params def action_space(self, params: EnvParams | None = None) -> spaces.Discrete | spaces.Box: return spaces.Box( low=jnp.ones(self.unified_action_space_size) * -1.0, high=jnp.ones(self.unified_action_space_size) * 1.0, shape=(self.unified_action_space_size,), ) def process_action(self, action: PRNGKey, state: EnvState, static_env_params: StaticEnvParams) -> PRNGKey: return convert_continuous_actions(action, state, static_env_params, self.params) def noop_action(self) -> jnp.ndarray: return jnp.zeros(self.unified_action_space_size, dtype=jnp.float32) def random_action(self, rng: chex.PRNGKey) -> jnp.ndarray: actions = jax.random.uniform(rng, shape=(self.unified_action_space_size,), minval=-1.0, maxval=1.0) # Motors between -1 and 1, thrusters between 0 and 1 actions = actions.at[self.static_env_params.num_motor_bindings :].set( jnp.abs(actions[self.static_env_params.num_motor_bindings :]) ) return actions class ActionTypeDiscrete(ActionType): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): super().__init__(params, static_env_params) self.params = params self.static_env_params = static_env_params self._n_actions = ( self.static_env_params.num_motor_bindings * 2 + 1 + self.static_env_params.num_thruster_bindings ) def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete: return spaces.Discrete(self._n_actions) def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray: return convert_discrete_actions(action, state, static_env_params, self.params) def noop_action(self) -> int: return self.static_env_params.num_motor_bindings * 2 def random_action(self, rng: chex.PRNGKey): return jax.random.randint(rng, shape=(), minval=0, maxval=self._n_actions) class MultiDiscrete(Space): def __init__(self, n, number_of_dims_per_distribution): self.number_of_dims_per_distribution = number_of_dims_per_distribution self.n = n self.shape = (number_of_dims_per_distribution.shape[0],) self.dtype = jnp.int_ def sample(self, rng: chex.PRNGKey) -> chex.Array: """Sample random action uniformly from set of categorical choices.""" uniform_sample = jax.random.uniform(rng, shape=self.shape) * self.number_of_dims_per_distribution md_dist = jnp.floor(uniform_sample) return md_dist.astype(self.dtype) def contains(self, x) -> jnp.ndarray: """Check whether specific object is within space.""" range_cond = jnp.logical_and(x >= 0, (x < self.number_of_dims_per_distribution).all()) return range_cond class ActionTypeMultiDiscrete(ActionType): def __init__(self, params: EnvParams, static_env_params: StaticEnvParams): super().__init__(params, static_env_params) self.params = params self.static_env_params = static_env_params # This is the action space that will be used internally by an agent # 3 dims per motor (foward, backward, off) and 2 per thruster (on, off) self.n_hot_action_space_size = ( self.static_env_params.num_motor_bindings * 3 + self.static_env_params.num_thruster_bindings * 2 ) def _make_sample_random(): minval = jnp.zeros(self.unified_action_space_size, dtype=jnp.int32) maxval = jnp.ones(self.unified_action_space_size, dtype=jnp.int32) * 3 maxval = maxval.at[self.static_env_params.num_motor_bindings :].set(2) def random(rng): return jax.random.randint(rng, shape=(self.unified_action_space_size,), minval=minval, maxval=maxval) return random self._random = _make_sample_random self.number_of_dims_per_distribution = jnp.concatenate( [ np.ones(self.static_env_params.num_motor_bindings) * 3, np.ones(self.static_env_params.num_thruster_bindings) * 2, ] ).astype(np.int32) def action_space(self, params: Optional[EnvParams] = None) -> MultiDiscrete: return MultiDiscrete(self.n_hot_action_space_size, self.number_of_dims_per_distribution) def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray: return convert_multi_discrete_actions(action, state, static_env_params, self.params) def noop_action(self): return jnp.zeros(self.unified_action_space_size, dtype=jnp.int32) def random_action(self, rng: chex.PRNGKey): return self._random()(rng) class BasePhysicsEnv(UnderspecifiedEnv): def __init__( self, action_type: ActionType, observation_space: ObservationSpace, static_env_params: StaticEnvParams, target_index: int = 0, filename_to_use_for_reset=None, # "worlds/games/bipedal_v1", should_do_pcg_reset: bool = False, ): super().__init__() self.target_index = target_index self.static_env_params = static_env_params self.action_type = action_type self._observation_space = observation_space self.physics_engine = PhysicsEngine(self.static_env_params) self.should_do_pcg_reset = should_do_pcg_reset self.filename_to_use_for_reset = filename_to_use_for_reset if self.filename_to_use_for_reset is not None: self.reset_state = load_pcg_state_pickle(filename_to_use_for_reset) else: env_state = create_empty_env(self.static_env_params) self.reset_state = env_state_to_pcg_state(env_state) # Action / Observations def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]: return self.action_type.action_space(params) def observation_space(self, params: Any): return self._observation_space.observation_space(params) def get_obs(self, state: EnvState): return self._observation_space.get_obs(state) def step_env(self, rng, state, action: jnp.ndarray, params): action_processed = self.action_type.process_action(action, state, self.static_env_params) return self.engine_step(state, action_processed, params) def reset_env(self, rng, params): # Wrap in AutoResetWrapper or AutoReplayWrapper raise NotImplementedError() def reset_env_to_level(self, rng, state: EnvState, params): if isinstance(state, PCGState): return self.reset_env_to_pcg_level(rng, state, params) return self.get_obs(state), state def reset_env_to_pcg_level(self, rng, state: PCGState, params): env_state = sample_pcg_state(rng, state, params, self.static_env_params) return self.get_obs(env_state), env_state @property def default_params(self) -> EnvParams: return EnvParams() @staticmethod def default_static_params() -> StaticEnvParams: return StaticEnvParams() def compute_reward_info( self, state: EnvState, manifolds: tuple[CollisionManifold, CollisionManifold, CollisionManifold] ) -> float: def get_active(manifold: CollisionManifold) -> jnp.ndarray: return manifold.active def dist(a, b): return jnp.linalg.norm(a - b) @jax.vmap def dist_rr(idxa, idxb): return dist(state.polygon.position[idxa], state.polygon.position[idxb]) @jax.vmap def dist_cc(idxa, idxb): return dist(state.circle.position[idxa], state.circle.position[idxb]) @jax.vmap def dist_cr(idxa, idxb): return dist(state.circle.position[idxa], state.polygon.position[idxb]) info = { "GoalR": False, } negative_reward = 0 reward = 0 distance = 0 rs = state.polygon_shape_roles * state.polygon.active cs = state.circle_shape_roles * state.circle.active # Polygon Polygon r1 = rs[self.physics_engine.poly_poly_pairs[:, 0]] r2 = rs[self.physics_engine.poly_poly_pairs[:, 1]] reward += ((r1 * r2 == 2) * get_active(manifolds[0])).sum() negative_reward += ((r1 * r2 == 3) * get_active(manifolds[0])).sum() distance += ( (r1 * r2 == 2) * dist_rr(self.physics_engine.poly_poly_pairs[:, 0], self.physics_engine.poly_poly_pairs[:, 1]) ).sum() # Circle Polygon c1 = cs[self.physics_engine.circle_poly_pairs[:, 0]] r2 = rs[self.physics_engine.circle_poly_pairs[:, 1]] reward += ((c1 * r2 == 2) * get_active(manifolds[1])).sum() negative_reward += ((c1 * r2 == 3) * get_active(manifolds[1])).sum() t = dist_cr(self.physics_engine.circle_poly_pairs[:, 0], self.physics_engine.circle_poly_pairs[:, 1]) distance += ((c1 * r2 == 2) * t).sum() # Circle Circle c1 = cs[self.physics_engine.circle_circle_pairs[:, 0]] c2 = cs[self.physics_engine.circle_circle_pairs[:, 1]] reward += ((c1 * c2 == 2) * get_active(manifolds[2])).sum() negative_reward += ((c1 * c2 == 3) * get_active(manifolds[2])).sum() distance += ( (c1 * c2 == 2) * dist_cc(self.physics_engine.circle_circle_pairs[:, 0], self.physics_engine.circle_circle_pairs[:, 1]) ).sum() reward = jax.lax.select( negative_reward > 0, -1.0, jax.lax.select( reward > 0, 1.0, 0.0, ), ) info["GoalR"] = reward > 0 info["distance"] = distance return reward, info @partial(jax.jit, static_argnums=(0,)) def engine_step(self, env_state, action_to_perform, env_params): def _single_step(env_state, unused): env_state, mfolds = self.physics_engine.step( env_state, env_params, action_to_perform, ) reward, info = self.compute_reward_info(env_state, mfolds) done = reward != 0 info = {"rr_manifolds": None, "cr_manifolds": None} | info return env_state, (reward, done, info) env_state, (rewards, dones, infos) = jax.lax.scan( _single_step, env_state, xs=None, length=self.static_env_params.frame_skip ) env_state = env_state.replace(timestep=env_state.timestep + 1) reward = rewards.max() done = dones.sum() > 0 | jax.tree.reduce( jnp.logical_or, jax.tree.map(lambda x: jnp.isnan(x).any(), env_state), False ) done |= env_state.timestep >= env_params.max_timesteps info = jax.tree.map(lambda x: x[-1], infos) return ( lax.stop_gradient(self.get_obs(env_state)), lax.stop_gradient(env_state), reward, done, info, ) @functools.partial(jax.jit, static_argnums=(0,)) def step( self, key: chex.PRNGKey, state: TEnvState, action: Union[int, float, chex.Array], params: Optional[TEnvParams] = None, ) -> Tuple[chex.Array, TEnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]: raise NotImplementedError() class KinetixPixelsDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeDiscrete(params, static_env_params), observation_space=PixelObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Pixels-Discrete-v1" class KinetixPixelsContinuousActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeContinuous(params, static_env_params), observation_space=PixelObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Pixels-Continuous-v1" class KinetixPixelsMultiDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeMultiDiscrete(params, static_env_params), observation_space=PixelObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Pixels-MultiDiscrete-v1" class KinetixSymbolicDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeDiscrete(params, static_env_params), observation_space=SymbolicObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Symbolic-Discrete-v1" class KinetixSymbolicContinuousActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeContinuous(params, static_env_params), observation_space=SymbolicObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Symbolic-Continuous-v1" class KinetixSymbolicMultiDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeMultiDiscrete(params, static_env_params), observation_space=SymbolicObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Symbolic-MultiDiscrete-v1" class KinetixEntityDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeDiscrete(params, static_env_params), observation_space=EntityObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Entity-Discrete-v1" class KinetixEntityContinuousActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeContinuous(params, static_env_params), observation_space=EntityObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Entity-Continuous-v1" class KinetixEntityMultiDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeMultiDiscrete(params, static_env_params), observation_space=EntityObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Entity-MultiDiscrete-v1" class KinetixBlindDiscreteActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeDiscrete(params, static_env_params), observation_space=BlindObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Blind-Discrete-v1" class KinetixBlindContinuousActions(BasePhysicsEnv): def __init__( self, static_env_params: StaticEnvParams | None = None, **kwargs, ): params = self.default_params static_env_params = static_env_params or self.default_static_params() super().__init__( action_type=ActionTypeContinuous(params, static_env_params), observation_space=BlindObservations(params, static_env_params), static_env_params=static_env_params, **kwargs, ) @property def name(self) -> str: return "Kinetix-Blind-Continuous-v1"