Spaces:
Runtime error
Runtime error
Upload 46 files
Browse files- kinetix/__init__.py +0 -0
- kinetix/assets/circle.png +0 -0
- kinetix/assets/edit.png +0 -0
- kinetix/assets/fjoint.png +0 -0
- kinetix/assets/fjoint2.png +0 -0
- kinetix/assets/hand.png +0 -0
- kinetix/assets/joint.png +0 -0
- kinetix/assets/play.png +0 -0
- kinetix/assets/rjoint.png +0 -0
- kinetix/assets/rjoint2.png +0 -0
- kinetix/assets/rotate.png +0 -0
- kinetix/assets/square.png +0 -0
- kinetix/assets/thruster.png +0 -0
- kinetix/assets/thruster6.png +0 -0
- kinetix/assets/triangle.png +0 -0
- kinetix/editor.py +0 -0
- kinetix/environment/__init__.py +0 -0
- kinetix/environment/env.py +829 -0
- kinetix/environment/env_state.py +43 -0
- kinetix/environment/ued/distributions.py +349 -0
- kinetix/environment/ued/mutators.py +1157 -0
- kinetix/environment/ued/ued.py +249 -0
- kinetix/environment/ued/ued_state.py +53 -0
- kinetix/environment/ued/util.py +358 -0
- kinetix/environment/utils.py +66 -0
- kinetix/environment/wrappers.py +309 -0
- kinetix/models/.gitignore +2 -0
- kinetix/models/__init__.py +65 -0
- kinetix/models/action_spaces.py +58 -0
- kinetix/models/actor_critic.py +206 -0
- kinetix/models/rel_multi_head.py +546 -0
- kinetix/models/transformer_model.py +302 -0
- kinetix/pcg/__init__.py +0 -0
- kinetix/pcg/pcg.py +97 -0
- kinetix/pcg/pcg_state.py +24 -0
- kinetix/render/__init__.py +0 -0
- kinetix/render/renderer_pixels.py +290 -0
- kinetix/render/renderer_symbolic_common.py +190 -0
- kinetix/render/renderer_symbolic_entity.py +121 -0
- kinetix/render/renderer_symbolic_flat.py +102 -0
- kinetix/render/textures.py +43 -0
- kinetix/util/__init__.py +0 -0
- kinetix/util/config.py +229 -0
- kinetix/util/learning.py +565 -0
- kinetix/util/saving.py +540 -0
- kinetix/util/timing.py +15 -0
kinetix/__init__.py
ADDED
File without changes
|
kinetix/assets/circle.png
ADDED
![]() |
kinetix/assets/edit.png
ADDED
![]() |
kinetix/assets/fjoint.png
ADDED
![]() |
kinetix/assets/fjoint2.png
ADDED
![]() |
kinetix/assets/hand.png
ADDED
![]() |
kinetix/assets/joint.png
ADDED
![]() |
kinetix/assets/play.png
ADDED
![]() |
kinetix/assets/rjoint.png
ADDED
![]() |
kinetix/assets/rjoint2.png
ADDED
![]() |
kinetix/assets/rotate.png
ADDED
![]() |
kinetix/assets/square.png
ADDED
![]() |
kinetix/assets/thruster.png
ADDED
![]() |
kinetix/assets/thruster6.png
ADDED
![]() |
kinetix/assets/triangle.png
ADDED
![]() |
kinetix/editor.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
kinetix/environment/__init__.py
ADDED
File without changes
|
kinetix/environment/env.py
ADDED
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from functools import partial
|
3 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import chex
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import numpy as np
|
9 |
+
from chex._src.pytypes import PRNGKey
|
10 |
+
from gymnax.environments import environment, spaces
|
11 |
+
from gymnax.environments.environment import TEnvParams, TEnvState
|
12 |
+
from gymnax.environments.spaces import Space
|
13 |
+
from jax import lax
|
14 |
+
|
15 |
+
from jax2d.engine import PhysicsEngine, create_empty_sim, recalculate_mass_and_inertia
|
16 |
+
from jax2d.sim_state import CollisionManifold, SimState
|
17 |
+
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
|
18 |
+
from kinetix.environment.wrappers import (
|
19 |
+
AutoReplayWrapper,
|
20 |
+
AutoResetWrapper,
|
21 |
+
UnderspecifiedToGymnaxWrapper,
|
22 |
+
DenseRewardWrapper,
|
23 |
+
LogWrapper,
|
24 |
+
)
|
25 |
+
|
26 |
+
from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state
|
27 |
+
from kinetix.pcg.pcg_state import PCGState
|
28 |
+
from kinetix.render.renderer_symbolic_entity import make_render_entities
|
29 |
+
from kinetix.render.renderer_pixels import make_render_pixels, make_render_pixels_rl
|
30 |
+
from kinetix.render.renderer_symbolic_flat import make_render_symbolic
|
31 |
+
|
32 |
+
from kinetix.util.saving import load_pcg_state_pickle
|
33 |
+
from jaxued.environments import UnderspecifiedEnv
|
34 |
+
|
35 |
+
|
36 |
+
def create_empty_env(static_env_params):
|
37 |
+
sim_state = create_empty_sim(static_env_params)
|
38 |
+
return EnvState(
|
39 |
+
timestep=0,
|
40 |
+
thruster_bindings=jnp.zeros(static_env_params.num_thrusters, dtype=jnp.int32),
|
41 |
+
motor_bindings=jnp.zeros(static_env_params.num_joints, dtype=jnp.int32),
|
42 |
+
motor_auto=jnp.zeros(static_env_params.num_joints, dtype=bool),
|
43 |
+
polygon_shape_roles=jnp.zeros(static_env_params.num_polygons, dtype=jnp.int32),
|
44 |
+
circle_shape_roles=jnp.zeros(static_env_params.num_circles, dtype=jnp.int32),
|
45 |
+
polygon_highlighted=jnp.zeros(static_env_params.num_polygons, dtype=bool),
|
46 |
+
circle_highlighted=jnp.zeros(static_env_params.num_circles, dtype=bool),
|
47 |
+
polygon_densities=jnp.ones(static_env_params.num_polygons, dtype=jnp.float32),
|
48 |
+
circle_densities=jnp.ones(static_env_params.num_circles, dtype=jnp.float32),
|
49 |
+
**sim_state.__dict__,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def index_motor_actions(
|
54 |
+
action: jnp.ndarray,
|
55 |
+
state: EnvState,
|
56 |
+
clip_min=None,
|
57 |
+
clip_max=None,
|
58 |
+
):
|
59 |
+
# Expand the motor actions to all joints with the same colour
|
60 |
+
return jnp.clip(action[state.motor_bindings], clip_min, clip_max)
|
61 |
+
|
62 |
+
|
63 |
+
def index_thruster_actions(
|
64 |
+
action: jnp.ndarray,
|
65 |
+
state: EnvState,
|
66 |
+
clip_min=None,
|
67 |
+
clip_max=None,
|
68 |
+
):
|
69 |
+
# Expand the thruster actions to all joints with the same colour
|
70 |
+
return jnp.clip(action[state.thruster_bindings], clip_min, clip_max)
|
71 |
+
|
72 |
+
|
73 |
+
def convert_continuous_actions(
|
74 |
+
action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams
|
75 |
+
):
|
76 |
+
action_motor = action[: static_env_params.num_motor_bindings]
|
77 |
+
action_thruster = action[static_env_params.num_motor_bindings :]
|
78 |
+
action_motor = index_motor_actions(action_motor, state, -1, 1)
|
79 |
+
action_thruster = index_thruster_actions(action_thruster, state, 0, 1)
|
80 |
+
|
81 |
+
action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor)
|
82 |
+
|
83 |
+
action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0)
|
84 |
+
return action_to_perform
|
85 |
+
|
86 |
+
|
87 |
+
def convert_discrete_actions(action: int, state: SimState, static_env_params: StaticEnvParams, params: EnvParams):
|
88 |
+
# so, we have
|
89 |
+
# 0 to NJC * 2 - 1: Joint Actions
|
90 |
+
# NJC * 2: No-op
|
91 |
+
# NJC * 2 + 1 to NJC * 2 + 1 + NTC - 1: Thruster Actions
|
92 |
+
# action here is a categorical action
|
93 |
+
which_idx = action // 2
|
94 |
+
which_dir = action % 2
|
95 |
+
actions = (
|
96 |
+
jnp.zeros(static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings)
|
97 |
+
.at[which_idx]
|
98 |
+
.set(which_dir * 2 - 1)
|
99 |
+
)
|
100 |
+
actions = actions * (
|
101 |
+
1 - (action >= static_env_params.num_motor_bindings * 2)
|
102 |
+
) # if action is the last one, set it to zero, i.e., a no-op. Alternatively, if the action is larger than NJC * 2, then it is a thruster action and we shouldn't control the joints.
|
103 |
+
|
104 |
+
actions = jax.lax.select(
|
105 |
+
action > static_env_params.num_motor_bindings * 2,
|
106 |
+
actions.at[action - static_env_params.num_motor_bindings * 2 - 1 + static_env_params.num_motor_bindings].set(1),
|
107 |
+
actions,
|
108 |
+
)
|
109 |
+
|
110 |
+
action_motor = index_motor_actions(actions[: static_env_params.num_motor_bindings], state, -1, 1)
|
111 |
+
action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor)
|
112 |
+
action_thruster = index_thruster_actions(actions[static_env_params.num_motor_bindings :], state, 0, 1)
|
113 |
+
action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0)
|
114 |
+
return action_to_perform
|
115 |
+
|
116 |
+
|
117 |
+
def convert_multi_discrete_actions(
|
118 |
+
action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams
|
119 |
+
):
|
120 |
+
# Comes in with each action being in {0,1,2} for joints and {0,1} for thrusters
|
121 |
+
# Convert to [-1., 1.] for joints and [0., 1.] for thrusters
|
122 |
+
|
123 |
+
def _single_motor_action(act):
|
124 |
+
return jax.lax.switch(
|
125 |
+
act,
|
126 |
+
[lambda: 0.0, lambda: 1.0, lambda: -1.0],
|
127 |
+
)
|
128 |
+
|
129 |
+
def _single_thruster_act(act):
|
130 |
+
return jax.lax.select(
|
131 |
+
act == 0,
|
132 |
+
0.0,
|
133 |
+
1.0,
|
134 |
+
)
|
135 |
+
|
136 |
+
action_motor = jax.vmap(_single_motor_action)(action[: static_env_params.num_motor_bindings])
|
137 |
+
action_thruster = jax.vmap(_single_thruster_act)(action[static_env_params.num_motor_bindings :])
|
138 |
+
|
139 |
+
action_motor = index_motor_actions(action_motor, state, -1, 1)
|
140 |
+
action_thruster = index_thruster_actions(action_thruster, state, 0, 1)
|
141 |
+
|
142 |
+
action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor)
|
143 |
+
|
144 |
+
action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0)
|
145 |
+
return action_to_perform
|
146 |
+
|
147 |
+
|
148 |
+
def make_kinetix_env_from_args(
|
149 |
+
obs_type, action_type, reset_type, static_env_params=None, auto_reset_fn=None, dense_reward_scale=1.0
|
150 |
+
):
|
151 |
+
if obs_type == "entity":
|
152 |
+
if action_type == "multidiscrete":
|
153 |
+
env = KinetixEntityMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
154 |
+
elif action_type == "discrete":
|
155 |
+
env = KinetixEntityDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
156 |
+
elif action_type == "continuous":
|
157 |
+
env = KinetixEntityContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
158 |
+
else:
|
159 |
+
raise ValueError(f"Unknown action type: {action_type}")
|
160 |
+
|
161 |
+
elif obs_type == "symbolic":
|
162 |
+
if action_type == "multidiscrete":
|
163 |
+
env = KinetixSymbolicMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
164 |
+
elif action_type == "discrete":
|
165 |
+
env = KinetixSymbolicDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
166 |
+
elif action_type == "continuous":
|
167 |
+
env = KinetixSymbolicContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
168 |
+
else:
|
169 |
+
raise ValueError(f"Unknown action type: {action_type}")
|
170 |
+
|
171 |
+
elif obs_type == "pixels":
|
172 |
+
if action_type == "multidiscrete":
|
173 |
+
env = KinetixPixelsMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
174 |
+
elif action_type == "discrete":
|
175 |
+
env = KinetixPixelsDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
176 |
+
elif action_type == "continuous":
|
177 |
+
env = KinetixPixelsContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Unknown action type: {action_type}")
|
180 |
+
|
181 |
+
elif obs_type == "blind":
|
182 |
+
if action_type == "discrete":
|
183 |
+
env = KinetixBlindDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
184 |
+
elif action_type == "continuous":
|
185 |
+
env = KinetixBlindContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
|
186 |
+
else:
|
187 |
+
raise ValueError(f"Unknown action type: {action_type}")
|
188 |
+
|
189 |
+
else:
|
190 |
+
raise ValueError(f"Unknown observation type: {obs_type}")
|
191 |
+
|
192 |
+
# Wrap
|
193 |
+
if reset_type == "replay":
|
194 |
+
env = AutoReplayWrapper(env)
|
195 |
+
elif reset_type == "reset":
|
196 |
+
env = AutoResetWrapper(env, sample_level=auto_reset_fn)
|
197 |
+
else:
|
198 |
+
raise ValueError(f"Unknown reset type {reset_type}")
|
199 |
+
|
200 |
+
env = UnderspecifiedToGymnaxWrapper(env)
|
201 |
+
env = DenseRewardWrapper(env, dense_reward_scale=dense_reward_scale)
|
202 |
+
env = LogWrapper(env)
|
203 |
+
|
204 |
+
return env
|
205 |
+
|
206 |
+
|
207 |
+
def make_kinetix_env_from_name(name, static_env_params=None):
|
208 |
+
kwargs = dict(filename_to_use_for_reset=None, should_do_pcg_reset=True, static_env_params=static_env_params)
|
209 |
+
values = {
|
210 |
+
"Kinetix-Pixels-MultiDiscrete-v1": KinetixPixelsMultiDiscreteActions,
|
211 |
+
"Kinetix-Pixels-Discrete-v1": KinetixPixelsDiscreteActions,
|
212 |
+
"Kinetix-Pixels-Continuous-v1": KinetixPixelsContinuousActions,
|
213 |
+
"Kinetix-Symbolic-MultiDiscrete-v1": KinetixSymbolicMultiDiscreteActions,
|
214 |
+
"Kinetix-Symbolic-Discrete-v1": KinetixSymbolicDiscreteActions,
|
215 |
+
"Kinetix-Symbolic-Continuous-v1": KinetixSymbolicContinuousActions,
|
216 |
+
"Kinetix-Blind-Discrete-v1": KinetixBlindDiscreteActions,
|
217 |
+
"Kinetix-Blind-Continuous-v1": KinetixBlindContinuousActions,
|
218 |
+
"Kinetix-Entity-Discrete-v1": KinetixEntityDiscreteActions,
|
219 |
+
"Kinetix-Entity-Continuous-v1": KinetixEntityContinuousActions,
|
220 |
+
"Kinetix-Entity-MultiDiscrete-v1": KinetixEntityMultiDiscreteActions,
|
221 |
+
}
|
222 |
+
|
223 |
+
return values[name](**kwargs)
|
224 |
+
|
225 |
+
|
226 |
+
class ObservationSpace:
|
227 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
228 |
+
pass
|
229 |
+
|
230 |
+
def get_obs(self, state: EnvState):
|
231 |
+
raise NotImplementedError()
|
232 |
+
|
233 |
+
def observation_space(self, params: EnvParams):
|
234 |
+
raise NotImplementedError()
|
235 |
+
|
236 |
+
|
237 |
+
class PixelObservations(ObservationSpace):
|
238 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
239 |
+
self.render_function = make_render_pixels_rl(params, static_env_params)
|
240 |
+
self.static_env_params = static_env_params
|
241 |
+
|
242 |
+
def get_obs(self, state: EnvState):
|
243 |
+
return self.render_function(state)
|
244 |
+
|
245 |
+
def observation_space(self, params: EnvParams) -> spaces.Box:
|
246 |
+
return spaces.Box(
|
247 |
+
0.0,
|
248 |
+
1.0,
|
249 |
+
tuple(a // self.static_env_params.downscale for a in self.static_env_params.screen_dim) + (3,),
|
250 |
+
dtype=jnp.float32,
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
class SymbolicObservations(ObservationSpace):
|
255 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
256 |
+
self.render_function = make_render_symbolic(params, static_env_params)
|
257 |
+
|
258 |
+
def get_obs(self, state: EnvState):
|
259 |
+
return self.render_function(state)
|
260 |
+
|
261 |
+
|
262 |
+
class EntityObservations(ObservationSpace):
|
263 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
264 |
+
self.render_function = make_render_entities(params, static_env_params)
|
265 |
+
|
266 |
+
def get_obs(self, state: EnvState):
|
267 |
+
return self.render_function(state)
|
268 |
+
|
269 |
+
|
270 |
+
class BlindObservations(ObservationSpace):
|
271 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
272 |
+
self.params = params
|
273 |
+
|
274 |
+
def get_obs(self, state: EnvState):
|
275 |
+
return jax.nn.one_hot(state.timestep, self.params.max_timesteps + 1)
|
276 |
+
|
277 |
+
|
278 |
+
def get_observation_space_from_name(name: str, params, static_env_params):
|
279 |
+
if "Pixels" in name:
|
280 |
+
return PixelObservations(params, static_env_params)
|
281 |
+
elif "Symbolic" in name:
|
282 |
+
return SymbolicObservations(params, static_env_params)
|
283 |
+
elif "Entity" in name:
|
284 |
+
return EntityObservations(params, static_env_params)
|
285 |
+
if "Blind" in name:
|
286 |
+
return BlindObservations(params, static_env_params)
|
287 |
+
else:
|
288 |
+
raise ValueError(f"Unknown name {name}")
|
289 |
+
|
290 |
+
|
291 |
+
class ActionType:
|
292 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
293 |
+
# This is the processed, unified action space size that is shared with all action types
|
294 |
+
# 1 dim per motor and thruster
|
295 |
+
self.unified_action_space_size = static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings
|
296 |
+
|
297 |
+
def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]:
|
298 |
+
raise NotImplementedError()
|
299 |
+
|
300 |
+
def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray:
|
301 |
+
raise NotImplementedError()
|
302 |
+
|
303 |
+
def noop_action(self) -> jnp.ndarray:
|
304 |
+
raise NotImplementedError()
|
305 |
+
|
306 |
+
def random_action(self, rng: chex.PRNGKey):
|
307 |
+
raise NotImplementedError()
|
308 |
+
|
309 |
+
|
310 |
+
class ActionTypeContinuous(ActionType):
|
311 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
312 |
+
super().__init__(params, static_env_params)
|
313 |
+
|
314 |
+
self.params = params
|
315 |
+
self.static_env_params = static_env_params
|
316 |
+
|
317 |
+
def action_space(self, params: EnvParams | None = None) -> spaces.Discrete | spaces.Box:
|
318 |
+
return spaces.Box(
|
319 |
+
low=jnp.ones(self.unified_action_space_size) * -1.0,
|
320 |
+
high=jnp.ones(self.unified_action_space_size) * 1.0,
|
321 |
+
shape=(self.unified_action_space_size,),
|
322 |
+
)
|
323 |
+
|
324 |
+
def process_action(self, action: PRNGKey, state: EnvState, static_env_params: StaticEnvParams) -> PRNGKey:
|
325 |
+
return convert_continuous_actions(action, state, static_env_params, self.params)
|
326 |
+
|
327 |
+
def noop_action(self) -> jnp.ndarray:
|
328 |
+
return jnp.zeros(self.unified_action_space_size, dtype=jnp.float32)
|
329 |
+
|
330 |
+
def random_action(self, rng: chex.PRNGKey) -> jnp.ndarray:
|
331 |
+
actions = jax.random.uniform(rng, shape=(self.unified_action_space_size,), minval=-1.0, maxval=1.0)
|
332 |
+
# Motors between -1 and 1, thrusters between 0 and 1
|
333 |
+
actions = actions.at[self.static_env_params.num_motor_bindings :].set(
|
334 |
+
jnp.abs(actions[self.static_env_params.num_motor_bindings :])
|
335 |
+
)
|
336 |
+
|
337 |
+
return actions
|
338 |
+
|
339 |
+
|
340 |
+
class ActionTypeDiscrete(ActionType):
|
341 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
342 |
+
super().__init__(params, static_env_params)
|
343 |
+
|
344 |
+
self.params = params
|
345 |
+
self.static_env_params = static_env_params
|
346 |
+
|
347 |
+
self._n_actions = (
|
348 |
+
self.static_env_params.num_motor_bindings * 2 + 1 + self.static_env_params.num_thruster_bindings
|
349 |
+
)
|
350 |
+
|
351 |
+
def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
|
352 |
+
return spaces.Discrete(self._n_actions)
|
353 |
+
|
354 |
+
def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray:
|
355 |
+
return convert_discrete_actions(action, state, static_env_params, self.params)
|
356 |
+
|
357 |
+
def noop_action(self) -> int:
|
358 |
+
return self.static_env_params.num_motor_bindings * 2
|
359 |
+
|
360 |
+
def random_action(self, rng: chex.PRNGKey):
|
361 |
+
return jax.random.randint(rng, shape=(), minval=0, maxval=self._n_actions)
|
362 |
+
|
363 |
+
|
364 |
+
class MultiDiscrete(Space):
|
365 |
+
def __init__(self, n, number_of_dims_per_distribution):
|
366 |
+
self.number_of_dims_per_distribution = number_of_dims_per_distribution
|
367 |
+
self.n = n
|
368 |
+
self.shape = (number_of_dims_per_distribution.shape[0],)
|
369 |
+
self.dtype = jnp.int_
|
370 |
+
|
371 |
+
def sample(self, rng: chex.PRNGKey) -> chex.Array:
|
372 |
+
"""Sample random action uniformly from set of categorical choices."""
|
373 |
+
uniform_sample = jax.random.uniform(rng, shape=self.shape) * self.number_of_dims_per_distribution
|
374 |
+
md_dist = jnp.floor(uniform_sample)
|
375 |
+
return md_dist.astype(self.dtype)
|
376 |
+
|
377 |
+
def contains(self, x) -> jnp.ndarray:
|
378 |
+
"""Check whether specific object is within space."""
|
379 |
+
range_cond = jnp.logical_and(x >= 0, (x < self.number_of_dims_per_distribution).all())
|
380 |
+
return range_cond
|
381 |
+
|
382 |
+
|
383 |
+
class ActionTypeMultiDiscrete(ActionType):
|
384 |
+
def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
|
385 |
+
super().__init__(params, static_env_params)
|
386 |
+
|
387 |
+
self.params = params
|
388 |
+
self.static_env_params = static_env_params
|
389 |
+
# This is the action space that will be used internally by an agent
|
390 |
+
# 3 dims per motor (foward, backward, off) and 2 per thruster (on, off)
|
391 |
+
self.n_hot_action_space_size = (
|
392 |
+
self.static_env_params.num_motor_bindings * 3 + self.static_env_params.num_thruster_bindings * 2
|
393 |
+
)
|
394 |
+
|
395 |
+
def _make_sample_random():
|
396 |
+
minval = jnp.zeros(self.unified_action_space_size, dtype=jnp.int32)
|
397 |
+
maxval = jnp.ones(self.unified_action_space_size, dtype=jnp.int32) * 3
|
398 |
+
maxval = maxval.at[self.static_env_params.num_motor_bindings :].set(2)
|
399 |
+
|
400 |
+
def random(rng):
|
401 |
+
return jax.random.randint(rng, shape=(self.unified_action_space_size,), minval=minval, maxval=maxval)
|
402 |
+
|
403 |
+
return random
|
404 |
+
|
405 |
+
self._random = _make_sample_random
|
406 |
+
|
407 |
+
self.number_of_dims_per_distribution = jnp.concatenate(
|
408 |
+
[
|
409 |
+
np.ones(self.static_env_params.num_motor_bindings) * 3,
|
410 |
+
np.ones(self.static_env_params.num_thruster_bindings) * 2,
|
411 |
+
]
|
412 |
+
).astype(np.int32)
|
413 |
+
|
414 |
+
def action_space(self, params: Optional[EnvParams] = None) -> MultiDiscrete:
|
415 |
+
return MultiDiscrete(self.n_hot_action_space_size, self.number_of_dims_per_distribution)
|
416 |
+
|
417 |
+
def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray:
|
418 |
+
return convert_multi_discrete_actions(action, state, static_env_params, self.params)
|
419 |
+
|
420 |
+
def noop_action(self):
|
421 |
+
return jnp.zeros(self.unified_action_space_size, dtype=jnp.int32)
|
422 |
+
|
423 |
+
def random_action(self, rng: chex.PRNGKey):
|
424 |
+
return self._random()(rng)
|
425 |
+
|
426 |
+
|
427 |
+
class BasePhysicsEnv(UnderspecifiedEnv):
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
action_type: ActionType,
|
431 |
+
observation_space: ObservationSpace,
|
432 |
+
static_env_params: StaticEnvParams,
|
433 |
+
target_index: int = 0,
|
434 |
+
filename_to_use_for_reset=None, # "worlds/games/bipedal_v1",
|
435 |
+
should_do_pcg_reset: bool = False,
|
436 |
+
):
|
437 |
+
super().__init__()
|
438 |
+
self.target_index = target_index
|
439 |
+
self.static_env_params = static_env_params
|
440 |
+
self.action_type = action_type
|
441 |
+
self._observation_space = observation_space
|
442 |
+
self.physics_engine = PhysicsEngine(self.static_env_params)
|
443 |
+
self.should_do_pcg_reset = should_do_pcg_reset
|
444 |
+
|
445 |
+
self.filename_to_use_for_reset = filename_to_use_for_reset
|
446 |
+
if self.filename_to_use_for_reset is not None:
|
447 |
+
self.reset_state = load_pcg_state_pickle(filename_to_use_for_reset)
|
448 |
+
else:
|
449 |
+
env_state = create_empty_env(self.static_env_params)
|
450 |
+
self.reset_state = env_state_to_pcg_state(env_state)
|
451 |
+
|
452 |
+
# Action / Observations
|
453 |
+
def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]:
|
454 |
+
return self.action_type.action_space(params)
|
455 |
+
|
456 |
+
def observation_space(self, params: Any):
|
457 |
+
return self._observation_space.observation_space(params)
|
458 |
+
|
459 |
+
def get_obs(self, state: EnvState):
|
460 |
+
return self._observation_space.get_obs(state)
|
461 |
+
|
462 |
+
def step_env(self, rng, state, action: jnp.ndarray, params):
|
463 |
+
action_processed = self.action_type.process_action(action, state, self.static_env_params)
|
464 |
+
return self.engine_step(state, action_processed, params)
|
465 |
+
|
466 |
+
def reset_env(self, rng, params):
|
467 |
+
# Wrap in AutoResetWrapper or AutoReplayWrapper
|
468 |
+
raise NotImplementedError()
|
469 |
+
|
470 |
+
def reset_env_to_level(self, rng, state: EnvState, params):
|
471 |
+
if isinstance(state, PCGState):
|
472 |
+
return self.reset_env_to_pcg_level(rng, state, params)
|
473 |
+
return self.get_obs(state), state
|
474 |
+
|
475 |
+
def reset_env_to_pcg_level(self, rng, state: PCGState, params):
|
476 |
+
env_state = sample_pcg_state(rng, state, params, self.static_env_params)
|
477 |
+
return self.get_obs(env_state), env_state
|
478 |
+
|
479 |
+
@property
|
480 |
+
def default_params(self) -> EnvParams:
|
481 |
+
return EnvParams()
|
482 |
+
|
483 |
+
@staticmethod
|
484 |
+
def default_static_params() -> StaticEnvParams:
|
485 |
+
return StaticEnvParams()
|
486 |
+
|
487 |
+
def compute_reward_info(
|
488 |
+
self, state: EnvState, manifolds: tuple[CollisionManifold, CollisionManifold, CollisionManifold]
|
489 |
+
) -> float:
|
490 |
+
def get_active(manifold: CollisionManifold) -> jnp.ndarray:
|
491 |
+
return manifold.active
|
492 |
+
|
493 |
+
def dist(a, b):
|
494 |
+
return jnp.linalg.norm(a - b)
|
495 |
+
|
496 |
+
@jax.vmap
|
497 |
+
def dist_rr(idxa, idxb):
|
498 |
+
return dist(state.polygon.position[idxa], state.polygon.position[idxb])
|
499 |
+
|
500 |
+
@jax.vmap
|
501 |
+
def dist_cc(idxa, idxb):
|
502 |
+
return dist(state.circle.position[idxa], state.circle.position[idxb])
|
503 |
+
|
504 |
+
@jax.vmap
|
505 |
+
def dist_cr(idxa, idxb):
|
506 |
+
return dist(state.circle.position[idxa], state.polygon.position[idxb])
|
507 |
+
|
508 |
+
info = {
|
509 |
+
"GoalR": False,
|
510 |
+
}
|
511 |
+
negative_reward = 0
|
512 |
+
reward = 0
|
513 |
+
distance = 0
|
514 |
+
rs = state.polygon_shape_roles * state.polygon.active
|
515 |
+
cs = state.circle_shape_roles * state.circle.active
|
516 |
+
|
517 |
+
# Polygon Polygon
|
518 |
+
r1 = rs[self.physics_engine.poly_poly_pairs[:, 0]]
|
519 |
+
r2 = rs[self.physics_engine.poly_poly_pairs[:, 1]]
|
520 |
+
reward += ((r1 * r2 == 2) * get_active(manifolds[0])).sum()
|
521 |
+
negative_reward += ((r1 * r2 == 3) * get_active(manifolds[0])).sum()
|
522 |
+
|
523 |
+
distance += (
|
524 |
+
(r1 * r2 == 2)
|
525 |
+
* dist_rr(self.physics_engine.poly_poly_pairs[:, 0], self.physics_engine.poly_poly_pairs[:, 1])
|
526 |
+
).sum()
|
527 |
+
|
528 |
+
# Circle Polygon
|
529 |
+
c1 = cs[self.physics_engine.circle_poly_pairs[:, 0]]
|
530 |
+
r2 = rs[self.physics_engine.circle_poly_pairs[:, 1]]
|
531 |
+
reward += ((c1 * r2 == 2) * get_active(manifolds[1])).sum()
|
532 |
+
negative_reward += ((c1 * r2 == 3) * get_active(manifolds[1])).sum()
|
533 |
+
|
534 |
+
t = dist_cr(self.physics_engine.circle_poly_pairs[:, 0], self.physics_engine.circle_poly_pairs[:, 1])
|
535 |
+
distance += ((c1 * r2 == 2) * t).sum()
|
536 |
+
|
537 |
+
# Circle Circle
|
538 |
+
c1 = cs[self.physics_engine.circle_circle_pairs[:, 0]]
|
539 |
+
c2 = cs[self.physics_engine.circle_circle_pairs[:, 1]]
|
540 |
+
reward += ((c1 * c2 == 2) * get_active(manifolds[2])).sum()
|
541 |
+
negative_reward += ((c1 * c2 == 3) * get_active(manifolds[2])).sum()
|
542 |
+
|
543 |
+
distance += (
|
544 |
+
(c1 * c2 == 2)
|
545 |
+
* dist_cc(self.physics_engine.circle_circle_pairs[:, 0], self.physics_engine.circle_circle_pairs[:, 1])
|
546 |
+
).sum()
|
547 |
+
|
548 |
+
reward = jax.lax.select(
|
549 |
+
negative_reward > 0,
|
550 |
+
-1.0,
|
551 |
+
jax.lax.select(
|
552 |
+
reward > 0,
|
553 |
+
1.0,
|
554 |
+
0.0,
|
555 |
+
),
|
556 |
+
)
|
557 |
+
|
558 |
+
info["GoalR"] = reward > 0
|
559 |
+
info["distance"] = distance
|
560 |
+
return reward, info
|
561 |
+
|
562 |
+
@partial(jax.jit, static_argnums=(0,))
|
563 |
+
def engine_step(self, env_state, action_to_perform, env_params):
|
564 |
+
def _single_step(env_state, unused):
|
565 |
+
env_state, mfolds = self.physics_engine.step(
|
566 |
+
env_state,
|
567 |
+
env_params,
|
568 |
+
action_to_perform,
|
569 |
+
)
|
570 |
+
|
571 |
+
reward, info = self.compute_reward_info(env_state, mfolds)
|
572 |
+
|
573 |
+
done = reward != 0
|
574 |
+
|
575 |
+
info = {"rr_manifolds": None, "cr_manifolds": None} | info
|
576 |
+
|
577 |
+
return env_state, (reward, done, info)
|
578 |
+
|
579 |
+
env_state, (rewards, dones, infos) = jax.lax.scan(
|
580 |
+
_single_step, env_state, xs=None, length=self.static_env_params.frame_skip
|
581 |
+
)
|
582 |
+
env_state = env_state.replace(timestep=env_state.timestep + 1)
|
583 |
+
|
584 |
+
reward = rewards.max()
|
585 |
+
done = dones.sum() > 0 | jax.tree.reduce(
|
586 |
+
jnp.logical_or, jax.tree.map(lambda x: jnp.isnan(x).any(), env_state), False
|
587 |
+
)
|
588 |
+
done |= env_state.timestep >= env_params.max_timesteps
|
589 |
+
|
590 |
+
info = jax.tree.map(lambda x: x[-1], infos)
|
591 |
+
|
592 |
+
return (
|
593 |
+
lax.stop_gradient(self.get_obs(env_state)),
|
594 |
+
lax.stop_gradient(env_state),
|
595 |
+
reward,
|
596 |
+
done,
|
597 |
+
info,
|
598 |
+
)
|
599 |
+
|
600 |
+
@functools.partial(jax.jit, static_argnums=(0,))
|
601 |
+
def step(
|
602 |
+
self,
|
603 |
+
key: chex.PRNGKey,
|
604 |
+
state: TEnvState,
|
605 |
+
action: Union[int, float, chex.Array],
|
606 |
+
params: Optional[TEnvParams] = None,
|
607 |
+
) -> Tuple[chex.Array, TEnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
|
608 |
+
raise NotImplementedError()
|
609 |
+
|
610 |
+
|
611 |
+
class KinetixPixelsDiscreteActions(BasePhysicsEnv):
|
612 |
+
def __init__(
|
613 |
+
self,
|
614 |
+
static_env_params: StaticEnvParams | None = None,
|
615 |
+
**kwargs,
|
616 |
+
):
|
617 |
+
|
618 |
+
params = self.default_params
|
619 |
+
static_env_params = static_env_params or self.default_static_params()
|
620 |
+
super().__init__(
|
621 |
+
action_type=ActionTypeDiscrete(params, static_env_params),
|
622 |
+
observation_space=PixelObservations(params, static_env_params),
|
623 |
+
static_env_params=static_env_params,
|
624 |
+
**kwargs,
|
625 |
+
)
|
626 |
+
|
627 |
+
@property
|
628 |
+
def name(self) -> str:
|
629 |
+
return "Kinetix-Pixels-Discrete-v1"
|
630 |
+
|
631 |
+
|
632 |
+
class KinetixPixelsContinuousActions(BasePhysicsEnv):
|
633 |
+
def __init__(
|
634 |
+
self,
|
635 |
+
static_env_params: StaticEnvParams | None = None,
|
636 |
+
**kwargs,
|
637 |
+
):
|
638 |
+
params = self.default_params
|
639 |
+
static_env_params = static_env_params or self.default_static_params()
|
640 |
+
super().__init__(
|
641 |
+
action_type=ActionTypeContinuous(params, static_env_params),
|
642 |
+
observation_space=PixelObservations(params, static_env_params),
|
643 |
+
static_env_params=static_env_params,
|
644 |
+
**kwargs,
|
645 |
+
)
|
646 |
+
|
647 |
+
@property
|
648 |
+
def name(self) -> str:
|
649 |
+
return "Kinetix-Pixels-Continuous-v1"
|
650 |
+
|
651 |
+
|
652 |
+
class KinetixPixelsMultiDiscreteActions(BasePhysicsEnv):
|
653 |
+
def __init__(
|
654 |
+
self,
|
655 |
+
static_env_params: StaticEnvParams | None = None,
|
656 |
+
**kwargs,
|
657 |
+
):
|
658 |
+
params = self.default_params
|
659 |
+
static_env_params = static_env_params or self.default_static_params()
|
660 |
+
super().__init__(
|
661 |
+
action_type=ActionTypeMultiDiscrete(params, static_env_params),
|
662 |
+
observation_space=PixelObservations(params, static_env_params),
|
663 |
+
static_env_params=static_env_params,
|
664 |
+
**kwargs,
|
665 |
+
)
|
666 |
+
|
667 |
+
@property
|
668 |
+
def name(self) -> str:
|
669 |
+
return "Kinetix-Pixels-MultiDiscrete-v1"
|
670 |
+
|
671 |
+
|
672 |
+
class KinetixSymbolicDiscreteActions(BasePhysicsEnv):
|
673 |
+
def __init__(
|
674 |
+
self,
|
675 |
+
static_env_params: StaticEnvParams | None = None,
|
676 |
+
**kwargs,
|
677 |
+
):
|
678 |
+
params = self.default_params
|
679 |
+
static_env_params = static_env_params or self.default_static_params()
|
680 |
+
super().__init__(
|
681 |
+
action_type=ActionTypeDiscrete(params, static_env_params),
|
682 |
+
observation_space=SymbolicObservations(params, static_env_params),
|
683 |
+
static_env_params=static_env_params,
|
684 |
+
**kwargs,
|
685 |
+
)
|
686 |
+
|
687 |
+
@property
|
688 |
+
def name(self) -> str:
|
689 |
+
return "Kinetix-Symbolic-Discrete-v1"
|
690 |
+
|
691 |
+
|
692 |
+
class KinetixSymbolicContinuousActions(BasePhysicsEnv):
|
693 |
+
def __init__(
|
694 |
+
self,
|
695 |
+
static_env_params: StaticEnvParams | None = None,
|
696 |
+
**kwargs,
|
697 |
+
):
|
698 |
+
params = self.default_params
|
699 |
+
static_env_params = static_env_params or self.default_static_params()
|
700 |
+
super().__init__(
|
701 |
+
action_type=ActionTypeContinuous(params, static_env_params),
|
702 |
+
observation_space=SymbolicObservations(params, static_env_params),
|
703 |
+
static_env_params=static_env_params,
|
704 |
+
**kwargs,
|
705 |
+
)
|
706 |
+
|
707 |
+
@property
|
708 |
+
def name(self) -> str:
|
709 |
+
return "Kinetix-Symbolic-Continuous-v1"
|
710 |
+
|
711 |
+
|
712 |
+
class KinetixSymbolicMultiDiscreteActions(BasePhysicsEnv):
|
713 |
+
def __init__(
|
714 |
+
self,
|
715 |
+
static_env_params: StaticEnvParams | None = None,
|
716 |
+
**kwargs,
|
717 |
+
):
|
718 |
+
params = self.default_params
|
719 |
+
static_env_params = static_env_params or self.default_static_params()
|
720 |
+
super().__init__(
|
721 |
+
action_type=ActionTypeMultiDiscrete(params, static_env_params),
|
722 |
+
observation_space=SymbolicObservations(params, static_env_params),
|
723 |
+
static_env_params=static_env_params,
|
724 |
+
**kwargs,
|
725 |
+
)
|
726 |
+
|
727 |
+
@property
|
728 |
+
def name(self) -> str:
|
729 |
+
return "Kinetix-Symbolic-MultiDiscrete-v1"
|
730 |
+
|
731 |
+
|
732 |
+
class KinetixEntityDiscreteActions(BasePhysicsEnv):
|
733 |
+
def __init__(
|
734 |
+
self,
|
735 |
+
static_env_params: StaticEnvParams | None = None,
|
736 |
+
**kwargs,
|
737 |
+
):
|
738 |
+
params = self.default_params
|
739 |
+
static_env_params = static_env_params or self.default_static_params()
|
740 |
+
super().__init__(
|
741 |
+
action_type=ActionTypeDiscrete(params, static_env_params),
|
742 |
+
observation_space=EntityObservations(params, static_env_params),
|
743 |
+
static_env_params=static_env_params,
|
744 |
+
**kwargs,
|
745 |
+
)
|
746 |
+
|
747 |
+
@property
|
748 |
+
def name(self) -> str:
|
749 |
+
return "Kinetix-Entity-Discrete-v1"
|
750 |
+
|
751 |
+
|
752 |
+
class KinetixEntityContinuousActions(BasePhysicsEnv):
|
753 |
+
def __init__(
|
754 |
+
self,
|
755 |
+
static_env_params: StaticEnvParams | None = None,
|
756 |
+
**kwargs,
|
757 |
+
):
|
758 |
+
params = self.default_params
|
759 |
+
static_env_params = static_env_params or self.default_static_params()
|
760 |
+
super().__init__(
|
761 |
+
action_type=ActionTypeContinuous(params, static_env_params),
|
762 |
+
observation_space=EntityObservations(params, static_env_params),
|
763 |
+
static_env_params=static_env_params,
|
764 |
+
**kwargs,
|
765 |
+
)
|
766 |
+
|
767 |
+
@property
|
768 |
+
def name(self) -> str:
|
769 |
+
return "Kinetix-Entity-Continuous-v1"
|
770 |
+
|
771 |
+
|
772 |
+
class KinetixEntityMultiDiscreteActions(BasePhysicsEnv):
|
773 |
+
def __init__(
|
774 |
+
self,
|
775 |
+
static_env_params: StaticEnvParams | None = None,
|
776 |
+
**kwargs,
|
777 |
+
):
|
778 |
+
params = self.default_params
|
779 |
+
static_env_params = static_env_params or self.default_static_params()
|
780 |
+
super().__init__(
|
781 |
+
action_type=ActionTypeMultiDiscrete(params, static_env_params),
|
782 |
+
observation_space=EntityObservations(params, static_env_params),
|
783 |
+
static_env_params=static_env_params,
|
784 |
+
**kwargs,
|
785 |
+
)
|
786 |
+
|
787 |
+
@property
|
788 |
+
def name(self) -> str:
|
789 |
+
return "Kinetix-Entity-MultiDiscrete-v1"
|
790 |
+
|
791 |
+
|
792 |
+
class KinetixBlindDiscreteActions(BasePhysicsEnv):
|
793 |
+
def __init__(
|
794 |
+
self,
|
795 |
+
static_env_params: StaticEnvParams | None = None,
|
796 |
+
**kwargs,
|
797 |
+
):
|
798 |
+
params = self.default_params
|
799 |
+
static_env_params = static_env_params or self.default_static_params()
|
800 |
+
super().__init__(
|
801 |
+
action_type=ActionTypeDiscrete(params, static_env_params),
|
802 |
+
observation_space=BlindObservations(params, static_env_params),
|
803 |
+
static_env_params=static_env_params,
|
804 |
+
**kwargs,
|
805 |
+
)
|
806 |
+
|
807 |
+
@property
|
808 |
+
def name(self) -> str:
|
809 |
+
return "Kinetix-Blind-Discrete-v1"
|
810 |
+
|
811 |
+
|
812 |
+
class KinetixBlindContinuousActions(BasePhysicsEnv):
|
813 |
+
def __init__(
|
814 |
+
self,
|
815 |
+
static_env_params: StaticEnvParams | None = None,
|
816 |
+
**kwargs,
|
817 |
+
):
|
818 |
+
params = self.default_params
|
819 |
+
static_env_params = static_env_params or self.default_static_params()
|
820 |
+
super().__init__(
|
821 |
+
action_type=ActionTypeContinuous(params, static_env_params),
|
822 |
+
observation_space=BlindObservations(params, static_env_params),
|
823 |
+
static_env_params=static_env_params,
|
824 |
+
**kwargs,
|
825 |
+
)
|
826 |
+
|
827 |
+
@property
|
828 |
+
def name(self) -> str:
|
829 |
+
return "Kinetix-Blind-Continuous-v1"
|
kinetix/environment/env_state.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import field
|
2 |
+
import jax.numpy as jnp
|
3 |
+
from flax import struct
|
4 |
+
|
5 |
+
from jax2d.sim_state import SimState, SimParams, StaticSimParams
|
6 |
+
|
7 |
+
|
8 |
+
@struct.dataclass
|
9 |
+
class EnvState(SimState):
|
10 |
+
thruster_bindings: jnp.ndarray
|
11 |
+
motor_bindings: jnp.ndarray
|
12 |
+
motor_auto: jnp.ndarray
|
13 |
+
|
14 |
+
polygon_shape_roles: jnp.ndarray
|
15 |
+
circle_shape_roles: jnp.ndarray
|
16 |
+
|
17 |
+
polygon_highlighted: jnp.ndarray
|
18 |
+
circle_highlighted: jnp.ndarray
|
19 |
+
|
20 |
+
polygon_densities: jnp.ndarray
|
21 |
+
circle_densities: jnp.ndarray
|
22 |
+
|
23 |
+
timestep: int = 0
|
24 |
+
|
25 |
+
|
26 |
+
@struct.dataclass
|
27 |
+
class EnvParams(SimParams):
|
28 |
+
max_timesteps: int = 256
|
29 |
+
pixels_per_unit: int = 100
|
30 |
+
dense_reward_scale: float = 0.1
|
31 |
+
num_shape_roles: int = 4
|
32 |
+
|
33 |
+
|
34 |
+
@struct.dataclass
|
35 |
+
class StaticEnvParams(StaticSimParams):
|
36 |
+
screen_dim: tuple[int, int] = (500, 500)
|
37 |
+
downscale: int = 4
|
38 |
+
|
39 |
+
frame_skip: int = 1
|
40 |
+
max_shape_size: int = 2
|
41 |
+
|
42 |
+
num_motor_bindings: int = 4
|
43 |
+
num_thruster_bindings: int = 2
|
kinetix/environment/ued/distributions.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import math
|
3 |
+
|
4 |
+
import chex
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
from flax.serialization import to_state_dict
|
8 |
+
from jax2d.engine import (
|
9 |
+
calculate_collision_matrix,
|
10 |
+
calc_inverse_mass_polygon,
|
11 |
+
calc_inverse_mass_circle,
|
12 |
+
calc_inverse_inertia_circle,
|
13 |
+
calc_inverse_inertia_polygon,
|
14 |
+
recalculate_mass_and_inertia,
|
15 |
+
select_shape,
|
16 |
+
PhysicsEngine,
|
17 |
+
)
|
18 |
+
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
|
19 |
+
from jax2d.maths import rmat
|
20 |
+
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
|
21 |
+
from kinetix.environment.ued.mutators import (
|
22 |
+
mutate_add_connected_shape_proper,
|
23 |
+
mutate_add_shape,
|
24 |
+
mutate_add_connected_shape,
|
25 |
+
mutate_add_thruster,
|
26 |
+
)
|
27 |
+
from kinetix.environment.ued.ued_state import UEDParams
|
28 |
+
from kinetix.environment.ued.util import (
|
29 |
+
get_role,
|
30 |
+
sample_dimensions,
|
31 |
+
is_space_for_shape,
|
32 |
+
random_position_on_polygon,
|
33 |
+
random_position_on_circle,
|
34 |
+
are_there_shapes_present,
|
35 |
+
is_space_for_joint,
|
36 |
+
)
|
37 |
+
from kinetix.environment.utils import permute_state
|
38 |
+
from kinetix.util.saving import load_world_state_pickle
|
39 |
+
from flax import struct
|
40 |
+
from kinetix.environment.env import create_empty_env
|
41 |
+
|
42 |
+
|
43 |
+
@partial(jax.jit, static_argnums=(1, 3, 5, 6, 7, 8, 9, 10))
|
44 |
+
def create_vmapped_filtered_distribution(
|
45 |
+
rng,
|
46 |
+
level_sampler,
|
47 |
+
env_params: EnvParams,
|
48 |
+
static_env_params: StaticEnvParams,
|
49 |
+
ued_params: UEDParams,
|
50 |
+
n_samples: int,
|
51 |
+
env,
|
52 |
+
do_filter_levels: bool,
|
53 |
+
level_filter_sample_ratio: int,
|
54 |
+
env_size_name: str,
|
55 |
+
level_filter_n_steps: int,
|
56 |
+
):
|
57 |
+
|
58 |
+
if do_filter_levels and level_filter_n_steps > 0:
|
59 |
+
sample_ratio = level_filter_sample_ratio
|
60 |
+
n_unfiltered_samples = sample_ratio * n_samples
|
61 |
+
rng, _rng = jax.random.split(rng)
|
62 |
+
_rngs = jax.random.split(_rng, n_unfiltered_samples)
|
63 |
+
|
64 |
+
# unfiltered_levels = jax.vmap(level_sampler, in_axes=(0, None, None, None, None))(
|
65 |
+
# _rngs, env_params, static_env_params, ued_params, env_size_name
|
66 |
+
# )
|
67 |
+
unfiltered_levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs)
|
68 |
+
#
|
69 |
+
|
70 |
+
# No-op filtering
|
71 |
+
|
72 |
+
def _noop_step(states, rng):
|
73 |
+
rng, _rng = jax.random.split(rng)
|
74 |
+
_rngs = jax.random.split(_rng, n_unfiltered_samples)
|
75 |
+
|
76 |
+
action = jnp.zeros((n_unfiltered_samples, *env.action_space(env_params).shape), dtype=jnp.int32)
|
77 |
+
|
78 |
+
obs, states, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
79 |
+
_rngs, states, action, env_params
|
80 |
+
)
|
81 |
+
|
82 |
+
return states, (done, reward)
|
83 |
+
|
84 |
+
# Wrap levels
|
85 |
+
rng, _rng = jax.random.split(rng)
|
86 |
+
_rngs = jax.random.split(_rng, n_unfiltered_samples)
|
87 |
+
obsv, unfiltered_levels_wrapped = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
|
88 |
+
_rngs, unfiltered_levels, env_params
|
89 |
+
)
|
90 |
+
|
91 |
+
rng, _rng = jax.random.split(rng)
|
92 |
+
_rngs = jax.random.split(_rng, level_filter_n_steps)
|
93 |
+
_, (done, rewards) = jax.lax.scan(_noop_step, unfiltered_levels_wrapped, xs=_rngs)
|
94 |
+
|
95 |
+
done_indexes = jnp.argmax(done, axis=0)
|
96 |
+
done_rewards = rewards[done_indexes, jnp.arange(n_unfiltered_samples)]
|
97 |
+
|
98 |
+
noop_solved_indexes = done_rewards > 0.5
|
99 |
+
p = noop_solved_indexes * 0.001 + (1 - noop_solved_indexes) * 1.0
|
100 |
+
p /= p.sum()
|
101 |
+
|
102 |
+
rng, _rng = jax.random.split(rng)
|
103 |
+
level_indexes = jax.random.choice(
|
104 |
+
_rng, jnp.arange(n_unfiltered_samples), shape=(n_samples,), replace=False, p=p
|
105 |
+
)
|
106 |
+
|
107 |
+
levels = jax.tree.map(lambda x: x[level_indexes], unfiltered_levels)
|
108 |
+
|
109 |
+
else:
|
110 |
+
rng, _rng = jax.random.split(rng)
|
111 |
+
_rngs = jax.random.split(_rng, n_samples)
|
112 |
+
|
113 |
+
levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs)
|
114 |
+
|
115 |
+
return levels
|
116 |
+
|
117 |
+
|
118 |
+
@partial(jax.jit, static_argnums=(1, 3, 4, 5))
|
119 |
+
def sample_kinetix_level(
|
120 |
+
rng,
|
121 |
+
engine: PhysicsEngine,
|
122 |
+
env_params: EnvParams,
|
123 |
+
static_env_params: StaticEnvParams,
|
124 |
+
ued_params: UEDParams,
|
125 |
+
env_size_name: str = "l",
|
126 |
+
):
|
127 |
+
rng, _rng = jax.random.split(rng)
|
128 |
+
_rngs = jax.random.split(_rng, 12)
|
129 |
+
|
130 |
+
small_force_no_fixate = env_size_name == "s"
|
131 |
+
|
132 |
+
# Start with empty state
|
133 |
+
state = create_empty_env(static_env_params)
|
134 |
+
|
135 |
+
# Set the floor
|
136 |
+
prob_of_floor_colour = jnp.array(
|
137 |
+
[
|
138 |
+
ued_params.floor_prob_normal,
|
139 |
+
ued_params.floor_prob_green,
|
140 |
+
ued_params.floor_prob_blue,
|
141 |
+
ued_params.floor_prob_red,
|
142 |
+
]
|
143 |
+
)
|
144 |
+
floor_colour = jax.random.choice(_rngs[0], jnp.arange(4), p=prob_of_floor_colour)
|
145 |
+
state = state.replace(polygon_shape_roles=state.polygon_shape_roles.at[0].set(floor_colour))
|
146 |
+
|
147 |
+
# When we add shapes we don't want them to collide with already existing shapes
|
148 |
+
def _choose_proposal_with_least_collisions(proposals, bias=None):
|
149 |
+
rr, cr, cc = jax.vmap(engine.calculate_collision_manifolds)(proposals)
|
150 |
+
|
151 |
+
rr_collisions = jnp.sum(jnp.sum(rr.active.astype(jnp.int32), axis=-1), axis=-1)
|
152 |
+
cr_collisions = jnp.sum(cr.active.astype(jnp.int32), axis=-1)
|
153 |
+
cc_collisions = jnp.sum(cc.active.astype(jnp.int32), axis=-1)
|
154 |
+
|
155 |
+
all_collisions = jnp.concatenate(
|
156 |
+
[rr_collisions[:, None], cr_collisions[:, None], cc_collisions[:, None]], axis=1
|
157 |
+
)
|
158 |
+
num_collisions = jnp.sum(all_collisions, axis=-1)
|
159 |
+
if bias is not None:
|
160 |
+
num_collisions = num_collisions + bias
|
161 |
+
|
162 |
+
chosen_addition_idx = jnp.argmin(num_collisions)
|
163 |
+
|
164 |
+
return jax.tree.map(lambda x: x[chosen_addition_idx], proposals)
|
165 |
+
|
166 |
+
def _add_filtered_shape(rng, state, force_no_fixate=False):
|
167 |
+
rng, _rng = jax.random.split(rng)
|
168 |
+
_rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals)
|
169 |
+
proposed_additions = jax.vmap(mutate_add_shape, in_axes=(0, None, None, None, None, None))(
|
170 |
+
_rngs,
|
171 |
+
state,
|
172 |
+
env_params,
|
173 |
+
static_env_params,
|
174 |
+
ued_params,
|
175 |
+
jnp.logical_or(force_no_fixate, small_force_no_fixate),
|
176 |
+
)
|
177 |
+
|
178 |
+
return _choose_proposal_with_least_collisions(proposed_additions)
|
179 |
+
|
180 |
+
def _add_filtered_connected_shape(rng, state, force_rjoint=False):
|
181 |
+
rng, _rng = jax.random.split(rng)
|
182 |
+
_rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals)
|
183 |
+
|
184 |
+
proposed_additions, valid = jax.vmap(mutate_add_connected_shape, in_axes=(0, None, None, None, None, None))(
|
185 |
+
_rngs, state, env_params, static_env_params, ued_params, force_rjoint
|
186 |
+
)
|
187 |
+
|
188 |
+
bias = (jnp.ones(ued_params.add_shape_n_proposals) - 1 * valid) * ued_params.connect_no_visibility_bias
|
189 |
+
|
190 |
+
return _choose_proposal_with_least_collisions(proposed_additions, bias=bias)
|
191 |
+
|
192 |
+
# Add green and blue - make sure they're not both fixated
|
193 |
+
force_green_no_fixate = (jax.random.uniform(_rngs[1]) < 0.5) | (state.polygon_shape_roles[0] == 2)
|
194 |
+
state = _add_filtered_shape(_rngs[2], state, force_green_no_fixate)
|
195 |
+
state = _add_filtered_shape(_rngs[3], state, ~force_green_no_fixate)
|
196 |
+
|
197 |
+
# Forced controls
|
198 |
+
forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[4], (), 0, 3)]
|
199 |
+
force_thruster, force_motor = forced_control[0], forced_control[1]
|
200 |
+
|
201 |
+
# Forced motor
|
202 |
+
state = jax.lax.cond(
|
203 |
+
force_motor,
|
204 |
+
lambda: _add_filtered_connected_shape(_rngs[5], state, force_rjoint=True), # force the rjoint
|
205 |
+
lambda: _add_filtered_shape(_rngs[6], state),
|
206 |
+
)
|
207 |
+
|
208 |
+
# Forced thruster
|
209 |
+
state = jax.lax.cond(
|
210 |
+
force_thruster,
|
211 |
+
lambda: mutate_add_thruster(_rngs[7], state, env_params, static_env_params, ued_params),
|
212 |
+
lambda: state,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Add rest of shapes
|
216 |
+
n_shapes_to_add = (
|
217 |
+
static_env_params.num_polygons + static_env_params.num_circles - 3 - static_env_params.num_static_fixated_polys
|
218 |
+
)
|
219 |
+
|
220 |
+
def _add_shape(state, rng):
|
221 |
+
rng, _rng = jax.random.split(rng)
|
222 |
+
_rngs = jax.random.split(_rng, 3)
|
223 |
+
shape_add_type = jax.random.choice(
|
224 |
+
_rngs[0],
|
225 |
+
jnp.arange(3),
|
226 |
+
p=jnp.array(
|
227 |
+
[ued_params.add_connected_shape_chance, ued_params.add_shape_chance, ued_params.add_no_shape_chance]
|
228 |
+
),
|
229 |
+
)
|
230 |
+
|
231 |
+
state = jax.lax.switch(
|
232 |
+
shape_add_type,
|
233 |
+
[
|
234 |
+
lambda: _add_filtered_connected_shape(_rngs[1], state),
|
235 |
+
lambda: _add_filtered_shape(_rngs[2], state),
|
236 |
+
lambda: state,
|
237 |
+
],
|
238 |
+
)
|
239 |
+
|
240 |
+
return state, None
|
241 |
+
|
242 |
+
state, _ = jax.lax.scan(_add_shape, state, jax.random.split(_rngs[8], n_shapes_to_add))
|
243 |
+
|
244 |
+
# Add thrusters
|
245 |
+
n_thrusters_to_add = static_env_params.num_thrusters - 1
|
246 |
+
|
247 |
+
def _add_thruster(state, rng):
|
248 |
+
rng, _rng = jax.random.split(rng)
|
249 |
+
_rngs = jax.random.split(_rng, 3)
|
250 |
+
state = jax.lax.cond(
|
251 |
+
jax.random.uniform(_rngs[0]) < ued_params.add_thruster_chance,
|
252 |
+
lambda: mutate_add_thruster(_rngs[1], state, env_params, static_env_params, ued_params),
|
253 |
+
lambda: state,
|
254 |
+
)
|
255 |
+
|
256 |
+
return state, None
|
257 |
+
|
258 |
+
state, _ = jax.lax.scan(_add_thruster, state, jax.random.split(_rngs[9], n_thrusters_to_add))
|
259 |
+
|
260 |
+
# Randomly swap green and blue to remove left-right bias
|
261 |
+
def _swap_roles(do_swap_roles, roles):
|
262 |
+
role1 = roles == 1
|
263 |
+
role2 = roles == 2
|
264 |
+
|
265 |
+
swapped_roles = roles * ~(role1 | role2) + role1.astype(int) * 2 + role2.astype(int) * 1
|
266 |
+
return jax.lax.select(do_swap_roles, swapped_roles, roles)
|
267 |
+
|
268 |
+
do_swap_roles = jax.random.uniform(_rngs[10], shape=()) < 0.5
|
269 |
+
# Don't want to swap if floor is non-standard
|
270 |
+
do_swap_roles &= state.polygon_shape_roles[0] == 0
|
271 |
+
state = state.replace(
|
272 |
+
polygon_shape_roles=_swap_roles(do_swap_roles, state.polygon_shape_roles),
|
273 |
+
circle_shape_roles=_swap_roles(do_swap_roles, state.circle_shape_roles),
|
274 |
+
)
|
275 |
+
|
276 |
+
return permute_state(_rngs[11], state, static_env_params)
|
277 |
+
|
278 |
+
|
279 |
+
@partial(jax.jit, static_argnums=(2, 4, 5))
|
280 |
+
def create_random_starting_distribution(
|
281 |
+
rng,
|
282 |
+
env_params: EnvParams,
|
283 |
+
static_env_params: StaticEnvParams,
|
284 |
+
ued_params: UEDParams,
|
285 |
+
env_size_name: str,
|
286 |
+
controllable=True,
|
287 |
+
):
|
288 |
+
rng, _rng = jax.random.split(rng)
|
289 |
+
_rngs = jax.random.split(_rng, 15)
|
290 |
+
d = to_state_dict(ued_params)
|
291 |
+
ued_params = UEDParams(
|
292 |
+
**(
|
293 |
+
d
|
294 |
+
| dict(
|
295 |
+
goal_body_size_factor=2.0,
|
296 |
+
thruster_power_multiplier=2.0,
|
297 |
+
max_shape_size=0.5,
|
298 |
+
)
|
299 |
+
),
|
300 |
+
)
|
301 |
+
|
302 |
+
prob_of_large_shapes = 0.05
|
303 |
+
|
304 |
+
ued_params_large_shapes = ued_params.replace(
|
305 |
+
max_shape_size=static_env_params.max_shape_size * 1.0, goal_body_size_factor=1.0
|
306 |
+
)
|
307 |
+
|
308 |
+
state = create_empty_env(env_params, static_env_params)
|
309 |
+
|
310 |
+
def _get_ued_params(rng):
|
311 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
312 |
+
large_shapes = jax.random.uniform(_rng) < prob_of_large_shapes
|
313 |
+
params_to_use = jax.tree.map(
|
314 |
+
lambda x, y: jax.lax.select(large_shapes, x, y), ued_params_large_shapes, ued_params
|
315 |
+
)
|
316 |
+
return params_to_use
|
317 |
+
|
318 |
+
def _my_add_shape(rng, state):
|
319 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
320 |
+
return mutate_add_shape(_rng, state, env_params, static_env_params, _get_ued_params(_rng2))
|
321 |
+
|
322 |
+
def _my_add_connected_shape(rng, state, **kwargs):
|
323 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
324 |
+
return mutate_add_connected_shape_proper(
|
325 |
+
_rng, state, env_params, static_env_params, _get_ued_params(_rng2), **kwargs
|
326 |
+
)
|
327 |
+
|
328 |
+
# Add the green thing and blue thing
|
329 |
+
state = _my_add_shape(_rngs[0], state)
|
330 |
+
state = _my_add_shape(_rngs[1], state)
|
331 |
+
if controllable:
|
332 |
+
# Forced controls
|
333 |
+
forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[2], (), 0, 3)]
|
334 |
+
force_thruster, force_motor = forced_control[0], forced_control[1]
|
335 |
+
|
336 |
+
# Forced motor
|
337 |
+
state = jax.lax.cond(
|
338 |
+
force_motor,
|
339 |
+
lambda: _my_add_connected_shape(_rngs[3], state, force_rjoint=True), # force the rjoint
|
340 |
+
lambda: state,
|
341 |
+
)
|
342 |
+
|
343 |
+
# Forced thruster
|
344 |
+
state = jax.lax.cond(
|
345 |
+
force_thruster,
|
346 |
+
lambda: mutate_add_thruster(_rngs[4], state, env_params, static_env_params, ued_params),
|
347 |
+
lambda: state,
|
348 |
+
)
|
349 |
+
return permute_state(_rngs[7], state, static_env_params)
|
kinetix/environment/ued/mutators.py
ADDED
@@ -0,0 +1,1157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import math
|
3 |
+
|
4 |
+
import chex
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
from flax.serialization import to_state_dict
|
8 |
+
from jax2d.engine import (
|
9 |
+
PhysicsEngine,
|
10 |
+
calculate_collision_matrix,
|
11 |
+
calc_inverse_mass_polygon,
|
12 |
+
calc_inverse_mass_circle,
|
13 |
+
calc_inverse_inertia_circle,
|
14 |
+
calc_inverse_inertia_polygon,
|
15 |
+
recalculate_mass_and_inertia,
|
16 |
+
select_shape,
|
17 |
+
)
|
18 |
+
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
|
19 |
+
from jax2d.maths import rmat
|
20 |
+
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
|
21 |
+
from kinetix.environment.ued.ued_state import UEDParams
|
22 |
+
from kinetix.environment.ued.util import (
|
23 |
+
count_roles,
|
24 |
+
is_space_for_joint,
|
25 |
+
make_velocities_zero,
|
26 |
+
sample_dimensions,
|
27 |
+
random_position_on_polygon,
|
28 |
+
random_position_on_circle,
|
29 |
+
get_role,
|
30 |
+
is_space_for_shape,
|
31 |
+
are_there_shapes_present,
|
32 |
+
)
|
33 |
+
from kinetix.util.saving import load_world_state_pickle
|
34 |
+
from flax import struct
|
35 |
+
from kinetix.environment.env import create_empty_env
|
36 |
+
from kinetix.environment.ued.util import make_do_dummy_step
|
37 |
+
|
38 |
+
|
39 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
40 |
+
def mutate_add_shape(
|
41 |
+
rng,
|
42 |
+
state: EnvState,
|
43 |
+
params: EnvParams,
|
44 |
+
static_env_params: StaticEnvParams,
|
45 |
+
ued_params: UEDParams,
|
46 |
+
force_no_fixate: bool = False,
|
47 |
+
):
|
48 |
+
def do_dummy(rng, state):
|
49 |
+
return state
|
50 |
+
|
51 |
+
def do_add(rng, state):
|
52 |
+
rng, _rng = jax.random.split(rng)
|
53 |
+
_rngs = jax.random.split(_rng, 9)
|
54 |
+
|
55 |
+
space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons
|
56 |
+
space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles
|
57 |
+
|
58 |
+
is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0])
|
59 |
+
is_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=is_rect_p)
|
60 |
+
|
61 |
+
rect_index = jnp.argmin(state.polygon.active)
|
62 |
+
circle_index = jnp.argmin(state.circle.active)
|
63 |
+
|
64 |
+
shape_role = get_role(_rngs[1], state, static_env_params)
|
65 |
+
|
66 |
+
max_shape_size = (
|
67 |
+
jnp.array([1.0, ued_params.goal_body_size_factor, ued_params.goal_body_size_factor, 1.0])[shape_role]
|
68 |
+
* ued_params.max_shape_size
|
69 |
+
)
|
70 |
+
|
71 |
+
vertices, half_dimensions, radius = sample_dimensions(
|
72 |
+
_rngs[2],
|
73 |
+
static_env_params,
|
74 |
+
is_rect,
|
75 |
+
ued_params,
|
76 |
+
max_shape_size=max_shape_size,
|
77 |
+
)
|
78 |
+
n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4)
|
79 |
+
|
80 |
+
largest = jnp.max(jnp.array([half_dimensions[0] * jnp.sqrt(2), half_dimensions[1] * jnp.sqrt(2), radius]))
|
81 |
+
|
82 |
+
screen_dim_world = (
|
83 |
+
static_env_params.screen_dim[0] / params.pixels_per_unit,
|
84 |
+
static_env_params.screen_dim[1] / params.pixels_per_unit,
|
85 |
+
)
|
86 |
+
min_x = largest
|
87 |
+
max_x = screen_dim_world[0] - largest
|
88 |
+
min_y = largest + 0.4
|
89 |
+
max_y = screen_dim_world[1] - largest
|
90 |
+
|
91 |
+
def _og_minmax():
|
92 |
+
return min_x, max_x, min_y, max_y
|
93 |
+
|
94 |
+
def _opposite_minmax():
|
95 |
+
return jax.lax.switch(
|
96 |
+
shape_role,
|
97 |
+
[
|
98 |
+
(lambda: (min_x, max_x, min_y, max_y)),
|
99 |
+
(lambda: (min_x, max_x - screen_dim_world[0] / 2, min_y, max_y)),
|
100 |
+
(lambda: (min_x + screen_dim_world[0] / 2, max_x, min_y, max_y)),
|
101 |
+
(lambda: (min_x, max_x, min_y, max_y)),
|
102 |
+
],
|
103 |
+
)
|
104 |
+
|
105 |
+
min_x, max_x, min_y, max_y = jax.lax.cond(
|
106 |
+
jax.random.uniform(_rngs[4], shape=()) < ued_params.goal_body_opposide_side_chance,
|
107 |
+
_opposite_minmax,
|
108 |
+
_og_minmax,
|
109 |
+
)
|
110 |
+
|
111 |
+
position = jax.random.uniform(_rngs[5], shape=(2,)) * jnp.array(
|
112 |
+
[
|
113 |
+
max_x - min_x,
|
114 |
+
max_y - min_y,
|
115 |
+
]
|
116 |
+
) + jnp.array([min_x, min_y])
|
117 |
+
|
118 |
+
rotation = jax.random.uniform(_rngs[6], shape=()) * 2 * math.pi
|
119 |
+
velocity = jnp.array([0.0, 0.0])
|
120 |
+
angular_velocity = 0.0
|
121 |
+
|
122 |
+
density = 1.0
|
123 |
+
inverse_mass = jax.lax.select(
|
124 |
+
is_rect,
|
125 |
+
calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0],
|
126 |
+
calc_inverse_mass_circle(radius, density),
|
127 |
+
)
|
128 |
+
|
129 |
+
inverse_inertia = jax.lax.select(
|
130 |
+
is_rect,
|
131 |
+
calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density),
|
132 |
+
calc_inverse_inertia_circle(radius, density),
|
133 |
+
)
|
134 |
+
|
135 |
+
fixate_chance = ued_params.fixate_chance_min + (1.0 / inverse_mass) * ued_params.fixate_chance_scale
|
136 |
+
fixate_chance = jnp.minimum(fixate_chance, ued_params.fixate_chance_max)
|
137 |
+
is_fixated = jax.random.uniform(_rngs[7], shape=()) < fixate_chance
|
138 |
+
is_fixated &= ~force_no_fixate
|
139 |
+
|
140 |
+
inverse_mass *= 1 - is_fixated
|
141 |
+
inverse_inertia *= 1 - is_fixated
|
142 |
+
|
143 |
+
# We want to bias fixated shapes to starting nearer the bottom half of the screen
|
144 |
+
fixate_shape_bottom_bias = (
|
145 |
+
ued_params.fixate_shape_bottom_bias + ued_params.fixate_shape_bottom_bias_special_role * (shape_role != 0)
|
146 |
+
)
|
147 |
+
is_forcing_bottom = jax.random.uniform(_rngs[8]) < fixate_shape_bottom_bias
|
148 |
+
|
149 |
+
|
150 |
+
half_screen_height = (static_env_params.screen_dim[1] / params.pixels_per_unit) / 2.0
|
151 |
+
position = jax.lax.select(
|
152 |
+
is_fixated & is_forcing_bottom & (position[1] >= half_screen_height),
|
153 |
+
position.at[1].add(-half_screen_height),
|
154 |
+
position,
|
155 |
+
)
|
156 |
+
|
157 |
+
# This could be either a rect or a circle
|
158 |
+
new_rigid_body = RigidBody(
|
159 |
+
position=position,
|
160 |
+
velocity=velocity,
|
161 |
+
inverse_mass=inverse_mass,
|
162 |
+
inverse_inertia=inverse_inertia,
|
163 |
+
rotation=rotation,
|
164 |
+
angular_velocity=angular_velocity,
|
165 |
+
radius=radius,
|
166 |
+
active=True,
|
167 |
+
friction=1.0,
|
168 |
+
vertices=vertices,
|
169 |
+
n_vertices=n_vertices,
|
170 |
+
collision_mode=1,
|
171 |
+
restitution=0.0,
|
172 |
+
)
|
173 |
+
|
174 |
+
state = state.replace(
|
175 |
+
polygon=jax.tree.map(
|
176 |
+
lambda x, y: jax.lax.select(is_rect, y.at[rect_index].set(x), y), new_rigid_body, state.polygon
|
177 |
+
),
|
178 |
+
circle=jax.tree.map(
|
179 |
+
lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[circle_index].set(x), y),
|
180 |
+
new_rigid_body,
|
181 |
+
state.circle,
|
182 |
+
),
|
183 |
+
polygon_shape_roles=jax.lax.select(
|
184 |
+
is_rect,
|
185 |
+
state.polygon_shape_roles.at[rect_index].set(shape_role),
|
186 |
+
state.polygon_shape_roles,
|
187 |
+
),
|
188 |
+
circle_shape_roles=jax.lax.select(
|
189 |
+
jnp.logical_not(is_rect),
|
190 |
+
state.circle_shape_roles.at[circle_index].set(shape_role),
|
191 |
+
state.circle_shape_roles,
|
192 |
+
),
|
193 |
+
)
|
194 |
+
return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
|
195 |
+
|
196 |
+
return jax.lax.cond(is_space_for_shape(state), do_add, do_dummy, rng, state)
|
197 |
+
|
198 |
+
|
199 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
200 |
+
def mutate_add_connected_shape(
|
201 |
+
rng,
|
202 |
+
state: EnvState,
|
203 |
+
params: EnvParams,
|
204 |
+
static_env_params: StaticEnvParams,
|
205 |
+
ued_params: UEDParams,
|
206 |
+
force_rjoint: bool = False,
|
207 |
+
):
|
208 |
+
def do_dummy(rng, state):
|
209 |
+
return state, False
|
210 |
+
|
211 |
+
def do_add(rng, state):
|
212 |
+
rng, _rng = jax.random.split(rng)
|
213 |
+
_rngs = jax.random.split(_rng, 21)
|
214 |
+
|
215 |
+
# Select a random index amongst the currently active shapes.
|
216 |
+
p_rect = state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False)
|
217 |
+
p_circle = state.circle.active
|
218 |
+
|
219 |
+
p_rect = p_rect.astype(jnp.float32)
|
220 |
+
p_circle = p_circle.astype(jnp.float32)
|
221 |
+
|
222 |
+
p_rect *= (state.polygon.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + (
|
223 |
+
state.polygon.inverse_mass != 0
|
224 |
+
) * 1.0
|
225 |
+
p_circle *= (state.circle.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + (
|
226 |
+
state.circle.inverse_mass != 0
|
227 |
+
) * 1.0
|
228 |
+
|
229 |
+
# Bias based on number of existing connections
|
230 |
+
rect_connections = jnp.zeros(static_env_params.num_polygons)
|
231 |
+
circle_connections = jnp.zeros(static_env_params.num_circles)
|
232 |
+
|
233 |
+
rect_connections = rect_connections.at[state.joint.a_index].add(
|
234 |
+
jnp.ones(static_env_params.num_joints)
|
235 |
+
* state.joint.active
|
236 |
+
* (state.joint.a_index < static_env_params.num_polygons)
|
237 |
+
)
|
238 |
+
rect_connections = rect_connections.at[state.joint.b_index].add(
|
239 |
+
jnp.ones(static_env_params.num_joints)
|
240 |
+
* state.joint.active
|
241 |
+
* (state.joint.b_index < static_env_params.num_polygons)
|
242 |
+
)
|
243 |
+
|
244 |
+
circle_connections = circle_connections.at[state.joint.a_index - static_env_params.num_polygons].add(
|
245 |
+
jnp.ones(static_env_params.num_joints)
|
246 |
+
* state.joint.active
|
247 |
+
* (state.joint.a_index >= static_env_params.num_polygons)
|
248 |
+
)
|
249 |
+
circle_connections = circle_connections.at[state.joint.b_index - static_env_params.num_polygons].add(
|
250 |
+
jnp.ones(static_env_params.num_joints)
|
251 |
+
* state.joint.active
|
252 |
+
* (state.joint.b_index >= static_env_params.num_polygons)
|
253 |
+
)
|
254 |
+
|
255 |
+
# Rectangles can have up to 2 connections
|
256 |
+
p_rect *= (-rect_connections + 2.0) / 2.0
|
257 |
+
p_rect = jnp.maximum(p_rect, 0.0)
|
258 |
+
# Circles can have 1 connection
|
259 |
+
p_circle *= circle_connections == 0
|
260 |
+
|
261 |
+
# To sample a target rect/circle, we have to have at least one.
|
262 |
+
target_rect_p = jnp.array(
|
263 |
+
[
|
264 |
+
(state.polygon.active.astype(int).sum() > static_env_params.num_static_fixated_polys) * 1.0,
|
265 |
+
(state.circle.active.astype(int).sum() > 0) * 1.0,
|
266 |
+
]
|
267 |
+
)
|
268 |
+
|
269 |
+
# Don't connect to a circle if no connection-free ones exist
|
270 |
+
target_rect_p = target_rect_p.at[1].mul(p_circle.sum() > 0)
|
271 |
+
|
272 |
+
space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons
|
273 |
+
space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles
|
274 |
+
|
275 |
+
is_target_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=target_rect_p) | (
|
276 |
+
~space_for_new_rect
|
277 |
+
)
|
278 |
+
|
279 |
+
is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0])
|
280 |
+
is_rect = jax.random.choice(_rngs[1], jnp.array([True, False], dtype=bool), p=is_rect_p) | (
|
281 |
+
~is_target_rect & space_for_new_rect
|
282 |
+
)
|
283 |
+
|
284 |
+
shape_index = jax.lax.select(
|
285 |
+
is_rect,
|
286 |
+
jnp.argmin(state.polygon.active),
|
287 |
+
jnp.argmin(state.circle.active),
|
288 |
+
)
|
289 |
+
unified_shape_index = shape_index + (~is_rect) * static_env_params.num_polygons
|
290 |
+
|
291 |
+
vertices, half_dimensions, radius = sample_dimensions(
|
292 |
+
_rngs[2], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size
|
293 |
+
)
|
294 |
+
n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4)
|
295 |
+
|
296 |
+
rotation = jax.random.uniform(_rngs[4], shape=()) * 2 * math.pi
|
297 |
+
velocity = jnp.array([0.0, 0.0])
|
298 |
+
angular_velocity = 0.0
|
299 |
+
|
300 |
+
density = 1.0
|
301 |
+
inverse_mass = jax.lax.select(
|
302 |
+
is_rect,
|
303 |
+
calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0],
|
304 |
+
calc_inverse_mass_circle(radius, density),
|
305 |
+
)
|
306 |
+
|
307 |
+
inverse_inertia = jax.lax.select(
|
308 |
+
is_rect,
|
309 |
+
calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density),
|
310 |
+
calc_inverse_inertia_circle(radius, density),
|
311 |
+
)
|
312 |
+
|
313 |
+
# Joint
|
314 |
+
|
315 |
+
current_num_rjoints = (jnp.logical_not(state.joint.is_fixed_joint) * state.joint.active).sum()
|
316 |
+
is_rjoint = jnp.logical_or(
|
317 |
+
jnp.logical_or(jax.random.uniform(_rngs[5]) < 0.5, force_rjoint),
|
318 |
+
current_num_rjoints < ued_params.min_rjoints_bias,
|
319 |
+
)
|
320 |
+
|
321 |
+
joint_index = jnp.argmin(state.joint.active)
|
322 |
+
|
323 |
+
local_joint_position_rect = random_position_on_polygon(_rngs[6], vertices, n_vertices, static_env_params)
|
324 |
+
local_joint_position_circle = random_position_on_circle(_rngs[7], radius, on_centre_chance=1.0)
|
325 |
+
|
326 |
+
local_joint_position = jax.lax.select(is_rect, local_joint_position_rect, local_joint_position_circle)
|
327 |
+
|
328 |
+
p_rect = jax.lax.select(p_rect.sum() == 0, state.polygon.active.astype(jnp.float32), p_rect)
|
329 |
+
p_circle = jax.lax.select(p_circle.sum() == 0, state.circle.active.astype(jnp.float32), p_circle)
|
330 |
+
|
331 |
+
target_index = jax.lax.select(
|
332 |
+
is_target_rect,
|
333 |
+
jax.random.choice(
|
334 |
+
_rngs[8],
|
335 |
+
jnp.arange(static_env_params.num_polygons),
|
336 |
+
p=p_rect,
|
337 |
+
),
|
338 |
+
jax.random.choice(
|
339 |
+
_rngs[9],
|
340 |
+
jnp.arange(static_env_params.num_circles),
|
341 |
+
p=p_circle,
|
342 |
+
),
|
343 |
+
)
|
344 |
+
|
345 |
+
unified_target_index = target_index + jnp.logical_not(is_target_rect) * static_env_params.num_polygons
|
346 |
+
target_shape = select_shape(state, unified_target_index, static_env_params)
|
347 |
+
|
348 |
+
target_joint_position_rect = random_position_on_polygon(
|
349 |
+
_rngs[10], state.polygon.vertices[target_index], state.polygon.n_vertices[target_index], static_env_params
|
350 |
+
)
|
351 |
+
target_joint_position_circle = random_position_on_circle(
|
352 |
+
_rngs[11], state.circle.radius[target_index], on_centre_chance=1.0
|
353 |
+
)
|
354 |
+
|
355 |
+
target_joint_position = jax.lax.select(is_target_rect, target_joint_position_rect, target_joint_position_circle)
|
356 |
+
|
357 |
+
# Calculate the world position of the new shape
|
358 |
+
# We know the rotation of the new shape. We also know the position of the current shape, which we want to remain fixed.
|
359 |
+
# Set `position` such that local_joint_position is the same as `target_joint_position`
|
360 |
+
global_joint_pos = target_shape.position + jnp.matmul(rmat(target_shape.rotation), target_joint_position)
|
361 |
+
position = global_joint_pos - jnp.matmul(rmat(rotation), local_joint_position)
|
362 |
+
|
363 |
+
_, pos_diff = calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)
|
364 |
+
position = jax.lax.select(is_rect, position + pos_diff, position)
|
365 |
+
local_joint_position = jax.lax.select(is_rect, local_joint_position - pos_diff, local_joint_position)
|
366 |
+
vertices = jax.lax.select(is_rect, vertices - pos_diff[None], vertices)
|
367 |
+
|
368 |
+
target_role = jax.lax.select(
|
369 |
+
is_target_rect, state.polygon_shape_roles[target_index], state.circle_shape_roles[target_index]
|
370 |
+
)
|
371 |
+
|
372 |
+
# We cannot have role 1 and role 2 being connected.
|
373 |
+
p = jnp.array([1.0, 1.0, 1.0, 1.0])
|
374 |
+
# If role is 0, keep all probs at 1, otherwise set the target role's complement to 0 prob
|
375 |
+
# 3 - role turns 1 to 2 and 2 to 1
|
376 |
+
# If the target role is three, we set everything to zero except for the default
|
377 |
+
p = jax.lax.select(
|
378 |
+
target_role == 0,
|
379 |
+
p,
|
380 |
+
jax.lax.select(
|
381 |
+
target_role <= 2,
|
382 |
+
p.at[3 - target_role].set(False).at[3].set(False),
|
383 |
+
(p.at[2].set(False).at[1].set(False)),
|
384 |
+
),
|
385 |
+
)
|
386 |
+
|
387 |
+
shape_role = get_role(_rngs[12], state, static_env_params, initial_p=p)
|
388 |
+
|
389 |
+
# This could be either a rect or a circle
|
390 |
+
new_rigid_body = RigidBody(
|
391 |
+
position=position,
|
392 |
+
velocity=velocity,
|
393 |
+
inverse_mass=inverse_mass,
|
394 |
+
inverse_inertia=inverse_inertia,
|
395 |
+
rotation=rotation,
|
396 |
+
angular_velocity=angular_velocity,
|
397 |
+
radius=radius,
|
398 |
+
active=True,
|
399 |
+
friction=1.0,
|
400 |
+
vertices=vertices,
|
401 |
+
n_vertices=n_vertices,
|
402 |
+
collision_mode=1,
|
403 |
+
restitution=0.0,
|
404 |
+
)
|
405 |
+
|
406 |
+
# Change the shape indices such that a_index is less than b_index
|
407 |
+
a_index = shape_index + (1 - is_rect) * static_env_params.num_polygons
|
408 |
+
b_index = target_index + (1 - is_target_rect) * static_env_params.num_polygons
|
409 |
+
|
410 |
+
should_swap = a_index > b_index
|
411 |
+
a_index, b_index, local_joint_position, target_joint_position, shape_a, shape_b = jax.lax.cond(
|
412 |
+
should_swap,
|
413 |
+
lambda x: (x[1], x[0], x[3], x[2], x[5], x[4]), # pairwise swap
|
414 |
+
lambda x: x,
|
415 |
+
(a_index, b_index, local_joint_position, target_joint_position, new_rigid_body, target_shape),
|
416 |
+
)
|
417 |
+
|
418 |
+
motor_on = jax.random.uniform(_rngs[13], shape=()) < ued_params.motor_on_chance
|
419 |
+
joint_colour = jax.random.randint(_rngs[14], shape=(), minval=0, maxval=static_env_params.num_motor_bindings)
|
420 |
+
joint_rotation = shape_b.rotation - shape_a.rotation
|
421 |
+
|
422 |
+
motor_speed = jax.random.uniform(
|
423 |
+
_rngs[15], shape=(), minval=ued_params.motor_min_speed, maxval=ued_params.motor_max_speed
|
424 |
+
)
|
425 |
+
|
426 |
+
motor_power = jax.random.uniform(
|
427 |
+
_rngs[16], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.motor_max_power
|
428 |
+
)
|
429 |
+
wheel_power = jax.random.uniform(
|
430 |
+
_rngs[20], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.wheel_max_power
|
431 |
+
)
|
432 |
+
|
433 |
+
# High-powered wheels break the physics engine - this is a temporary fix
|
434 |
+
motor_power = jax.lax.select(is_rect & is_target_rect, motor_power, wheel_power)
|
435 |
+
|
436 |
+
motor_has_joint_limits = jax.random.uniform(_rngs[17], shape=()) < ued_params.joint_limit_chance
|
437 |
+
motor_has_joint_limits &= is_rect & is_target_rect
|
438 |
+
joint_limit_min = (
|
439 |
+
jax.random.uniform(_rngs[18], shape=(), minval=-ued_params.joint_limit_max, maxval=0.0)
|
440 |
+
* motor_has_joint_limits
|
441 |
+
)
|
442 |
+
joint_limit_max = (
|
443 |
+
jax.random.uniform(_rngs[19], shape=(), minval=0.0, maxval=ued_params.joint_limit_max)
|
444 |
+
* motor_has_joint_limits
|
445 |
+
)
|
446 |
+
|
447 |
+
rjoint = Joint(
|
448 |
+
a_index=a_index,
|
449 |
+
b_index=b_index,
|
450 |
+
a_relative_pos=local_joint_position,
|
451 |
+
b_relative_pos=target_joint_position,
|
452 |
+
global_position=global_joint_pos,
|
453 |
+
active=True,
|
454 |
+
motor_speed=motor_speed,
|
455 |
+
motor_power=motor_power,
|
456 |
+
motor_on=motor_on,
|
457 |
+
# colour=joint_colour,
|
458 |
+
motor_has_joint_limits=motor_has_joint_limits,
|
459 |
+
min_rotation=joint_limit_min,
|
460 |
+
max_rotation=joint_limit_max,
|
461 |
+
is_fixed_joint=False,
|
462 |
+
rotation=0.0,
|
463 |
+
acc_impulse=jnp.zeros((2,), dtype=jnp.float32),
|
464 |
+
acc_r_impulse=jnp.zeros((), dtype=jnp.float32),
|
465 |
+
)
|
466 |
+
|
467 |
+
fjoint = Joint(
|
468 |
+
a_index=a_index,
|
469 |
+
b_index=b_index,
|
470 |
+
a_relative_pos=local_joint_position,
|
471 |
+
b_relative_pos=target_joint_position,
|
472 |
+
global_position=global_joint_pos,
|
473 |
+
active=True,
|
474 |
+
rotation=joint_rotation,
|
475 |
+
acc_impulse=jnp.zeros((2,), dtype=jnp.float32),
|
476 |
+
acc_r_impulse=jnp.zeros((), dtype=jnp.float32),
|
477 |
+
is_fixed_joint=True,
|
478 |
+
motor_has_joint_limits=False,
|
479 |
+
min_rotation=0.0,
|
480 |
+
max_rotation=0.0,
|
481 |
+
motor_on=False,
|
482 |
+
motor_power=0.0,
|
483 |
+
motor_speed=0.0,
|
484 |
+
)
|
485 |
+
|
486 |
+
state = state.replace(
|
487 |
+
polygon=jax.tree.map(
|
488 |
+
lambda x, y: jax.lax.select(is_rect, y.at[shape_index].set(x), y), new_rigid_body, state.polygon
|
489 |
+
),
|
490 |
+
circle=jax.tree.map(
|
491 |
+
lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[shape_index].set(x), y),
|
492 |
+
new_rigid_body,
|
493 |
+
state.circle,
|
494 |
+
),
|
495 |
+
joint=jax.tree.map(
|
496 |
+
lambda rj, fj, y: jax.lax.select(is_rjoint, y.at[joint_index].set(rj), y.at[joint_index].set(fj)),
|
497 |
+
rjoint,
|
498 |
+
fjoint,
|
499 |
+
state.joint,
|
500 |
+
),
|
501 |
+
polygon_shape_roles=jax.lax.select(
|
502 |
+
is_rect,
|
503 |
+
state.polygon_shape_roles.at[shape_index].set(shape_role),
|
504 |
+
state.polygon_shape_roles,
|
505 |
+
),
|
506 |
+
circle_shape_roles=jax.lax.select(
|
507 |
+
jnp.logical_not(is_rect),
|
508 |
+
state.circle_shape_roles.at[shape_index].set(shape_role),
|
509 |
+
state.circle_shape_roles,
|
510 |
+
),
|
511 |
+
motor_bindings=state.motor_bindings.at[joint_index].set(joint_colour),
|
512 |
+
)
|
513 |
+
|
514 |
+
# We need the new collision matrix.
|
515 |
+
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
|
516 |
+
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
|
517 |
+
|
518 |
+
# Was this a valid addition?
|
519 |
+
# We calculate whether (assuming the possiblity of 360 degree rotation around the joint)
|
520 |
+
# both shapes can be visible
|
521 |
+
# This is to remove the common degenerate pattern of connected shapes being fully inside each other
|
522 |
+
def _get_min_rect_dist(r_id, local_pos):
|
523 |
+
rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon)
|
524 |
+
|
525 |
+
half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0
|
526 |
+
half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0
|
527 |
+
|
528 |
+
dist_x = half_width - jnp.abs(local_pos[0])
|
529 |
+
dist_y = half_height - jnp.abs(local_pos[1])
|
530 |
+
|
531 |
+
return jnp.minimum(dist_x, dist_y)
|
532 |
+
|
533 |
+
def _get_max_rect_dist(r_id, local_pos):
|
534 |
+
rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon)
|
535 |
+
|
536 |
+
half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0
|
537 |
+
half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0
|
538 |
+
|
539 |
+
dist_x = jnp.maximum(
|
540 |
+
jnp.abs(half_width - local_pos[0]),
|
541 |
+
jnp.abs(-half_width - local_pos[0]),
|
542 |
+
)
|
543 |
+
|
544 |
+
dist_y = jnp.maximum(
|
545 |
+
jnp.abs(half_height - local_pos[1]),
|
546 |
+
jnp.abs(-half_height - local_pos[1]),
|
547 |
+
)
|
548 |
+
|
549 |
+
return jnp.sqrt(dist_x * dist_x + dist_y * dist_y)
|
550 |
+
|
551 |
+
def are_both_shapes_showing(idx1, idx2, local_pos1, local_pos2):
|
552 |
+
def _is_small_shape_showing(small_idx, big_idx, small_local_pos, big_local_pos):
|
553 |
+
small_is_poly = small_idx < static_env_params.num_polygons
|
554 |
+
big_is_poly = big_idx < static_env_params.num_polygons
|
555 |
+
|
556 |
+
# CC
|
557 |
+
cc_result = False
|
558 |
+
|
559 |
+
# CR
|
560 |
+
cr_r_dist = _get_min_rect_dist(big_idx, big_local_pos)
|
561 |
+
cr_result = (
|
562 |
+
cr_r_dist + ued_params.connect_visibility_min
|
563 |
+
< state.circle.radius[small_idx - static_env_params.num_polygons]
|
564 |
+
)
|
565 |
+
|
566 |
+
# RC
|
567 |
+
rc_r_dist = _get_max_rect_dist(small_idx, small_local_pos)
|
568 |
+
rc_result = (
|
569 |
+
rc_r_dist
|
570 |
+
> state.circle.radius[big_idx - static_env_params.num_polygons] + ued_params.connect_visibility_min
|
571 |
+
)
|
572 |
+
|
573 |
+
# RR
|
574 |
+
rr_small_dist = _get_max_rect_dist(small_idx, small_local_pos)
|
575 |
+
rr_big_dist = _get_min_rect_dist(big_idx, big_local_pos)
|
576 |
+
rr_result = rr_small_dist > rr_big_dist + ued_params.connect_visibility_min
|
577 |
+
|
578 |
+
# Select
|
579 |
+
return jax.lax.select(
|
580 |
+
small_is_poly,
|
581 |
+
jax.lax.select(big_is_poly, rr_result, rc_result),
|
582 |
+
jax.lax.select(big_is_poly, cr_result, cc_result),
|
583 |
+
)
|
584 |
+
|
585 |
+
# Are both shapes showing?
|
586 |
+
return _is_small_shape_showing(idx1, idx2, local_pos1, local_pos2) & _is_small_shape_showing(
|
587 |
+
idx2, idx1, local_pos2, local_pos1
|
588 |
+
)
|
589 |
+
|
590 |
+
valid = are_both_shapes_showing(
|
591 |
+
unified_shape_index, unified_target_index, local_joint_position, target_joint_position
|
592 |
+
)
|
593 |
+
return state, valid
|
594 |
+
|
595 |
+
# To add a connected shape, we must have both at least one existing shape and space
|
596 |
+
return jax.lax.cond(
|
597 |
+
is_space_for_shape(state) & are_there_shapes_present(state, static_env_params) & is_space_for_joint(state),
|
598 |
+
do_add,
|
599 |
+
do_dummy,
|
600 |
+
rng,
|
601 |
+
state,
|
602 |
+
)
|
603 |
+
|
604 |
+
|
605 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
606 |
+
def mutate_add_connected_shape_proper(
|
607 |
+
rng,
|
608 |
+
state: EnvState,
|
609 |
+
params: EnvParams,
|
610 |
+
static_env_params: StaticEnvParams,
|
611 |
+
ued_params: UEDParams,
|
612 |
+
force_rjoint: bool = False,
|
613 |
+
):
|
614 |
+
return mutate_add_connected_shape(rng, state, params, static_env_params, ued_params, force_rjoint=force_rjoint)[0]
|
615 |
+
|
616 |
+
|
617 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
618 |
+
def mutate_remove_shape(
|
619 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
620 |
+
):
|
621 |
+
|
622 |
+
can_remove_mask = (
|
623 |
+
jnp.concatenate([state.polygon.active, state.circle.active])
|
624 |
+
.at[: static_env_params.num_static_fixated_polys]
|
625 |
+
.set(False)
|
626 |
+
)
|
627 |
+
|
628 |
+
def dummy(rng, state):
|
629 |
+
return state
|
630 |
+
|
631 |
+
def do_remove(rng, state: EnvState):
|
632 |
+
rng, _rng = jax.random.split(rng)
|
633 |
+
rngs = jax.random.split(_rng, 2)
|
634 |
+
p = can_remove_mask.astype(jnp.float32)
|
635 |
+
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p)
|
636 |
+
is_rect = index_to_remove < static_env_params.num_polygons
|
637 |
+
state = state.replace(
|
638 |
+
polygon=state.polygon.replace(
|
639 |
+
active=jax.lax.select(
|
640 |
+
is_rect, state.polygon.active.at[index_to_remove].set(False), state.polygon.active
|
641 |
+
)
|
642 |
+
),
|
643 |
+
circle=state.circle.replace(
|
644 |
+
active=jax.lax.select(
|
645 |
+
jnp.logical_not(is_rect),
|
646 |
+
state.circle.active.at[index_to_remove - static_env_params.num_polygons].set(False),
|
647 |
+
state.circle.active,
|
648 |
+
)
|
649 |
+
),
|
650 |
+
)
|
651 |
+
# We need to now remove any joints connected to this shape
|
652 |
+
joints_to_remove = (state.joint.a_index == index_to_remove) | (state.joint.b_index == index_to_remove)
|
653 |
+
|
654 |
+
thrusters_to_remove = state.thruster.object_index == index_to_remove
|
655 |
+
|
656 |
+
state = state.replace(
|
657 |
+
joint=state.joint.replace(active=jnp.where(joints_to_remove, False, state.joint.active)),
|
658 |
+
thruster=state.thruster.replace(active=jnp.where(thrusters_to_remove, False, state.thruster.active)),
|
659 |
+
)
|
660 |
+
# Now recalculate collision matrix
|
661 |
+
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
|
662 |
+
return state
|
663 |
+
|
664 |
+
return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state)
|
665 |
+
|
666 |
+
|
667 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
668 |
+
def mutate_remove_joint(
|
669 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
670 |
+
):
|
671 |
+
can_remove_mask = state.joint.active
|
672 |
+
|
673 |
+
def dummy(rng, state):
|
674 |
+
return state
|
675 |
+
|
676 |
+
def do_remove(rng, state):
|
677 |
+
rng, _rng = jax.random.split(rng)
|
678 |
+
rngs = jax.random.split(_rng, 2)
|
679 |
+
p = can_remove_mask.astype(jnp.float32)
|
680 |
+
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p)
|
681 |
+
state = state.replace(joint=state.joint.replace(active=state.joint.active.at[index_to_remove].set(False)))
|
682 |
+
# Recalculate collision matrix.
|
683 |
+
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
|
684 |
+
return state
|
685 |
+
|
686 |
+
return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state)
|
687 |
+
|
688 |
+
|
689 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
690 |
+
def mutate_swap_role(
|
691 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
692 |
+
):
|
693 |
+
def _cr(*args):
|
694 |
+
return count_roles(*args, include_static_polys=False)
|
695 |
+
|
696 |
+
role_counts = jax.vmap(_cr, (None, None, 0))(state, static_env_params, jnp.arange(4))
|
697 |
+
are_there_multiple_roles = (role_counts > 0).sum() > 1
|
698 |
+
|
699 |
+
def dummy(rng, state):
|
700 |
+
return state
|
701 |
+
|
702 |
+
def do_swap(rng, state):
|
703 |
+
rng, _rng = jax.random.split(rng)
|
704 |
+
rngs = jax.random.split(_rng, 2)
|
705 |
+
all_roles = jnp.concatenate([state.polygon_shape_roles, state.circle_shape_roles])
|
706 |
+
|
707 |
+
p = (
|
708 |
+
(jnp.concatenate([state.polygon.active, state.circle.active]))
|
709 |
+
.astype(jnp.float32)
|
710 |
+
.at[: static_env_params.num_static_fixated_polys]
|
711 |
+
.set(0.0)
|
712 |
+
)
|
713 |
+
shape_idx_a = jax.random.choice(
|
714 |
+
rngs[0], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p
|
715 |
+
)
|
716 |
+
role_a = all_roles[shape_idx_a]
|
717 |
+
p = jnp.where(all_roles == role_a, 0.0, p)
|
718 |
+
shape_idx_b = jax.random.choice(
|
719 |
+
rngs[1], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p
|
720 |
+
)
|
721 |
+
role_b = all_roles[shape_idx_b]
|
722 |
+
role_a, role_b = role_b, role_a
|
723 |
+
|
724 |
+
for idx, role in [(shape_idx_a, role_a), (shape_idx_b, role_b)]:
|
725 |
+
is_rect = idx < static_env_params.num_polygons
|
726 |
+
state = state.replace(
|
727 |
+
polygon_shape_roles=jax.lax.select(
|
728 |
+
is_rect, state.polygon_shape_roles.at[idx].set(role), state.polygon_shape_roles
|
729 |
+
),
|
730 |
+
circle_shape_roles=jax.lax.select(
|
731 |
+
jnp.logical_not(is_rect),
|
732 |
+
state.circle_shape_roles.at[idx - static_env_params.num_polygons].set(role),
|
733 |
+
state.circle_shape_roles,
|
734 |
+
),
|
735 |
+
)
|
736 |
+
return state
|
737 |
+
|
738 |
+
return jax.lax.cond(are_there_multiple_roles, do_swap, dummy, rng, state)
|
739 |
+
|
740 |
+
|
741 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
742 |
+
def mutate_toggle_fixture(
|
743 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
744 |
+
):
|
745 |
+
can_toggle_mask = (
|
746 |
+
jnp.concatenate([state.polygon.active, state.circle.active])
|
747 |
+
.at[: static_env_params.num_static_fixated_polys]
|
748 |
+
.set(False)
|
749 |
+
)
|
750 |
+
|
751 |
+
def dummy(rng, state):
|
752 |
+
return state
|
753 |
+
|
754 |
+
def do_toggle(rng, state: EnvState):
|
755 |
+
rng, _rng = jax.random.split(rng)
|
756 |
+
rngs = jax.random.split(_rng, 2)
|
757 |
+
p = can_toggle_mask.astype(jnp.float32)
|
758 |
+
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_toggle_mask.shape[0]), p=p)
|
759 |
+
is_rect = index_to_remove < static_env_params.num_polygons
|
760 |
+
is_current_fixed = (
|
761 |
+
jax.lax.select(
|
762 |
+
is_rect,
|
763 |
+
state.polygon.inverse_inertia[index_to_remove],
|
764 |
+
state.circle.inverse_inertia[index_to_remove - static_env_params.num_polygons],
|
765 |
+
)
|
766 |
+
== 0.0
|
767 |
+
)
|
768 |
+
|
769 |
+
is_current_fixed = is_current_fixed * 1.0 # if it is fixed, we set it to 1.0 and recalc.
|
770 |
+
# If it is not fixed, this is 0.0, and it makes it fixed.
|
771 |
+
|
772 |
+
state = state.replace(
|
773 |
+
polygon=state.polygon.replace(
|
774 |
+
inverse_inertia=jax.lax.select(
|
775 |
+
is_rect,
|
776 |
+
state.polygon.inverse_inertia.at[index_to_remove].set(is_current_fixed),
|
777 |
+
state.polygon.inverse_inertia,
|
778 |
+
),
|
779 |
+
inverse_mass=jax.lax.select(
|
780 |
+
is_rect,
|
781 |
+
state.polygon.inverse_mass.at[index_to_remove].set(is_current_fixed),
|
782 |
+
state.polygon.inverse_mass,
|
783 |
+
),
|
784 |
+
),
|
785 |
+
circle=state.circle.replace(
|
786 |
+
inverse_inertia=jax.lax.select(
|
787 |
+
jnp.logical_not(is_rect),
|
788 |
+
state.circle.inverse_inertia.at[index_to_remove - static_env_params.num_polygons].set(
|
789 |
+
is_current_fixed
|
790 |
+
),
|
791 |
+
state.circle.inverse_inertia,
|
792 |
+
),
|
793 |
+
inverse_mass=jax.lax.select(
|
794 |
+
jnp.logical_not(is_rect),
|
795 |
+
state.circle.inverse_mass.at[index_to_remove - static_env_params.num_polygons].set(
|
796 |
+
is_current_fixed
|
797 |
+
),
|
798 |
+
state.circle.inverse_mass,
|
799 |
+
),
|
800 |
+
),
|
801 |
+
)
|
802 |
+
|
803 |
+
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
|
804 |
+
return state
|
805 |
+
|
806 |
+
return jax.lax.cond(can_toggle_mask.sum() > 0, do_toggle, dummy, rng, state)
|
807 |
+
|
808 |
+
|
809 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
810 |
+
def mutate_add_thruster(
|
811 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
812 |
+
):
|
813 |
+
is_fixated = jnp.concatenate([state.polygon.inverse_mass == 0, state.circle.inverse_mass == 0])
|
814 |
+
# is_fixated = jnp.zeros_like(is_fixated, dtype=bool)
|
815 |
+
is_active = jnp.concatenate([state.polygon.active, state.circle.active])
|
816 |
+
can_add_mask = is_active & (~is_fixated)
|
817 |
+
can_add_mask = jnp.logical_and(is_active, jnp.logical_not(is_fixated))
|
818 |
+
|
819 |
+
def dummy(rng, state):
|
820 |
+
return state
|
821 |
+
|
822 |
+
def do_add(rng, state: EnvState):
|
823 |
+
rng, _rng = jax.random.split(rng)
|
824 |
+
_rngs = jax.random.split(_rng, 10)
|
825 |
+
p = can_add_mask.astype(jnp.float32)
|
826 |
+
shape_index = jax.random.choice(_rngs[0], jnp.arange(can_add_mask.shape[0]), p=p)
|
827 |
+
thruster_idx = jnp.argmin(state.thruster.active)
|
828 |
+
|
829 |
+
shape = select_shape(state, shape_index, static_env_params)
|
830 |
+
|
831 |
+
position_to_add_thruster = jax.lax.select(
|
832 |
+
shape_index < static_env_params.num_polygons,
|
833 |
+
random_position_on_polygon(_rngs[1], shape.vertices, shape.n_vertices, static_env_params),
|
834 |
+
random_position_on_circle(_rngs[2], shape.radius, on_centre_chance=0.0),
|
835 |
+
)
|
836 |
+
|
837 |
+
direction_to_com = ((jax.random.uniform(_rngs[3]) > 0.5) * 2 - 1) * position_to_add_thruster
|
838 |
+
direction_to_com = jax.lax.select(
|
839 |
+
jnp.linalg.norm(direction_to_com) == 0.0, jnp.array([1.0, 0.0]), direction_to_com
|
840 |
+
)
|
841 |
+
|
842 |
+
thruster_angle = jax.lax.select(
|
843 |
+
jax.random.uniform(_rngs[4]) < ued_params.thruster_align_com_prob,
|
844 |
+
jnp.atan2(direction_to_com[1], direction_to_com[0]), # test this
|
845 |
+
jax.random.uniform(
|
846 |
+
_rngs[5],
|
847 |
+
(),
|
848 |
+
)
|
849 |
+
* 2
|
850 |
+
* jnp.pi,
|
851 |
+
)
|
852 |
+
|
853 |
+
thruster_power = jax.random.uniform(_rngs[6]) * 1.5 + 0.5
|
854 |
+
|
855 |
+
thruster = Thruster(
|
856 |
+
object_index=shape_index,
|
857 |
+
active=True,
|
858 |
+
relative_position=position_to_add_thruster, # jnp.array([0.0, 0.0]), # a bit of a hack but reasonable.
|
859 |
+
rotation=thruster_angle, # jax.random.choice(rngs[1], jnp.arange(4) * jnp.pi / 2),
|
860 |
+
power=1.0
|
861 |
+
/ jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass)
|
862 |
+
* ued_params.thruster_power_multiplier
|
863 |
+
* thruster_power,
|
864 |
+
global_position=shape.position + jnp.matmul(rmat(shape.rotation), position_to_add_thruster),
|
865 |
+
)
|
866 |
+
thruster_colour = jax.random.randint(
|
867 |
+
_rngs[7], shape=(), minval=0, maxval=static_env_params.num_thruster_bindings
|
868 |
+
)
|
869 |
+
|
870 |
+
state = state.replace(
|
871 |
+
thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
|
872 |
+
thruster_bindings=state.thruster_bindings.at[thruster_idx].set(thruster_colour),
|
873 |
+
)
|
874 |
+
|
875 |
+
return state
|
876 |
+
|
877 |
+
return jax.lax.cond(
|
878 |
+
jnp.logical_and((can_add_mask.sum() > 0), (jnp.logical_not(state.thruster.active).sum() > 0)),
|
879 |
+
do_add,
|
880 |
+
dummy,
|
881 |
+
rng,
|
882 |
+
state,
|
883 |
+
)
|
884 |
+
|
885 |
+
|
886 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
887 |
+
def mutate_change_gravity(
|
888 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
889 |
+
):
|
890 |
+
rng, _rng = jax.random.split(rng)
|
891 |
+
rngs = jax.random.split(_rng, 2)
|
892 |
+
new_gravity = jax.lax.select(
|
893 |
+
jax.random.uniform(rngs[0]) < 0.5,
|
894 |
+
jnp.array([0.0, -9.8]),
|
895 |
+
jnp.array([0.0, jax.random.uniform(rngs[1], minval=-9.8, maxval=0)]),
|
896 |
+
)
|
897 |
+
|
898 |
+
return state.replace(gravity=new_gravity)
|
899 |
+
|
900 |
+
|
901 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
902 |
+
def mutate_remove_thruster(
|
903 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
904 |
+
):
|
905 |
+
are_there_thrusters = state.thruster.active
|
906 |
+
|
907 |
+
def dummy(rng, state):
|
908 |
+
return state
|
909 |
+
|
910 |
+
def do_remove(rng, state):
|
911 |
+
rng, _rng = jax.random.split(rng)
|
912 |
+
rngs = jax.random.split(_rng, 2)
|
913 |
+
p = are_there_thrusters.astype(jnp.float32)
|
914 |
+
thruster_idx = jax.random.choice(rngs[0], jnp.arange(are_there_thrusters.shape[0]), p=p)
|
915 |
+
return state.replace(thruster=state.thruster.replace(active=state.thruster.active.at[thruster_idx].set(False)))
|
916 |
+
|
917 |
+
return jax.lax.cond(are_there_thrusters.sum() > 0, do_remove, dummy, rng, state)
|
918 |
+
|
919 |
+
|
920 |
+
def make_mutate_change_shape_size(params, static_env_params):
|
921 |
+
do_dummy_step = make_do_dummy_step(params, static_env_params)
|
922 |
+
|
923 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
924 |
+
def mutate_change_shape_size(
|
925 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
926 |
+
):
|
927 |
+
shape_active = jnp.concatenate(
|
928 |
+
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
|
929 |
+
)
|
930 |
+
|
931 |
+
def dummy(rng, state):
|
932 |
+
return state
|
933 |
+
|
934 |
+
def do_change(rng, state):
|
935 |
+
rng, _rng = jax.random.split(rng)
|
936 |
+
rngs = jax.random.split(_rng, 10)
|
937 |
+
p = shape_active.astype(jnp.float32)
|
938 |
+
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
|
939 |
+
is_rect = shape_idx < static_env_params.num_polygons
|
940 |
+
vertices, _, radius = sample_dimensions(
|
941 |
+
rngs[1], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size
|
942 |
+
)
|
943 |
+
|
944 |
+
idx_new_top_left = jnp.argmin(vertices[:, 0] * 100 + vertices[:, 1])
|
945 |
+
idx_old_top_left = jnp.argmin(
|
946 |
+
state.polygon.vertices[shape_idx, :, 0] * 100 + state.polygon.vertices[shape_idx, :, 1]
|
947 |
+
)
|
948 |
+
scale_rect = (vertices[idx_new_top_left]) / (state.polygon.vertices[shape_idx, idx_old_top_left])
|
949 |
+
scale_circle = radius / state.circle.radius[shape_idx - static_env_params.num_polygons]
|
950 |
+
vertices = state.polygon.vertices[shape_idx] * scale_rect
|
951 |
+
|
952 |
+
scale = jax.lax.select(
|
953 |
+
is_rect,
|
954 |
+
scale_rect,
|
955 |
+
jnp.array([scale_circle, scale_circle]),
|
956 |
+
)
|
957 |
+
|
958 |
+
is_a = ((state.joint.a_index == shape_idx) & state.joint.active)[:, None]
|
959 |
+
is_b = ((state.joint.b_index == shape_idx) & state.joint.active)[:, None]
|
960 |
+
state = state.replace(
|
961 |
+
joint=state.joint.replace(
|
962 |
+
a_relative_pos=(state.joint.a_relative_pos * scale[None]) * is_a
|
963 |
+
+ (1 - is_a) * state.joint.a_relative_pos,
|
964 |
+
b_relative_pos=(state.joint.b_relative_pos * scale[None]) * is_b
|
965 |
+
+ (1 - is_b) * state.joint.b_relative_pos,
|
966 |
+
),
|
967 |
+
polygon=state.polygon.replace(
|
968 |
+
vertices=jax.lax.select(
|
969 |
+
is_rect, state.polygon.vertices.at[shape_idx].set(vertices), state.polygon.vertices
|
970 |
+
),
|
971 |
+
),
|
972 |
+
circle=state.circle.replace(
|
973 |
+
radius=jax.lax.select(
|
974 |
+
jnp.logical_not(is_rect),
|
975 |
+
state.circle.radius.at[shape_idx - static_env_params.num_polygons].set(radius),
|
976 |
+
state.circle.radius,
|
977 |
+
)
|
978 |
+
),
|
979 |
+
)
|
980 |
+
|
981 |
+
def _ss(state, _):
|
982 |
+
return do_dummy_step(state), None
|
983 |
+
|
984 |
+
state = jax.lax.scan(_ss, state, jnp.arange(5))[0]
|
985 |
+
return recalculate_mass_and_inertia(
|
986 |
+
state, static_env_params, state.polygon_densities, state.circle_densities
|
987 |
+
)
|
988 |
+
|
989 |
+
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
|
990 |
+
|
991 |
+
return mutate_change_shape_size
|
992 |
+
|
993 |
+
|
994 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
995 |
+
def mutate_change_shape_location(
|
996 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
997 |
+
):
|
998 |
+
shape_active = jnp.concatenate(
|
999 |
+
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
|
1000 |
+
)
|
1001 |
+
|
1002 |
+
def dummy(rng, state):
|
1003 |
+
return state
|
1004 |
+
|
1005 |
+
def do_change(rng, state):
|
1006 |
+
rng, _rng = jax.random.split(rng)
|
1007 |
+
rngs = jax.random.split(_rng, 10)
|
1008 |
+
p = shape_active.astype(jnp.float32)
|
1009 |
+
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
|
1010 |
+
delta_pos = jax.random.uniform(rngs[1], shape=(2,)) - 0.5 # [-0.5, 0.5]
|
1011 |
+
|
1012 |
+
positions = jnp.concatenate([state.polygon.position, state.circle.position])
|
1013 |
+
|
1014 |
+
mask_of_shape_locations_to_change = (
|
1015 |
+
(state.collision_matrix[shape_idx] == 0).at[: static_env_params.num_static_fixated_polys].set(False)
|
1016 |
+
)
|
1017 |
+
# check the new positions, but then maybe revert if any shape becomes out of bounds now.
|
1018 |
+
new_positions_tentative = positions * (
|
1019 |
+
1 - mask_of_shape_locations_to_change[:, None]
|
1020 |
+
) + mask_of_shape_locations_to_change[:, None] * (positions + delta_pos[None])
|
1021 |
+
|
1022 |
+
polys = state.polygon
|
1023 |
+
p_pos = new_positions_tentative[: static_env_params.num_polygons]
|
1024 |
+
c_pos = new_positions_tentative[static_env_params.num_polygons :] # state.circle.position
|
1025 |
+
rad = state.circle.radius
|
1026 |
+
rect_vertex_mask = jnp.arange(static_env_params.max_polygon_vertices)[None] < polys.n_vertices[:, None]
|
1027 |
+
rect_mask = polys.active.at[: static_env_params.num_static_fixated_polys].set(False)
|
1028 |
+
circ_mask = state.circle.active
|
1029 |
+
# check if new pos maybe goes out of bounds:
|
1030 |
+
min_x, max_x, min_y, max_y = (
|
1031 |
+
jnp.minimum(
|
1032 |
+
jnp.min(
|
1033 |
+
p_pos[:, 0] + jnp.min(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1),
|
1034 |
+
where=rect_mask,
|
1035 |
+
initial=jnp.inf,
|
1036 |
+
),
|
1037 |
+
jnp.min(c_pos[:, 0] - rad, where=circ_mask, initial=jnp.inf),
|
1038 |
+
),
|
1039 |
+
jnp.maximum(
|
1040 |
+
jnp.max(
|
1041 |
+
p_pos[:, 0] + jnp.max(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1),
|
1042 |
+
where=rect_mask,
|
1043 |
+
initial=-jnp.inf,
|
1044 |
+
),
|
1045 |
+
jnp.max(c_pos[:, 0] + rad, where=circ_mask, initial=-jnp.inf),
|
1046 |
+
),
|
1047 |
+
jnp.minimum(
|
1048 |
+
jnp.min(
|
1049 |
+
p_pos[:, 1] + jnp.min(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1),
|
1050 |
+
where=rect_mask,
|
1051 |
+
initial=jnp.inf,
|
1052 |
+
),
|
1053 |
+
jnp.min(c_pos[:, 1] - rad, where=circ_mask, initial=jnp.inf),
|
1054 |
+
),
|
1055 |
+
jnp.maximum(
|
1056 |
+
jnp.max(
|
1057 |
+
p_pos[:, 1] + jnp.max(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1),
|
1058 |
+
where=rect_mask,
|
1059 |
+
initial=-jnp.inf,
|
1060 |
+
),
|
1061 |
+
jnp.max(c_pos[:, 1] + rad, where=circ_mask, initial=-jnp.inf),
|
1062 |
+
),
|
1063 |
+
)
|
1064 |
+
|
1065 |
+
how_much_oob_x_left = jnp.maximum(0, 0 - min_x)
|
1066 |
+
how_much_oob_x_right = jnp.maximum(0, max_x - static_env_params.screen_dim[0] / params.pixels_per_unit)
|
1067 |
+
how_much_oob_y_down = jnp.maximum(0, 0.4 - min_y) # this is for the floor
|
1068 |
+
how_much_oob_y_up = jnp.maximum(0, max_y - static_env_params.screen_dim[1] / params.pixels_per_unit)
|
1069 |
+
|
1070 |
+
# correct by out of bounds factor
|
1071 |
+
positions = (
|
1072 |
+
new_positions_tentative
|
1073 |
+
+ jnp.array(
|
1074 |
+
[
|
1075 |
+
how_much_oob_x_left - how_much_oob_x_right,
|
1076 |
+
how_much_oob_y_down - how_much_oob_y_up,
|
1077 |
+
]
|
1078 |
+
)[None]
|
1079 |
+
* mask_of_shape_locations_to_change[:, None]
|
1080 |
+
)
|
1081 |
+
|
1082 |
+
state = state.replace(
|
1083 |
+
polygon=state.polygon.replace(
|
1084 |
+
position=positions[: static_env_params.num_polygons],
|
1085 |
+
),
|
1086 |
+
circle=state.circle.replace(
|
1087 |
+
position=positions[static_env_params.num_polygons :],
|
1088 |
+
),
|
1089 |
+
)
|
1090 |
+
return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
|
1091 |
+
|
1092 |
+
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
|
1093 |
+
|
1094 |
+
|
1095 |
+
def make_mutate_change_shape_rotation(params, static_env_params):
|
1096 |
+
do_dummy_step = make_do_dummy_step(params, static_env_params)
|
1097 |
+
|
1098 |
+
@partial(jax.jit, static_argnums=(3, 4))
|
1099 |
+
def mutate_change_shape_rotation(
|
1100 |
+
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
|
1101 |
+
):
|
1102 |
+
shape_active = jnp.concatenate(
|
1103 |
+
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
def dummy(rng, state):
|
1107 |
+
return state
|
1108 |
+
|
1109 |
+
def do_change(rng, state):
|
1110 |
+
rng, _rng = jax.random.split(rng)
|
1111 |
+
rngs = jax.random.split(_rng, 10)
|
1112 |
+
p = shape_active.astype(jnp.float32)
|
1113 |
+
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
|
1114 |
+
is_rect = shape_idx < static_env_params.num_polygons
|
1115 |
+
|
1116 |
+
rotation_delta = jax.random.uniform(rngs[1], shape=()) * math.pi / 2
|
1117 |
+
|
1118 |
+
has_fixed_joint_a = (state.joint.a_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active
|
1119 |
+
has_fixed_joint_b = (state.joint.b_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active
|
1120 |
+
|
1121 |
+
state = state.replace(
|
1122 |
+
joint=state.joint.replace(
|
1123 |
+
rotation=jax.lax.select(
|
1124 |
+
has_fixed_joint_a,
|
1125 |
+
state.joint.rotation - rotation_delta,
|
1126 |
+
jax.lax.select(
|
1127 |
+
has_fixed_joint_b,
|
1128 |
+
state.joint.rotation + rotation_delta,
|
1129 |
+
state.joint.rotation,
|
1130 |
+
),
|
1131 |
+
)
|
1132 |
+
),
|
1133 |
+
polygon=state.polygon.replace(
|
1134 |
+
rotation=jax.lax.select(
|
1135 |
+
is_rect, state.polygon.rotation.at[shape_idx].add(rotation_delta), state.polygon.rotation
|
1136 |
+
),
|
1137 |
+
),
|
1138 |
+
circle=state.circle.replace(
|
1139 |
+
rotation=jax.lax.select(
|
1140 |
+
jnp.logical_not(is_rect),
|
1141 |
+
state.circle.rotation.at[shape_idx - static_env_params.num_polygons].add(rotation_delta),
|
1142 |
+
state.circle.rotation,
|
1143 |
+
)
|
1144 |
+
),
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
def _ss(state, _):
|
1148 |
+
return do_dummy_step(state), None
|
1149 |
+
|
1150 |
+
state = jax.lax.scan(_ss, state, jnp.arange(5))[0]
|
1151 |
+
return recalculate_mass_and_inertia(
|
1152 |
+
state, static_env_params, state.polygon_densities, state.circle_densities
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
|
1156 |
+
|
1157 |
+
return mutate_change_shape_rotation
|
kinetix/environment/ued/ued.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
|
5 |
+
import chex
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
from flax.serialization import to_state_dict
|
9 |
+
from jax2d.engine import (
|
10 |
+
calculate_collision_matrix,
|
11 |
+
calc_inverse_mass_polygon,
|
12 |
+
calc_inverse_mass_circle,
|
13 |
+
calc_inverse_inertia_circle,
|
14 |
+
calc_inverse_inertia_polygon,
|
15 |
+
recalculate_mass_and_inertia,
|
16 |
+
select_shape,
|
17 |
+
PhysicsEngine,
|
18 |
+
)
|
19 |
+
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
|
20 |
+
from jax2d.maths import rmat
|
21 |
+
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
|
22 |
+
from kinetix.environment.ued.distributions import (
|
23 |
+
create_vmapped_filtered_distribution,
|
24 |
+
sample_kinetix_level,
|
25 |
+
)
|
26 |
+
from kinetix.environment.ued.mutators import (
|
27 |
+
make_mutate_change_shape_rotation,
|
28 |
+
make_mutate_change_shape_size,
|
29 |
+
mutate_add_connected_shape_proper,
|
30 |
+
mutate_add_shape,
|
31 |
+
mutate_add_connected_shape,
|
32 |
+
mutate_change_shape_location,
|
33 |
+
mutate_remove_joint,
|
34 |
+
mutate_remove_shape,
|
35 |
+
mutate_swap_role,
|
36 |
+
mutate_toggle_fixture,
|
37 |
+
mutate_add_thruster,
|
38 |
+
mutate_remove_thruster,
|
39 |
+
mutate_change_gravity,
|
40 |
+
)
|
41 |
+
from kinetix.environment.ued.ued_state import UEDParams
|
42 |
+
from kinetix.environment.utils import permute_pcg_state
|
43 |
+
from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state
|
44 |
+
from kinetix.util.config import generate_ued_params_from_config, generate_params_from_config
|
45 |
+
from kinetix.util.saving import get_pcg_state_from_json, load_pcg_state_pickle, load_world_state_pickle, stack_list_of_pytrees, expand_env_state
|
46 |
+
from flax import struct
|
47 |
+
from kinetix.environment.env import create_empty_env
|
48 |
+
from kinetix.util.learning import BASE_DIR, general_eval, get_eval_levels
|
49 |
+
|
50 |
+
|
51 |
+
def make_mutate_env(static_env_params: StaticEnvParams, params: EnvParams, ued_params: UEDParams):
|
52 |
+
mutate_size = make_mutate_change_shape_size(params, static_env_params)
|
53 |
+
mutate_rot = make_mutate_change_shape_rotation(params, static_env_params)
|
54 |
+
|
55 |
+
def mutate_level(rng, level: EnvState, n=1):
|
56 |
+
def inner(carry: tuple[chex.PRNGKey, EnvState], _):
|
57 |
+
rng, level = carry
|
58 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
59 |
+
|
60 |
+
any_rects_left = jnp.logical_not(level.polygon.active).sum() > 0
|
61 |
+
any_circles_left = jnp.logical_not(level.circle.active).sum() > 0
|
62 |
+
any_joints_left = jnp.logical_not(level.joint.active).sum() > 0
|
63 |
+
any_thrust_left = jnp.logical_not(level.thruster.active).sum() > 0
|
64 |
+
has_any_thursters = level.thruster.active.sum() > 0
|
65 |
+
|
66 |
+
can_do_add_shape = any_rects_left | any_circles_left
|
67 |
+
can_do_add_joint = can_do_add_shape & any_joints_left
|
68 |
+
|
69 |
+
all_mutations = [
|
70 |
+
mutate_add_shape,
|
71 |
+
mutate_add_connected_shape_proper,
|
72 |
+
mutate_remove_joint,
|
73 |
+
mutate_remove_shape,
|
74 |
+
mutate_swap_role,
|
75 |
+
mutate_add_thruster,
|
76 |
+
mutate_remove_thruster,
|
77 |
+
mutate_toggle_fixture,
|
78 |
+
mutate_size,
|
79 |
+
mutate_change_shape_location,
|
80 |
+
mutate_rot,
|
81 |
+
]
|
82 |
+
|
83 |
+
def mypartial(f):
|
84 |
+
def inner(rng, level):
|
85 |
+
return f(rng, level, params, static_env_params, ued_params)
|
86 |
+
|
87 |
+
return inner
|
88 |
+
|
89 |
+
probs = jnp.array(
|
90 |
+
[
|
91 |
+
can_do_add_shape * 1.0,
|
92 |
+
can_do_add_joint * 1.0,
|
93 |
+
0.0,
|
94 |
+
0.0,
|
95 |
+
1.0,
|
96 |
+
any_thrust_left * 1.0,
|
97 |
+
has_any_thursters * 1.0,
|
98 |
+
0.1,
|
99 |
+
1.0,
|
100 |
+
1.0,
|
101 |
+
1.0,
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
all_mutations = [mypartial(i) for i in all_mutations]
|
106 |
+
index = jax.random.choice(_rng, jnp.arange(len(all_mutations)), (), p=probs)
|
107 |
+
level = jax.lax.switch(index, all_mutations, _rng2, level)
|
108 |
+
|
109 |
+
return (rng, level), None
|
110 |
+
|
111 |
+
(_, level), _ = jax.lax.scan(inner, (rng, level), None, length=n)
|
112 |
+
return level
|
113 |
+
|
114 |
+
return mutate_level
|
115 |
+
|
116 |
+
|
117 |
+
def make_create_eval_env():
|
118 |
+
eval_level1 = load_world_state_pickle("worlds/eval/eval_0610_car1")
|
119 |
+
eval_level2 = load_world_state_pickle("worlds/eval/eval_0610_car2")
|
120 |
+
eval_level3 = load_world_state_pickle("worlds/eval/eval_0628_ball_left")
|
121 |
+
eval_level4 = load_world_state_pickle("worlds/eval/eval_0628_ball_right")
|
122 |
+
eval_level5 = load_world_state_pickle("worlds/eval/eval_0628_hard_car_obstacle")
|
123 |
+
eval_level6 = load_world_state_pickle("worlds/eval/eval_0628_swingup")
|
124 |
+
|
125 |
+
def _create_eval_env(rng, env_params, static_env_params, index):
|
126 |
+
return jax.lax.switch(
|
127 |
+
index,
|
128 |
+
[
|
129 |
+
lambda: eval_level1,
|
130 |
+
lambda: eval_level2,
|
131 |
+
lambda: eval_level3,
|
132 |
+
lambda: eval_level4,
|
133 |
+
lambda: eval_level5,
|
134 |
+
lambda: eval_level6,
|
135 |
+
],
|
136 |
+
)
|
137 |
+
return jax.tree.map(lambda x, y: jax.lax.select(index == 0, x, y), eval_level1, eval_level2)
|
138 |
+
|
139 |
+
return _create_eval_env
|
140 |
+
|
141 |
+
|
142 |
+
def make_reset_train_function_with_mutations(
|
143 |
+
engine: PhysicsEngine, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state=True
|
144 |
+
):
|
145 |
+
ued_params = generate_ued_params_from_config(config)
|
146 |
+
|
147 |
+
def reset(rng):
|
148 |
+
inner = sample_kinetix_level(
|
149 |
+
rng, engine, env_params, static_env_params, ued_params, env_size_name=config["env_size_name"]
|
150 |
+
)
|
151 |
+
|
152 |
+
if make_pcg_state:
|
153 |
+
return env_state_to_pcg_state(inner)
|
154 |
+
else:
|
155 |
+
return inner
|
156 |
+
|
157 |
+
return reset
|
158 |
+
|
159 |
+
|
160 |
+
def make_vmapped_filtered_level_sampler(
|
161 |
+
level_sampler, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state, env
|
162 |
+
):
|
163 |
+
ued_params = generate_ued_params_from_config(config)
|
164 |
+
|
165 |
+
def reset(rng, n_samples):
|
166 |
+
inner = create_vmapped_filtered_distribution(
|
167 |
+
rng,
|
168 |
+
level_sampler,
|
169 |
+
env_params,
|
170 |
+
static_env_params,
|
171 |
+
ued_params,
|
172 |
+
n_samples,
|
173 |
+
env,
|
174 |
+
config["filter_levels"],
|
175 |
+
config["level_filter_sample_ratio"],
|
176 |
+
config["env_size_name"],
|
177 |
+
config["level_filter_n_steps"],
|
178 |
+
)
|
179 |
+
if make_pcg_state:
|
180 |
+
return env_state_to_pcg_state(inner)
|
181 |
+
else:
|
182 |
+
return inner
|
183 |
+
|
184 |
+
return reset
|
185 |
+
|
186 |
+
|
187 |
+
def make_reset_train_function_with_list_of_levels(config, levels, static_env_params, make_pcg_state=True,
|
188 |
+
is_loading_train_levels=False):
|
189 |
+
assert len(levels) > 0, "Need to provide at least one level to train on"
|
190 |
+
if config["load_train_levels_legacy"]:
|
191 |
+
ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in levels]
|
192 |
+
v = stack_list_of_pytrees(ls)
|
193 |
+
elif is_loading_train_levels:
|
194 |
+
v = get_eval_levels(levels, static_env_params)
|
195 |
+
else:
|
196 |
+
_, static_env_params = generate_params_from_config(
|
197 |
+
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
|
198 |
+
)
|
199 |
+
v = get_eval_levels(levels, static_env_params)
|
200 |
+
|
201 |
+
def reset(rng):
|
202 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
203 |
+
idx = jax.random.randint(_rng, (), 0, len(levels))
|
204 |
+
state_to_return = jax.tree.map(lambda x: x[idx], v)
|
205 |
+
|
206 |
+
if config["permute_state_during_training"]:
|
207 |
+
state_to_return = permute_pcg_state(rng, state_to_return, static_env_params)
|
208 |
+
if not make_pcg_state:
|
209 |
+
state_to_return = sample_pcg_state(_rng2, state_to_return, params=None, static_params=static_env_params)
|
210 |
+
|
211 |
+
return state_to_return
|
212 |
+
|
213 |
+
return reset
|
214 |
+
|
215 |
+
|
216 |
+
ALL_MUTATION_FNS = [
|
217 |
+
mutate_add_shape,
|
218 |
+
mutate_add_connected_shape,
|
219 |
+
mutate_remove_joint,
|
220 |
+
mutate_swap_role,
|
221 |
+
mutate_toggle_fixture,
|
222 |
+
mutate_add_thruster,
|
223 |
+
mutate_remove_thruster,
|
224 |
+
mutate_remove_shape,
|
225 |
+
mutate_change_gravity,
|
226 |
+
]
|
227 |
+
|
228 |
+
|
229 |
+
def test_ued():
|
230 |
+
from kinetix.environment.env import create_empty_env
|
231 |
+
|
232 |
+
env_params = EnvParams()
|
233 |
+
static_env_params = StaticEnvParams()
|
234 |
+
ued_params = UEDParams()
|
235 |
+
rng = jax.random.PRNGKey(0)
|
236 |
+
rng, _rng = jax.random.split(rng)
|
237 |
+
state = create_empty_env(env_params, static_env_params)
|
238 |
+
state = mutate_add_shape(_rng, state, env_params, static_env_params, ued_params)
|
239 |
+
state = mutate_add_connected_shape(_rng, state, env_params, static_env_params, ued_params)
|
240 |
+
state = mutate_remove_shape(_rng, state, env_params, static_env_params, ued_params)
|
241 |
+
state = mutate_remove_joint(_rng, state, env_params, static_env_params, ued_params)
|
242 |
+
state = mutate_swap_role(_rng, state, env_params, static_env_params, ued_params)
|
243 |
+
state = mutate_toggle_fixture(_rng, state, env_params, static_env_params, ued_params)
|
244 |
+
|
245 |
+
print("Successfully did this")
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
test_ued()
|
kinetix/environment/ued/ued_state.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from flax import struct
|
4 |
+
|
5 |
+
|
6 |
+
@struct.dataclass
|
7 |
+
class UEDParams:
|
8 |
+
max_shape_size: float = 1.0
|
9 |
+
goal_body_opposide_side_chance: float = 0.5
|
10 |
+
goal_body_size_factor: float = 1.0
|
11 |
+
min_rjoints_bias: int = 2
|
12 |
+
|
13 |
+
large_rect_dim_chance: float = 0.3
|
14 |
+
large_rect_dim_scale: float = 2.0
|
15 |
+
|
16 |
+
generate_triangles: bool = False
|
17 |
+
thruster_power_multiplier: float = 2.0
|
18 |
+
|
19 |
+
thruster_align_com_prob: float = 0.8
|
20 |
+
|
21 |
+
motor_on_chance: float = 0.8
|
22 |
+
motor_min_speed: float = 0.4
|
23 |
+
motor_max_speed: float = 3.0
|
24 |
+
motor_min_power: float = 1.0
|
25 |
+
motor_max_power: float = 3.0
|
26 |
+
wheel_max_power: float = 1.0
|
27 |
+
|
28 |
+
joint_limit_chance: float = 0.4
|
29 |
+
joint_limit_max: float = math.pi
|
30 |
+
joint_fixed_chance: float = 0.1
|
31 |
+
|
32 |
+
fixate_chance_min: float = 0.02
|
33 |
+
fixate_chance_max: float = 1.0
|
34 |
+
fixate_chance_scale: float = 4.0 # Fixation probability scales with size
|
35 |
+
fixate_shape_bottom_bias: float = 0.0
|
36 |
+
fixate_shape_bottom_bias_special_role: float = 0.6
|
37 |
+
|
38 |
+
circle_max_size_coeff: float = 0.8
|
39 |
+
|
40 |
+
connect_to_fixated_prob_coeff: float = 0.05
|
41 |
+
connect_visibility_min: float = 0.05
|
42 |
+
connect_no_visibility_bias: float = 10.0
|
43 |
+
|
44 |
+
add_shape_chance: float = 0.35
|
45 |
+
add_connected_shape_chance: float = 0.35
|
46 |
+
add_no_shape_chance: float = 0.3
|
47 |
+
add_thruster_chance: float = 0.3
|
48 |
+
add_shape_n_proposals: int = 8
|
49 |
+
|
50 |
+
floor_prob_normal: float = 0.9
|
51 |
+
floor_prob_green: float = 0.0
|
52 |
+
floor_prob_blue: float = 0.02
|
53 |
+
floor_prob_red: float = 0.08
|
kinetix/environment/ued/util.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
|
7 |
+
from jax2d.engine import PhysicsEngine, calculate_collision_matrix, recalculate_mass_and_inertia, select_shape
|
8 |
+
from jax2d.sim_state import RigidBody, Thruster
|
9 |
+
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
|
10 |
+
|
11 |
+
|
12 |
+
def sample_dimensions(rng, static_env_params: StaticEnvParams, is_rect: bool, ued_params, max_shape_size=None):
|
13 |
+
if max_shape_size is None:
|
14 |
+
max_shape_size = static_env_params.max_shape_size
|
15 |
+
# Returns (half_dimensions, radius)
|
16 |
+
|
17 |
+
rng, _rng = jax.random.split(rng)
|
18 |
+
# Don't want overly small shapes
|
19 |
+
min_rect_size = 0.05
|
20 |
+
min_circle_size = 0.1
|
21 |
+
cap_rect = max_shape_size / 2.0 / jnp.sqrt(2.0)
|
22 |
+
cap_circ = max_shape_size / 2.0 * ued_params.circle_max_size_coeff
|
23 |
+
half_dimensions = (
|
24 |
+
jax.lax.select(is_rect, jax.random.uniform(_rng, shape=(2,)), jnp.zeros(2, dtype=jnp.float32))
|
25 |
+
* (cap_rect - min_rect_size)
|
26 |
+
+ min_rect_size
|
27 |
+
)
|
28 |
+
|
29 |
+
rng, _rng, __rng = jax.random.split(rng, 3)
|
30 |
+
dim_scale = (
|
31 |
+
jnp.ones(2)
|
32 |
+
.at[jax.random.randint(_rng, shape=(), minval=0, maxval=2)]
|
33 |
+
.set(
|
34 |
+
jax.lax.select(
|
35 |
+
jax.random.uniform(__rng) < ued_params.large_rect_dim_chance, ued_params.large_rect_dim_scale, 1.0
|
36 |
+
)
|
37 |
+
)
|
38 |
+
)
|
39 |
+
half_dimensions *= dim_scale
|
40 |
+
|
41 |
+
vertices = jnp.array(
|
42 |
+
[
|
43 |
+
half_dimensions * jnp.array([1, 1]),
|
44 |
+
half_dimensions * jnp.array([1, -1]),
|
45 |
+
half_dimensions * jnp.array([-1, -1]),
|
46 |
+
half_dimensions * jnp.array([-1, 1]),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
rng, _rng = jax.random.split(rng)
|
51 |
+
radius = (
|
52 |
+
jax.lax.select(is_rect, jnp.zeros((), dtype=jnp.float32), jax.random.uniform(_rng, shape=()))
|
53 |
+
* (cap_circ - min_circle_size)
|
54 |
+
+ min_circle_size
|
55 |
+
)
|
56 |
+
return vertices, half_dimensions, radius
|
57 |
+
|
58 |
+
|
59 |
+
def count_roles(state: EnvState, static_env_params: StaticEnvParams, role: int, include_static_polys=True) -> int:
|
60 |
+
active_to_use = state.polygon.active
|
61 |
+
if not include_static_polys:
|
62 |
+
active_to_use = active_to_use.at[: static_env_params.num_static_fixated_polys].set(False)
|
63 |
+
return ((state.polygon_shape_roles == role) * active_to_use).sum() + (
|
64 |
+
(state.circle_shape_roles == role) * state.circle.active
|
65 |
+
).sum()
|
66 |
+
|
67 |
+
|
68 |
+
def random_position_on_triangle(rng, vertices):
|
69 |
+
verts = vertices[:3]
|
70 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
71 |
+
f1 = jax.random.uniform(_rng)
|
72 |
+
f2 = jax.random.uniform(_rng2)
|
73 |
+
# https://www.reddit.com/r/godot/comments/mqp29g/how_do_i_get_a_random_position_inside_a_collision/
|
74 |
+
return verts[0] + jnp.sqrt(f1) * (-verts[0] + verts[1] + f2 * (verts[2] - verts[1]))
|
75 |
+
|
76 |
+
|
77 |
+
def random_position_on_rectangle(rng, vertices):
|
78 |
+
verts = vertices[:4]
|
79 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
80 |
+
f1 = jax.random.uniform(_rng)
|
81 |
+
f2 = jax.random.uniform(_rng2)
|
82 |
+
|
83 |
+
min_x, max_x = jnp.min(verts[:, 0]), jnp.max(verts[:, 0])
|
84 |
+
min_y, max_y = jnp.min(verts[:, 1]), jnp.max(verts[:, 1])
|
85 |
+
random_x_pos = min_x + f1 * (max_x - min_x)
|
86 |
+
random_y_pos = min_y + f2 * (max_y - min_y)
|
87 |
+
|
88 |
+
return jnp.array([random_x_pos, random_y_pos])
|
89 |
+
|
90 |
+
|
91 |
+
def random_position_on_polygon(rng, vertices, n_vertices, static_env_params: StaticEnvParams):
|
92 |
+
assert static_env_params.max_polygon_vertices <= 4, "Only supports up to 4 vertices"
|
93 |
+
return jax.lax.select(
|
94 |
+
n_vertices <= 3, random_position_on_triangle(rng, vertices), random_position_on_rectangle(rng, vertices)
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def random_position_on_circle(rng, radius, on_centre_chance):
|
99 |
+
rngs = jax.random.split(rng, 3)
|
100 |
+
|
101 |
+
on_centre = jax.random.uniform(rngs[0]) < on_centre_chance
|
102 |
+
|
103 |
+
local_joint_position_circle_theta = jax.random.uniform(rngs[1], shape=()) * 2 * math.pi
|
104 |
+
local_joint_position_circle_r = jax.random.uniform(rngs[2], shape=()) * radius
|
105 |
+
local_joint_position_circle = jnp.array(
|
106 |
+
[
|
107 |
+
local_joint_position_circle_r * jnp.cos(local_joint_position_circle_theta),
|
108 |
+
local_joint_position_circle_r * jnp.sin(local_joint_position_circle_theta),
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
return jax.lax.select(on_centre, jnp.array([0.0, 0.0]), local_joint_position_circle)
|
113 |
+
|
114 |
+
|
115 |
+
def get_role(rng, state: EnvState, static_env_params: StaticEnvParams, initial_p=None) -> int:
|
116 |
+
|
117 |
+
if initial_p is None:
|
118 |
+
initial_p = jnp.array([1.0, 1.0, 1.0, 1.0])
|
119 |
+
|
120 |
+
needs_ball = count_roles(state, static_env_params, 1) == 0
|
121 |
+
needs_goal = count_roles(state, static_env_params, 2) == 0
|
122 |
+
needs_lava = count_roles(state, static_env_params, 3) == 0
|
123 |
+
|
124 |
+
# always put goal/ball first.
|
125 |
+
prob_of_something_else = (needs_ball == 0) & (needs_goal == 0)
|
126 |
+
p = initial_p * jnp.array(
|
127 |
+
[prob_of_something_else, needs_ball, needs_goal, prob_of_something_else * needs_lava / 3]
|
128 |
+
) # This ensures we cannot more than one ball or goal.
|
129 |
+
return jax.random.choice(rng, jnp.array([0, 1, 2, 3]), p=p)
|
130 |
+
|
131 |
+
|
132 |
+
def is_space_for_shape(state: EnvState):
|
133 |
+
return jnp.logical_not(jnp.concatenate([state.polygon.active, state.circle.active])).sum() > 0
|
134 |
+
|
135 |
+
|
136 |
+
def is_space_for_joint(state: EnvState):
|
137 |
+
return jnp.logical_not(state.joint.active).sum() > 0
|
138 |
+
|
139 |
+
|
140 |
+
def are_there_shapes_present(state: EnvState, static_env_params: StaticEnvParams):
|
141 |
+
m = (
|
142 |
+
jnp.concatenate([state.polygon.active, state.circle.active])
|
143 |
+
.at[: static_env_params.num_static_fixated_polys]
|
144 |
+
.set(False)
|
145 |
+
)
|
146 |
+
return m.sum() > 0
|
147 |
+
|
148 |
+
|
149 |
+
@partial(jax.jit, static_argnums=(2, 9))
|
150 |
+
def add_rigidbody_to_state(
|
151 |
+
state: EnvState,
|
152 |
+
env_params: EnvParams,
|
153 |
+
static_env_params: StaticEnvParams,
|
154 |
+
position: jnp.ndarray,
|
155 |
+
vertices: jnp.ndarray,
|
156 |
+
n_vertices: int,
|
157 |
+
radius: float,
|
158 |
+
shape_role: int,
|
159 |
+
density: float = 1,
|
160 |
+
is_circle: bool = False,
|
161 |
+
):
|
162 |
+
|
163 |
+
new_rigid_body = RigidBody(
|
164 |
+
position=position,
|
165 |
+
velocity=jnp.array([0.0, 0.0]),
|
166 |
+
inverse_mass=1.0,
|
167 |
+
inverse_inertia=1.0,
|
168 |
+
rotation=0.0,
|
169 |
+
angular_velocity=0.0,
|
170 |
+
radius=radius,
|
171 |
+
active=True,
|
172 |
+
friction=1.0,
|
173 |
+
vertices=vertices,
|
174 |
+
n_vertices=n_vertices,
|
175 |
+
collision_mode=1,
|
176 |
+
restitution=0.0,
|
177 |
+
)
|
178 |
+
|
179 |
+
if is_circle:
|
180 |
+
actives = state.circle.active
|
181 |
+
else:
|
182 |
+
actives = state.polygon.active
|
183 |
+
|
184 |
+
idx = jnp.argmin(actives)
|
185 |
+
|
186 |
+
def noop(state):
|
187 |
+
return state
|
188 |
+
|
189 |
+
def replace(state):
|
190 |
+
add_func = lambda all, new: all.at[idx].set(new)
|
191 |
+
if is_circle:
|
192 |
+
state = state.replace(
|
193 |
+
circle=jax.tree.map(add_func, state.circle, new_rigid_body),
|
194 |
+
circle_densities=state.circle_densities.at[idx].set(density),
|
195 |
+
circle_shape_roles=state.circle_shape_roles.at[idx].set(shape_role),
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
state = state.replace(
|
199 |
+
polygon=jax.tree.map(add_func, state.polygon, new_rigid_body),
|
200 |
+
polygon_densities=state.polygon_densities.at[idx].set(density),
|
201 |
+
polygon_shape_roles=state.polygon_shape_roles.at[idx].set(shape_role),
|
202 |
+
)
|
203 |
+
|
204 |
+
state = state.replace(
|
205 |
+
collision_matrix=calculate_collision_matrix(static_env_params, state.joint),
|
206 |
+
)
|
207 |
+
|
208 |
+
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
|
209 |
+
return state
|
210 |
+
|
211 |
+
return jax.lax.cond(jnp.logical_not(actives).sum() > 0, replace, noop, state)
|
212 |
+
|
213 |
+
|
214 |
+
def rectangle_vertices(half_dim):
|
215 |
+
return jnp.array(
|
216 |
+
[
|
217 |
+
half_dim * jnp.array([1, 1]),
|
218 |
+
half_dim * jnp.array([1, -1]),
|
219 |
+
half_dim * jnp.array([-1, -1]),
|
220 |
+
half_dim * jnp.array([-1, 1]),
|
221 |
+
]
|
222 |
+
)
|
223 |
+
|
224 |
+
|
225 |
+
# More Manual Control
|
226 |
+
@partial(jax.jit, static_argnums=(2,))
|
227 |
+
def add_rectangle_to_state(
|
228 |
+
state: EnvState,
|
229 |
+
env_params: EnvParams,
|
230 |
+
static_env_params: StaticEnvParams,
|
231 |
+
position: jnp.ndarray,
|
232 |
+
width: float,
|
233 |
+
height: float,
|
234 |
+
shape_role: int,
|
235 |
+
density: float = 1,
|
236 |
+
):
|
237 |
+
|
238 |
+
return add_rigidbody_to_state(
|
239 |
+
state,
|
240 |
+
env_params,
|
241 |
+
static_env_params,
|
242 |
+
position,
|
243 |
+
rectangle_vertices(jnp.array([width, height]) / 2),
|
244 |
+
4,
|
245 |
+
0.0,
|
246 |
+
shape_role,
|
247 |
+
density,
|
248 |
+
is_circle=False,
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
@partial(jax.jit, static_argnums=(2,))
|
253 |
+
def add_circle_to_state(
|
254 |
+
state: EnvState,
|
255 |
+
env_params: EnvParams,
|
256 |
+
static_env_params: StaticEnvParams,
|
257 |
+
position: jnp.ndarray,
|
258 |
+
radius: float,
|
259 |
+
shape_role: int,
|
260 |
+
density: float = 1,
|
261 |
+
):
|
262 |
+
return add_rigidbody_to_state(
|
263 |
+
state,
|
264 |
+
env_params,
|
265 |
+
static_env_params,
|
266 |
+
position,
|
267 |
+
jnp.array([0.0, 0.0]),
|
268 |
+
0,
|
269 |
+
radius,
|
270 |
+
shape_role,
|
271 |
+
density,
|
272 |
+
is_circle=True,
|
273 |
+
)
|
274 |
+
|
275 |
+
|
276 |
+
@partial(jax.jit, static_argnums=(2,))
|
277 |
+
def add_thruster_to_object(
|
278 |
+
state: EnvState,
|
279 |
+
env_params: EnvParams,
|
280 |
+
static_env_params: StaticEnvParams,
|
281 |
+
shape_index: int,
|
282 |
+
rotation: float,
|
283 |
+
colour: int,
|
284 |
+
thruster_power_multiplier: float,
|
285 |
+
):
|
286 |
+
def dummy(state):
|
287 |
+
return state
|
288 |
+
|
289 |
+
def do_add(state: EnvState):
|
290 |
+
thruster_idx = jnp.argmin(state.thruster.active)
|
291 |
+
|
292 |
+
shape = select_shape(state, shape_index, static_env_params)
|
293 |
+
|
294 |
+
thruster = Thruster(
|
295 |
+
object_index=shape_index,
|
296 |
+
active=True,
|
297 |
+
relative_position=jnp.array([0.0, 0.0]), # a bit of a hack but reasonable.
|
298 |
+
rotation=rotation,
|
299 |
+
power=1.0 / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) * thruster_power_multiplier,
|
300 |
+
global_position=select_shape(state, shape_index, static_env_params).position,
|
301 |
+
)
|
302 |
+
|
303 |
+
state = state.replace(
|
304 |
+
thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
|
305 |
+
thruster_bindings=state.thruster_bindings.at[thruster_idx].set(colour),
|
306 |
+
)
|
307 |
+
|
308 |
+
return state
|
309 |
+
|
310 |
+
return jax.lax.cond(
|
311 |
+
(select_shape(state, shape_index, static_env_params).active)
|
312 |
+
& (jnp.logical_not(state.thruster.active).sum() > 0),
|
313 |
+
do_add,
|
314 |
+
dummy,
|
315 |
+
state,
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def make_velocities_zero(state: EnvState):
|
320 |
+
def inner(state):
|
321 |
+
return state.replace(
|
322 |
+
polygon=state.polygon.replace(
|
323 |
+
angular_velocity=state.polygon.angular_velocity * 0,
|
324 |
+
velocity=state.polygon.velocity * 0,
|
325 |
+
),
|
326 |
+
circle=state.circle.replace(
|
327 |
+
angular_velocity=state.circle.angular_velocity * 0,
|
328 |
+
velocity=state.circle.velocity * 0,
|
329 |
+
),
|
330 |
+
)
|
331 |
+
|
332 |
+
return inner(state)
|
333 |
+
|
334 |
+
|
335 |
+
def make_do_dummy_step(
|
336 |
+
params: EnvParams, static_sim_params: StaticEnvParams, zero_collisions=True, zero_velocities=True
|
337 |
+
):
|
338 |
+
env = PhysicsEngine(static_sim_params)
|
339 |
+
|
340 |
+
@jax.jit
|
341 |
+
def _step_fn(state):
|
342 |
+
state, _ = env.step(state, params, jnp.zeros((static_sim_params.num_joints + static_sim_params.num_thrusters,)))
|
343 |
+
return state
|
344 |
+
|
345 |
+
def do_dummy_step(state: EnvState) -> EnvState:
|
346 |
+
rng = jax.random.PRNGKey(0)
|
347 |
+
og_col = state.collision_matrix
|
348 |
+
g = state.gravity
|
349 |
+
state = state.replace(
|
350 |
+
collision_matrix=state.collision_matrix & (not zero_collisions), gravity=state.gravity * 0
|
351 |
+
)
|
352 |
+
state = _step_fn(state)
|
353 |
+
state = state.replace(gravity=g, collision_matrix=og_col)
|
354 |
+
if zero_velocities:
|
355 |
+
state = make_velocities_zero(state)
|
356 |
+
return state
|
357 |
+
|
358 |
+
return do_dummy_step
|
kinetix/environment/utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chex
|
2 |
+
import jax
|
3 |
+
from jax2d.engine import calculate_collision_matrix
|
4 |
+
from kinetix.environment.env_state import EnvState, StaticEnvParams
|
5 |
+
import jax.numpy as jnp
|
6 |
+
|
7 |
+
from kinetix.pcg.pcg_state import PCGState
|
8 |
+
|
9 |
+
|
10 |
+
def permute_state(rng: chex.PRNGKey, env_state: EnvState, static_env_params: StaticEnvParams):
|
11 |
+
idxs_circles = jnp.arange(static_env_params.num_circles)
|
12 |
+
idxs_polygons = jnp.arange(static_env_params.num_polygons)
|
13 |
+
idxs_joints = jnp.arange(static_env_params.num_joints)
|
14 |
+
idxs_thrusters = jnp.arange(static_env_params.num_thrusters)
|
15 |
+
|
16 |
+
rng, *_rngs = jax.random.split(rng, 5)
|
17 |
+
idxs_circles_permuted = jax.random.permutation(_rngs[0], idxs_circles, independent=True)
|
18 |
+
idxs_polygons_permuted = idxs_polygons.at[static_env_params.num_static_fixated_polys :].set(
|
19 |
+
jax.random.permutation(_rngs[1], idxs_polygons[static_env_params.num_static_fixated_polys :], independent=True)
|
20 |
+
)
|
21 |
+
|
22 |
+
idxs_joints_permuted = jax.random.permutation(_rngs[2], idxs_joints, independent=True)
|
23 |
+
idxs_thrusters_permuted = jax.random.permutation(_rngs[3], idxs_thrusters, independent=True)
|
24 |
+
|
25 |
+
combined = jnp.concatenate([idxs_polygons_permuted, idxs_circles_permuted + static_env_params.num_polygons])
|
26 |
+
# Change the ordering of the shapes, and also remember to change the indices associated with the joints
|
27 |
+
|
28 |
+
inverse_permutation = jnp.argsort(combined)
|
29 |
+
|
30 |
+
env_state = env_state.replace(
|
31 |
+
polygon_shape_roles=env_state.polygon_shape_roles[idxs_polygons_permuted],
|
32 |
+
circle_shape_roles=env_state.circle_shape_roles[idxs_circles_permuted],
|
33 |
+
polygon_highlighted=env_state.polygon_highlighted[idxs_polygons_permuted],
|
34 |
+
circle_highlighted=env_state.circle_highlighted[idxs_circles_permuted],
|
35 |
+
polygon_densities=env_state.polygon_densities[idxs_polygons_permuted],
|
36 |
+
circle_densities=env_state.circle_densities[idxs_circles_permuted],
|
37 |
+
polygon=jax.tree.map(lambda x: x[idxs_polygons_permuted], env_state.polygon),
|
38 |
+
circle=jax.tree.map(lambda x: x[idxs_circles_permuted], env_state.circle),
|
39 |
+
joint=env_state.joint.replace(
|
40 |
+
a_index=inverse_permutation[env_state.joint.a_index],
|
41 |
+
b_index=inverse_permutation[env_state.joint.b_index],
|
42 |
+
),
|
43 |
+
thruster=env_state.thruster.replace(
|
44 |
+
object_index=inverse_permutation[env_state.thruster.object_index],
|
45 |
+
),
|
46 |
+
)
|
47 |
+
|
48 |
+
# And now permute the thrusters and joints
|
49 |
+
env_state = env_state.replace(
|
50 |
+
thruster_bindings=env_state.thruster_bindings[idxs_thrusters_permuted],
|
51 |
+
motor_bindings=env_state.motor_bindings[idxs_joints_permuted],
|
52 |
+
motor_auto=env_state.motor_auto[idxs_joints_permuted],
|
53 |
+
joint=jax.tree.map(lambda x: x[idxs_joints_permuted], env_state.joint),
|
54 |
+
thruster=jax.tree.map(lambda x: x[idxs_thrusters_permuted], env_state.thruster),
|
55 |
+
)
|
56 |
+
# and collision matrix
|
57 |
+
env_state = env_state.replace(collision_matrix=calculate_collision_matrix(static_env_params, env_state.joint))
|
58 |
+
return env_state
|
59 |
+
|
60 |
+
|
61 |
+
def permute_pcg_state(rng: chex.PRNGKey, pcg_state: PCGState, static_env_params: StaticEnvParams):
|
62 |
+
return pcg_state.replace(
|
63 |
+
env_state=permute_state(rng, pcg_state.env_state, static_env_params),
|
64 |
+
env_state_max=permute_state(rng, pcg_state.env_state_max, static_env_params),
|
65 |
+
env_state_pcg_mask=jax.tree.map(lambda x: jnp.zeros_like(x, dtype=bool), pcg_state.env_state_pcg_mask),
|
66 |
+
)
|
kinetix/environment/wrappers.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from chex._src.pytypes import PRNGKey
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import chex
|
6 |
+
from jax.numpy import ndarray
|
7 |
+
import numpy as np
|
8 |
+
from flax import struct
|
9 |
+
from functools import partial
|
10 |
+
from typing import Callable, Dict, Optional, Tuple, Union, Any
|
11 |
+
|
12 |
+
from gymnax.environments import spaces, environment
|
13 |
+
|
14 |
+
from kinetix.environment.env_state import EnvParams, EnvState
|
15 |
+
from jaxued.environments import UnderspecifiedEnv
|
16 |
+
|
17 |
+
|
18 |
+
class UnderspecifiedEnvWrapper(UnderspecifiedEnv):
|
19 |
+
"""Base class for Gymnax wrappers."""
|
20 |
+
|
21 |
+
def __init__(self, env):
|
22 |
+
self._env = env
|
23 |
+
|
24 |
+
# provide proxy access to regular attributes of wrapped object
|
25 |
+
def __getattr__(self, name):
|
26 |
+
return getattr(self._env, name)
|
27 |
+
|
28 |
+
|
29 |
+
class GymnaxWrapper(object):
|
30 |
+
"""Base class for Gymnax wrappers."""
|
31 |
+
|
32 |
+
def __init__(self, env):
|
33 |
+
self._env = env
|
34 |
+
|
35 |
+
# provide proxy access to regular attributes of wrapped object
|
36 |
+
def __getattr__(self, name):
|
37 |
+
return getattr(self._env, name)
|
38 |
+
|
39 |
+
|
40 |
+
# From Here: https://github.com/DramaCow/jaxued/blob/main/src/jaxued/wrappers/autoreset.py
|
41 |
+
class AutoResetWrapper(UnderspecifiedEnvWrapper):
|
42 |
+
"""
|
43 |
+
This is a wrapper around an `UnderspecifiedEnv`, allowing for the environment to be automatically reset upon completion of an episode. This behaviour is similar to the default Gymnax interface. The user can specify a callable `sample_level` that takes in a PRNGKey and returns a level.
|
44 |
+
|
45 |
+
Warning:
|
46 |
+
To maintain compliance with UnderspecifiedEnv interface, user can reset to an
|
47 |
+
arbitrary level. This includes levels outside the support of sample_level(). Consequently,
|
48 |
+
the tagged rng is defaulted to jax.random.PRNGKey(0). If your code relies on this, careful
|
49 |
+
attention may be required.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, env: UnderspecifiedEnv, sample_level: Callable[[chex.PRNGKey], EnvState]):
|
53 |
+
self._env = env
|
54 |
+
self.sample_level = sample_level
|
55 |
+
|
56 |
+
@property
|
57 |
+
def default_params(self) -> EnvParams:
|
58 |
+
return self._env.default_params
|
59 |
+
|
60 |
+
def reset_env(self, rng, params):
|
61 |
+
rng, rng_sample, rng_reset = jax.random.split(rng, 3)
|
62 |
+
state_to_reset_to = self.sample_level(rng_sample)
|
63 |
+
return self._env.reset_env_to_pcg_level(rng_reset, state_to_reset_to, params)
|
64 |
+
|
65 |
+
def step_env(
|
66 |
+
self,
|
67 |
+
rng: chex.PRNGKey,
|
68 |
+
state: EnvState,
|
69 |
+
action: Union[int, float],
|
70 |
+
params: EnvParams,
|
71 |
+
) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
|
72 |
+
|
73 |
+
rng_reset, rng_step = jax.random.split(rng, 2)
|
74 |
+
obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state, action, params)
|
75 |
+
obs_re, env_state_re = self.reset_env(rng_reset, params)
|
76 |
+
|
77 |
+
env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
|
78 |
+
obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
|
79 |
+
|
80 |
+
return obs, env_state, reward, done, info
|
81 |
+
|
82 |
+
def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
|
83 |
+
# raise NotImplementedError("This method should not be called directly. Use reset instead.")
|
84 |
+
obs, env_state = self._env.reset_to_level(rng, level, params)
|
85 |
+
return obs, env_state
|
86 |
+
|
87 |
+
def action_space(self, params: EnvParams) -> Any:
|
88 |
+
return self._env.action_space(params)
|
89 |
+
|
90 |
+
|
91 |
+
class AutoReplayWrapper(UnderspecifiedEnv):
|
92 |
+
"""
|
93 |
+
This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
|
94 |
+
This is useful for training/rolling out multiple times on the same level.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, env: UnderspecifiedEnv):
|
98 |
+
self._env = env
|
99 |
+
|
100 |
+
@property
|
101 |
+
def default_params(self) -> EnvParams:
|
102 |
+
return self._env.default_params
|
103 |
+
|
104 |
+
def step_env(
|
105 |
+
self,
|
106 |
+
rng: chex.PRNGKey,
|
107 |
+
state: EnvState,
|
108 |
+
action: Union[int, float],
|
109 |
+
params: EnvParams,
|
110 |
+
) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
|
111 |
+
rng_reset, rng_step = jax.random.split(rng)
|
112 |
+
obs_re, env_state_re = self._env.reset_to_level(rng_reset, state.level, params)
|
113 |
+
obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state.env_state, action, params)
|
114 |
+
env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
|
115 |
+
obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
|
116 |
+
return obs, state.replace(env_state=env_state), reward, done, info
|
117 |
+
|
118 |
+
def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
|
119 |
+
obs, env_state = self._env.reset_to_level(rng, level, params)
|
120 |
+
return obs, AutoReplayState(env_state=env_state, level=level)
|
121 |
+
|
122 |
+
def action_space(self, params: EnvParams) -> Any:
|
123 |
+
return self._env.action_space(params)
|
124 |
+
|
125 |
+
|
126 |
+
@struct.dataclass
|
127 |
+
class AutoReplayState:
|
128 |
+
env_state: EnvState
|
129 |
+
level: EnvState
|
130 |
+
|
131 |
+
|
132 |
+
class AutoReplayWrapper(UnderspecifiedEnvWrapper):
|
133 |
+
"""
|
134 |
+
This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
|
135 |
+
This is useful for training/rolling out multiple times on the same level.
|
136 |
+
"""
|
137 |
+
|
138 |
+
def __init__(self, env: UnderspecifiedEnv):
|
139 |
+
self._env = env
|
140 |
+
|
141 |
+
@property
|
142 |
+
def default_params(self) -> EnvParams:
|
143 |
+
return self._env.default_params
|
144 |
+
|
145 |
+
def step_env(
|
146 |
+
self,
|
147 |
+
rng: chex.PRNGKey,
|
148 |
+
state: EnvState,
|
149 |
+
action: Union[int, float],
|
150 |
+
params: EnvParams,
|
151 |
+
) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
|
152 |
+
rng_reset, rng_step = jax.random.split(rng)
|
153 |
+
obs_re, env_state_re = self._env.reset_to_level(rng_reset, state.level, params)
|
154 |
+
obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state.env_state, action, params)
|
155 |
+
env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
|
156 |
+
obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
|
157 |
+
return obs, state.replace(env_state=env_state), reward, done, info
|
158 |
+
|
159 |
+
def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
|
160 |
+
obs, env_state = self._env.reset_to_level(rng, level, params)
|
161 |
+
return obs, AutoReplayState(env_state=env_state, level=level)
|
162 |
+
|
163 |
+
def action_space(self, params: EnvParams) -> Any:
|
164 |
+
return self._env.action_space(params)
|
165 |
+
|
166 |
+
|
167 |
+
class UnderspecifiedToGymnaxWrapper(environment.Environment):
|
168 |
+
def __init__(self, env):
|
169 |
+
self._env = env
|
170 |
+
|
171 |
+
# provide proxy access to regular attributes of wrapped object
|
172 |
+
def __getattr__(self, name):
|
173 |
+
return getattr(self._env, name)
|
174 |
+
|
175 |
+
@property
|
176 |
+
def default_params(self) -> Any:
|
177 |
+
return self._env.default_params
|
178 |
+
|
179 |
+
def step_env(
|
180 |
+
self, key: jax.Array, state: Any, action: int | float | jax.Array | ndarray | np.bool_ | np.number, params: Any
|
181 |
+
) -> Tuple[jax.Array | ndarray | np.bool_ | np.number | Any | Dict[Any, Any]]:
|
182 |
+
return self._env.step_env(key, state, action, params)
|
183 |
+
|
184 |
+
def reset_env(self, key: PRNGKey, params: Any) -> Tuple[PRNGKey | np.ndarray | np.bool_ | np.number | Any]:
|
185 |
+
return self._env.reset_env(key, params)
|
186 |
+
|
187 |
+
def action_space(self, params: Any):
|
188 |
+
return self._env.action_space(params)
|
189 |
+
|
190 |
+
|
191 |
+
class BatchEnvWrapper(GymnaxWrapper):
|
192 |
+
"""Batches reset and step functions"""
|
193 |
+
|
194 |
+
def __init__(self, env, num_envs: int):
|
195 |
+
super().__init__(env)
|
196 |
+
|
197 |
+
self.num_envs = num_envs
|
198 |
+
|
199 |
+
self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
|
200 |
+
self.reset_to_level_fn = jax.vmap(self._env.reset_to_level, in_axes=(0, 0, None))
|
201 |
+
self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
|
202 |
+
|
203 |
+
@partial(jax.jit, static_argnums=(0, 2))
|
204 |
+
def reset(self, rng, params=None):
|
205 |
+
rng, _rng = jax.random.split(rng)
|
206 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
207 |
+
obs, env_state = self.reset_fn(rngs, params)
|
208 |
+
return obs, env_state
|
209 |
+
|
210 |
+
@partial(jax.jit, static_argnums=(0, 3))
|
211 |
+
def reset_to_level(self, rng, level, params=None):
|
212 |
+
rng, _rng = jax.random.split(rng)
|
213 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
214 |
+
obs, env_state = self.reset_to_level_fn(rngs, level, params)
|
215 |
+
return obs, env_state
|
216 |
+
|
217 |
+
@partial(jax.jit, static_argnums=(0, 4))
|
218 |
+
def step(self, rng, state, action, params=None):
|
219 |
+
rng, _rng = jax.random.split(rng)
|
220 |
+
rngs = jax.random.split(_rng, self.num_envs)
|
221 |
+
obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
|
222 |
+
|
223 |
+
return obs, state, reward, done, info
|
224 |
+
|
225 |
+
|
226 |
+
@struct.dataclass
|
227 |
+
class DenseRewardState:
|
228 |
+
env_state: EnvState
|
229 |
+
last_distance: float = -1.0
|
230 |
+
|
231 |
+
|
232 |
+
class DenseRewardWrapper(GymnaxWrapper):
|
233 |
+
def __init__(self, env, dense_reward_scale: float = 1.0) -> None:
|
234 |
+
super().__init__(env)
|
235 |
+
self.dense_reward_scale = dense_reward_scale
|
236 |
+
|
237 |
+
def step(self, key, state, action: int, params=None):
|
238 |
+
obs, env_state, reward, done, info = self._env.step_env(key, state.env_state, action, params)
|
239 |
+
delta_dist = (
|
240 |
+
-(info["distance"] - state.last_distance) * params.dense_reward_scale
|
241 |
+
) # if distance got less, then reward is positive
|
242 |
+
|
243 |
+
delta_dist = jnp.nan_to_num(delta_dist, nan=0.0, posinf=0.0, neginf=0.0)
|
244 |
+
reward = reward + jax.lax.select(
|
245 |
+
(state.last_distance == -1) | (self.dense_reward_scale == 0.0), 0.0, delta_dist * self.dense_reward_scale
|
246 |
+
)
|
247 |
+
return obs, DenseRewardState(env_state, info["distance"]), reward, done, info
|
248 |
+
|
249 |
+
def reset(self, rng, params=None):
|
250 |
+
obs, env_state = self._env.reset(rng, params)
|
251 |
+
return obs, DenseRewardState(env_state, -1.0)
|
252 |
+
|
253 |
+
def reset_to_level(self, rng, level, params=None):
|
254 |
+
obs, env_state = self._env.reset_to_level(rng, level, params)
|
255 |
+
return obs, DenseRewardState(env_state, -1.0)
|
256 |
+
|
257 |
+
|
258 |
+
@struct.dataclass
|
259 |
+
class LogEnvState:
|
260 |
+
env_state: Any
|
261 |
+
episode_returns: float
|
262 |
+
episode_lengths: int
|
263 |
+
returned_episode_returns: float
|
264 |
+
returned_episode_lengths: int
|
265 |
+
timestep: int
|
266 |
+
|
267 |
+
|
268 |
+
class LogWrapper(GymnaxWrapper):
|
269 |
+
"""Log the episode returns and lengths."""
|
270 |
+
|
271 |
+
def __init__(self, env):
|
272 |
+
super().__init__(env)
|
273 |
+
|
274 |
+
@partial(jax.jit, static_argnums=(0, 2))
|
275 |
+
def reset(self, key: chex.PRNGKey, params=None):
|
276 |
+
obs, env_state = self._env.reset(key, params)
|
277 |
+
state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
|
278 |
+
return obs, state
|
279 |
+
|
280 |
+
def reset_to_level(self, key: chex.PRNGKey, level: EnvState, params=None):
|
281 |
+
obs, env_state = self._env.reset_to_level(key, level, params)
|
282 |
+
state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
|
283 |
+
return obs, state
|
284 |
+
|
285 |
+
@partial(jax.jit, static_argnums=(0, 4))
|
286 |
+
def step(
|
287 |
+
self,
|
288 |
+
key: chex.PRNGKey,
|
289 |
+
state,
|
290 |
+
action: Union[int, float],
|
291 |
+
params=None,
|
292 |
+
):
|
293 |
+
obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
|
294 |
+
new_episode_return = state.episode_returns + reward
|
295 |
+
new_episode_length = state.episode_lengths + 1
|
296 |
+
state = LogEnvState(
|
297 |
+
env_state=env_state,
|
298 |
+
episode_returns=new_episode_return * (1 - done),
|
299 |
+
episode_lengths=new_episode_length * (1 - done),
|
300 |
+
returned_episode_returns=state.returned_episode_returns * (1 - done) + new_episode_return * done,
|
301 |
+
returned_episode_lengths=state.returned_episode_lengths * (1 - done) + new_episode_length * done,
|
302 |
+
timestep=state.timestep + 1,
|
303 |
+
)
|
304 |
+
info["returned_episode_returns"] = state.returned_episode_returns
|
305 |
+
info["returned_episode_lengths"] = state.returned_episode_lengths
|
306 |
+
info["returned_episode_solved"] = info["GoalR"]
|
307 |
+
info["timestep"] = state.timestep
|
308 |
+
info["returned_episode"] = done
|
309 |
+
return obs, state, reward, done, info
|
kinetix/models/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
actor_critic_old.py
|
2 |
+
gactor_gritic_old.py
|
kinetix/models/__init__.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from kinetix.models.actor_critic import (
|
2 |
+
ActorCriticPixelsRNN,
|
3 |
+
ActorCriticSymbolicRNN,
|
4 |
+
)
|
5 |
+
from kinetix.models.transformer_model import ActorCriticTransformer
|
6 |
+
|
7 |
+
|
8 |
+
def make_network_from_config(env, env_params, config, network_kws={}):
|
9 |
+
|
10 |
+
env_name = config["env_name"]
|
11 |
+
if "MultiDiscrete" in env_name:
|
12 |
+
action_mode = "multi_discrete"
|
13 |
+
elif "Discrete" in env_name:
|
14 |
+
action_mode = "discrete"
|
15 |
+
elif "Continuous" in env_name:
|
16 |
+
action_mode = "continuous"
|
17 |
+
elif "Hybrid" in env_name:
|
18 |
+
action_mode = "hybrid"
|
19 |
+
else:
|
20 |
+
raise ValueError(f"Unknown action mode for {env_name}")
|
21 |
+
action_dim = (
|
22 |
+
env.action_space(env_params).shape[0] if action_mode == "continuous" else env.action_space(env_params).n
|
23 |
+
)
|
24 |
+
if "hybrid_action_continuous_dim" not in network_kws:
|
25 |
+
network_kws["hybrid_action_continuous_dim"] = action_dim
|
26 |
+
|
27 |
+
if "multi_discrete_number_of_dims_per_distribution" not in network_kws:
|
28 |
+
num_joint_bindings = config["static_env_params"]["num_motor_bindings"]
|
29 |
+
num_thruster_bindings = config["static_env_params"]["num_thruster_bindings"]
|
30 |
+
network_kws["multi_discrete_number_of_dims_per_distribution"] = [3 for _ in range(num_joint_bindings)] + [
|
31 |
+
2 for _ in range(num_thruster_bindings)
|
32 |
+
]
|
33 |
+
network_kws["recurrent"] = config.get("recurrent_model", True)
|
34 |
+
|
35 |
+
if "Pixels" in env_name:
|
36 |
+
cls_to_use = ActorCriticPixelsRNN
|
37 |
+
elif "Symbolic" in env_name or "Blind" in env_name:
|
38 |
+
cls_to_use = ActorCriticSymbolicRNN
|
39 |
+
|
40 |
+
if "Entity" in env_name:
|
41 |
+
network = ActorCriticTransformer(
|
42 |
+
action_dim=action_dim,
|
43 |
+
fc_layer_width=config["fc_layer_width"],
|
44 |
+
fc_layer_depth=config["fc_layer_depth"],
|
45 |
+
action_mode=action_mode,
|
46 |
+
num_heads=config["num_heads"],
|
47 |
+
transformer_depth=config["transformer_depth"],
|
48 |
+
transformer_size=config["transformer_size"],
|
49 |
+
transformer_encoder_size=config["transformer_encoder_size"],
|
50 |
+
aggregate_mode=config["aggregate_mode"],
|
51 |
+
full_attention_mask=config["full_attention_mask"],
|
52 |
+
activation=config["activation"],
|
53 |
+
**network_kws,
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
network = cls_to_use(
|
57 |
+
action_dim,
|
58 |
+
fc_layer_width=config["fc_layer_width"],
|
59 |
+
fc_layer_depth=config["fc_layer_depth"],
|
60 |
+
activation=config["activation"],
|
61 |
+
action_mode=action_mode,
|
62 |
+
**network_kws,
|
63 |
+
)
|
64 |
+
|
65 |
+
return network
|
kinetix/models/action_spaces.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Sequence
|
2 |
+
from chex import PRNGKey
|
3 |
+
import distrax
|
4 |
+
from flax import struct
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
|
8 |
+
|
9 |
+
@struct.dataclass
|
10 |
+
class HybridAction:
|
11 |
+
discrete: int
|
12 |
+
continuous: jnp.ndarray
|
13 |
+
|
14 |
+
|
15 |
+
class HybridActionDistribution(distrax.Distribution):
|
16 |
+
def __init__(self, discrete_logits, continuous_mu, continuous_sigma) -> None:
|
17 |
+
self.discrete = distrax.Categorical(logits=discrete_logits)
|
18 |
+
self.continuous = distrax.MultivariateNormalDiag(continuous_mu, continuous_sigma)
|
19 |
+
|
20 |
+
def _sample_n(self, rng: PRNGKey, n: int) -> Any:
|
21 |
+
rng, _rng, _rng2 = jax.random.split(rng, 3)
|
22 |
+
a = self.discrete._sample_n(_rng, n)
|
23 |
+
b = self.continuous._sample_n(_rng2, n)
|
24 |
+
return HybridAction(a, b)
|
25 |
+
|
26 |
+
def log_prob(self, value: Any):
|
27 |
+
a = self.discrete.log_prob(value.discrete)
|
28 |
+
b = self.continuous.log_prob(value.continuous)
|
29 |
+
return a + b # log probs, we add.
|
30 |
+
|
31 |
+
def entropy(self):
|
32 |
+
return self.discrete.entropy() + self.continuous.entropy()
|
33 |
+
|
34 |
+
def event_shape(self) -> Sequence[int]:
|
35 |
+
return ()
|
36 |
+
|
37 |
+
|
38 |
+
class MultiDiscreteActionDistribution(distrax.Distribution):
|
39 |
+
def __init__(self, flat_logits, number_of_dims_per_distribution) -> None:
|
40 |
+
self.distributions = []
|
41 |
+
total_dims = 0
|
42 |
+
for dims in number_of_dims_per_distribution:
|
43 |
+
self.distributions.append(distrax.Categorical(logits=flat_logits[..., total_dims : total_dims + dims]))
|
44 |
+
total_dims += dims
|
45 |
+
|
46 |
+
def _sample_n(self, key: PRNGKey, n: int) -> Any:
|
47 |
+
rngs = jax.random.split(key, len(self.distributions))
|
48 |
+
samples = [jnp.expand_dims(d._sample_n(rng, n), axis=-1) for rng, d in zip(rngs, self.distributions)]
|
49 |
+
return jnp.concatenate(samples, axis=-1)
|
50 |
+
|
51 |
+
def log_prob(self, value: Any):
|
52 |
+
return sum(d.log_prob(value[..., i]) for i, d in enumerate(self.distributions))
|
53 |
+
|
54 |
+
def entropy(self):
|
55 |
+
return sum(d.entropy() for d in self.distributions)
|
56 |
+
|
57 |
+
def event_shape(self) -> Sequence[int]:
|
58 |
+
return ()
|
kinetix/models/actor_critic.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import jax
|
3 |
+
import jax.numpy as jnp
|
4 |
+
import flax.linen as nn
|
5 |
+
import numpy as np
|
6 |
+
from flax.linen.initializers import constant, orthogonal
|
7 |
+
from typing import List, Sequence
|
8 |
+
|
9 |
+
import distrax
|
10 |
+
|
11 |
+
from kinetix.models.action_spaces import HybridActionDistribution, MultiDiscreteActionDistribution
|
12 |
+
|
13 |
+
|
14 |
+
class ScannedRNN(nn.Module):
|
15 |
+
@functools.partial(
|
16 |
+
nn.scan,
|
17 |
+
variable_broadcast="params",
|
18 |
+
in_axes=0,
|
19 |
+
out_axes=0,
|
20 |
+
split_rngs={"params": False},
|
21 |
+
)
|
22 |
+
@nn.compact
|
23 |
+
def __call__(self, carry, x):
|
24 |
+
"""Applies the module."""
|
25 |
+
rnn_state = carry
|
26 |
+
ins, resets = x
|
27 |
+
rnn_state = jnp.where(
|
28 |
+
resets[:, np.newaxis],
|
29 |
+
self.initialize_carry(ins.shape[0], 256),
|
30 |
+
rnn_state,
|
31 |
+
)
|
32 |
+
new_rnn_state, y = nn.GRUCell(features=256)(rnn_state, ins)
|
33 |
+
return new_rnn_state, y
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def initialize_carry(batch_size, hidden_size=256):
|
37 |
+
# Use a dummy key since the default state init fn is just zeros.
|
38 |
+
cell = nn.GRUCell(features=256)
|
39 |
+
return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
|
40 |
+
|
41 |
+
|
42 |
+
class GeneralActorCriticRNN(nn.Module):
|
43 |
+
action_dim: Sequence[int]
|
44 |
+
fc_layer_depth: int
|
45 |
+
fc_layer_width: int
|
46 |
+
action_mode: str # "continuous" or "discrete" or "hybrid"
|
47 |
+
hybrid_action_continuous_dim: int
|
48 |
+
multi_discrete_number_of_dims_per_distribution: List[int]
|
49 |
+
add_generator_embedding: bool = False
|
50 |
+
generator_embedding_number_of_timesteps: int = 10
|
51 |
+
recurrent: bool = False
|
52 |
+
|
53 |
+
# Given an embedding, return the action/values, since this is shared across all models.
|
54 |
+
@nn.compact
|
55 |
+
def __call__(self, hidden, obs, embedding, dones, activation):
|
56 |
+
|
57 |
+
if self.add_generator_embedding:
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
if self.recurrent:
|
61 |
+
rnn_in = (embedding, dones)
|
62 |
+
hidden, embedding = ScannedRNN()(hidden, rnn_in)
|
63 |
+
|
64 |
+
actor_mean = embedding
|
65 |
+
critic = embedding
|
66 |
+
actor_mean_last = embedding
|
67 |
+
for _ in range(self.fc_layer_depth):
|
68 |
+
actor_mean = nn.Dense(
|
69 |
+
self.fc_layer_width,
|
70 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
71 |
+
bias_init=constant(0.0),
|
72 |
+
)(actor_mean)
|
73 |
+
actor_mean = activation(actor_mean)
|
74 |
+
|
75 |
+
critic = nn.Dense(
|
76 |
+
self.fc_layer_width,
|
77 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
78 |
+
bias_init=constant(0.0),
|
79 |
+
)(critic)
|
80 |
+
critic = activation(critic)
|
81 |
+
|
82 |
+
actor_mean_last = actor_mean
|
83 |
+
actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
|
84 |
+
if self.action_mode == "discrete":
|
85 |
+
pi = distrax.Categorical(logits=actor_mean)
|
86 |
+
elif self.action_mode == "continuous":
|
87 |
+
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
|
88 |
+
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
|
89 |
+
elif self.action_mode == "multi_discrete":
|
90 |
+
pi = MultiDiscreteActionDistribution(actor_mean, self.multi_discrete_number_of_dims_per_distribution)
|
91 |
+
else:
|
92 |
+
actor_mean_continuous = nn.Dense(
|
93 |
+
self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
|
94 |
+
)(actor_mean_last)
|
95 |
+
actor_mean_sigma = jnp.exp(
|
96 |
+
nn.Dense(self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(
|
97 |
+
actor_mean_last
|
98 |
+
)
|
99 |
+
)
|
100 |
+
pi = HybridActionDistribution(actor_mean, actor_mean_continuous, actor_mean_sigma)
|
101 |
+
|
102 |
+
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
|
103 |
+
return hidden, pi, jnp.squeeze(critic, axis=-1)
|
104 |
+
|
105 |
+
|
106 |
+
class ActorCriticPixelsRNN(nn.Module):
|
107 |
+
|
108 |
+
action_dim: Sequence[int]
|
109 |
+
fc_layer_depth: int
|
110 |
+
fc_layer_width: int
|
111 |
+
action_mode: str
|
112 |
+
hybrid_action_continuous_dim: int
|
113 |
+
multi_discrete_number_of_dims_per_distribution: List[int]
|
114 |
+
activation: str
|
115 |
+
add_generator_embedding: bool = False
|
116 |
+
generator_embedding_number_of_timesteps: int = 10
|
117 |
+
recurrent: bool = True
|
118 |
+
|
119 |
+
@nn.compact
|
120 |
+
def __call__(self, hidden, x, **kwargs):
|
121 |
+
if self.activation == "relu":
|
122 |
+
activation = nn.relu
|
123 |
+
else:
|
124 |
+
activation = nn.tanh
|
125 |
+
og_obs, dones = x
|
126 |
+
|
127 |
+
if self.add_generator_embedding:
|
128 |
+
obs = og_obs.obs
|
129 |
+
else:
|
130 |
+
obs = og_obs
|
131 |
+
|
132 |
+
image = obs.image
|
133 |
+
global_info = obs.global_info
|
134 |
+
|
135 |
+
x = nn.Conv(features=16, kernel_size=(8, 8), strides=(4, 4))(image)
|
136 |
+
x = nn.relu(x)
|
137 |
+
x = nn.Conv(features=32, kernel_size=(4, 4), strides=(2, 2))(x)
|
138 |
+
x = nn.relu(x)
|
139 |
+
embedding = x.reshape(x.shape[0], x.shape[1], -1)
|
140 |
+
|
141 |
+
embedding = jnp.concatenate([embedding, global_info], axis=-1)
|
142 |
+
|
143 |
+
return GeneralActorCriticRNN(
|
144 |
+
action_dim=self.action_dim,
|
145 |
+
fc_layer_depth=self.fc_layer_depth,
|
146 |
+
fc_layer_width=self.fc_layer_width,
|
147 |
+
action_mode=self.action_mode,
|
148 |
+
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
|
149 |
+
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
|
150 |
+
add_generator_embedding=self.add_generator_embedding,
|
151 |
+
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
|
152 |
+
recurrent=self.recurrent,
|
153 |
+
)(hidden, og_obs, embedding, dones, activation)
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def initialize_carry(batch_size, hidden_size=256):
|
157 |
+
return ScannedRNN.initialize_carry(batch_size, hidden_size)
|
158 |
+
|
159 |
+
|
160 |
+
class ActorCriticSymbolicRNN(nn.Module):
|
161 |
+
action_dim: Sequence[int]
|
162 |
+
fc_layer_width: int
|
163 |
+
action_mode: str
|
164 |
+
hybrid_action_continuous_dim: int
|
165 |
+
multi_discrete_number_of_dims_per_distribution: List[int]
|
166 |
+
fc_layer_depth: int
|
167 |
+
activation: str
|
168 |
+
add_generator_embedding: bool = False
|
169 |
+
generator_embedding_number_of_timesteps: int = 10
|
170 |
+
recurrent: bool = True
|
171 |
+
|
172 |
+
@nn.compact
|
173 |
+
def __call__(self, hidden, x):
|
174 |
+
if self.activation == "relu":
|
175 |
+
activation = nn.relu
|
176 |
+
else:
|
177 |
+
activation = nn.tanh
|
178 |
+
|
179 |
+
og_obs, dones = x
|
180 |
+
if self.add_generator_embedding:
|
181 |
+
obs = og_obs.obs
|
182 |
+
else:
|
183 |
+
obs = og_obs
|
184 |
+
|
185 |
+
embedding = nn.Dense(
|
186 |
+
self.fc_layer_width,
|
187 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
188 |
+
bias_init=constant(0.0),
|
189 |
+
)(obs)
|
190 |
+
embedding = nn.relu(embedding)
|
191 |
+
|
192 |
+
return GeneralActorCriticRNN(
|
193 |
+
action_dim=self.action_dim,
|
194 |
+
fc_layer_depth=self.fc_layer_depth,
|
195 |
+
fc_layer_width=self.fc_layer_width,
|
196 |
+
action_mode=self.action_mode,
|
197 |
+
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
|
198 |
+
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
|
199 |
+
add_generator_embedding=self.add_generator_embedding,
|
200 |
+
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
|
201 |
+
recurrent=self.recurrent,
|
202 |
+
)(hidden, og_obs, embedding, dones, activation)
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def initialize_carry(batch_size, hidden_size=256):
|
206 |
+
return ScannedRNN.initialize_carry(batch_size, hidden_size)
|
kinetix/models/rel_multi_head.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The Flax Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
# CODE IS HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
|
17 |
+
# MOST OF THE TIME JUST A CONVERSION IN JAX
|
18 |
+
|
19 |
+
|
20 |
+
"""Relative Attention HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
|
21 |
+
, flax attention, https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L143, most of the time just a flax/jax conversion """
|
22 |
+
|
23 |
+
import functools
|
24 |
+
from typing import Any, Callable, Optional, Tuple
|
25 |
+
from flax.linen.dtypes import promote_dtype
|
26 |
+
|
27 |
+
from flax.linen import initializers
|
28 |
+
from flax.linen.linear import default_kernel_init
|
29 |
+
from flax.linen.linear import DenseGeneral
|
30 |
+
from flax.linen.linear import DotGeneralT
|
31 |
+
from flax.linen.linear import PrecisionLike
|
32 |
+
from flax.linen.module import compact
|
33 |
+
from flax.linen.module import merge_param
|
34 |
+
from flax.linen.module import Module
|
35 |
+
|
36 |
+
import jax
|
37 |
+
from jax import lax
|
38 |
+
from jax import random
|
39 |
+
import jax.numpy as jnp
|
40 |
+
|
41 |
+
PRNGKey = Any
|
42 |
+
Shape = Tuple[int, ...]
|
43 |
+
Dtype = Any
|
44 |
+
Array = Any
|
45 |
+
|
46 |
+
roll_vmap = jax.vmap(jnp.roll, in_axes=(-2, 0, None), out_axes=-2)
|
47 |
+
|
48 |
+
|
49 |
+
def _rel_shift(x):
|
50 |
+
zero_pad_shape = x.shape[:-2] + (x.shape[-2], 1)
|
51 |
+
zero_pad = jnp.zeros(zero_pad_shape, dtype=x.dtype)
|
52 |
+
x_padded = jnp.concatenate([zero_pad, x], axis=-1)
|
53 |
+
|
54 |
+
x_padded_shape = x.shape[:-2] + (x.shape[-1] + 1, x.shape[-2])
|
55 |
+
x_padded = x_padded.reshape(x_padded_shape)
|
56 |
+
# x_padded=jnp.swapaxes(x_padded,0,1)
|
57 |
+
|
58 |
+
x = jnp.take(x_padded, jnp.arange(1, x_padded.shape[-2]), axis=-2).reshape(x.shape)
|
59 |
+
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def dot_product_attention_weights(
|
64 |
+
query: Array,
|
65 |
+
key: Array,
|
66 |
+
r_pos_embed,
|
67 |
+
r_r_bias,
|
68 |
+
r_w_bias,
|
69 |
+
bias: Optional[Array] = None,
|
70 |
+
mask: Optional[Array] = None,
|
71 |
+
broadcast_dropout: bool = True,
|
72 |
+
dropout_rng: Optional[PRNGKey] = None,
|
73 |
+
dropout_rate: float = 0.0,
|
74 |
+
deterministic: bool = False,
|
75 |
+
dtype: Optional[Dtype] = None,
|
76 |
+
precision: PrecisionLike = None,
|
77 |
+
):
|
78 |
+
"""Computes dot-product attention weights given query and key.
|
79 |
+
|
80 |
+
Used by :func:`dot_product_attention`, which is what you'll most likely use.
|
81 |
+
But if you want access to the attention weights for introspection, then
|
82 |
+
you can directly call this function and call einsum yourself.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
query: queries for calculating attention with shape of
|
86 |
+
`[batch..., q_length, num_heads, qk_depth_per_head]`.
|
87 |
+
key: keys for calculating attention with shape of
|
88 |
+
`[batch..., kv_length, num_heads, qk_depth_per_head]`.
|
89 |
+
bias: bias for the attention weights. This should be broadcastable to the
|
90 |
+
shape `[batch..., num_heads, q_length, kv_length]`.
|
91 |
+
This can be used for incorporating causal masks, padding masks,
|
92 |
+
proximity bias, etc.
|
93 |
+
mask: mask for the attention weights. This should be broadcastable to the
|
94 |
+
shape `[batch..., num_heads, q_length, kv_length]`.
|
95 |
+
This can be used for incorporating causal masks.
|
96 |
+
Attention weights are masked out if their corresponding mask value
|
97 |
+
is `False`.
|
98 |
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
99 |
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
100 |
+
dropout_rate: dropout rate
|
101 |
+
deterministic: bool, deterministic or not (to apply dropout)
|
102 |
+
dtype: the dtype of the computation (default: infer from inputs and params)
|
103 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
104 |
+
for details.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Output of shape `[batch..., num_heads, q_length, kv_length]`.
|
108 |
+
"""
|
109 |
+
query, key = promote_dtype(query, key, dtype=dtype)
|
110 |
+
dtype = query.dtype
|
111 |
+
|
112 |
+
assert query.ndim == key.ndim, "q, k must have same rank."
|
113 |
+
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
|
114 |
+
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
|
115 |
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
116 |
+
|
117 |
+
# calculate attention matrix
|
118 |
+
depth = query.shape[-1]
|
119 |
+
# query = query
|
120 |
+
|
121 |
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
122 |
+
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query + r_w_bias, key, precision=precision)
|
123 |
+
|
124 |
+
attn_weights_r = jnp.einsum("...qhd,khd->...hqk", query + r_r_bias, r_pos_embed, precision=precision)
|
125 |
+
|
126 |
+
attn_weights_r = roll_vmap(attn_weights_r, jnp.arange(0, query.shape[-3]) - (query.shape[-3] - 1), -1)
|
127 |
+
# attn_weights_r=_rel_shift(attn_weights_r)
|
128 |
+
attn_weights = attn_weights + attn_weights_r
|
129 |
+
|
130 |
+
attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
|
131 |
+
|
132 |
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
133 |
+
if bias is not None:
|
134 |
+
attn_weights = attn_weights + bias
|
135 |
+
# apply attention mask
|
136 |
+
if mask is not None:
|
137 |
+
big_neg = jnp.finfo(dtype).min
|
138 |
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
139 |
+
|
140 |
+
# normalize the attention weights
|
141 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
142 |
+
|
143 |
+
# apply attention dropout
|
144 |
+
if not deterministic and dropout_rate > 0.0:
|
145 |
+
keep_prob = 1.0 - dropout_rate
|
146 |
+
if broadcast_dropout:
|
147 |
+
# dropout is broadcast across the batch + head dimensions
|
148 |
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
149 |
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
|
150 |
+
else:
|
151 |
+
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
|
152 |
+
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
153 |
+
attn_weights = attn_weights * multiplier
|
154 |
+
|
155 |
+
return attn_weights
|
156 |
+
|
157 |
+
|
158 |
+
def dot_product_attention(
|
159 |
+
query: Array,
|
160 |
+
key: Array,
|
161 |
+
value: Array,
|
162 |
+
r_pos_embed,
|
163 |
+
r_r_bias,
|
164 |
+
r_w_bias,
|
165 |
+
bias: Optional[Array] = None,
|
166 |
+
mask: Optional[Array] = None,
|
167 |
+
broadcast_dropout: bool = True,
|
168 |
+
dropout_rng: Optional[PRNGKey] = None,
|
169 |
+
dropout_rate: float = 0.0,
|
170 |
+
deterministic: bool = False,
|
171 |
+
dtype: Optional[Dtype] = None,
|
172 |
+
precision: PrecisionLike = None,
|
173 |
+
):
|
174 |
+
"""Computes dot-product attention given query, key, and value.
|
175 |
+
|
176 |
+
This is the core function for applying attention based on
|
177 |
+
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
178 |
+
query and key and combines the values using the attention weights.
|
179 |
+
|
180 |
+
Note: query, key, value needn't have any batch dimensions.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
query: queries for calculating attention with shape of
|
184 |
+
`[batch..., q_length, num_heads, qk_depth_per_head]`.
|
185 |
+
key: keys for calculating attention with shape of
|
186 |
+
`[batch..., kv_length, num_heads, qk_depth_per_head]`.
|
187 |
+
value: values to be used in attention with shape of
|
188 |
+
`[batch..., kv_length, num_heads, v_depth_per_head]`.
|
189 |
+
bias: bias for the attention weights. This should be broadcastable to the
|
190 |
+
shape `[batch..., num_heads, q_length, kv_length]`.
|
191 |
+
This can be used for incorporating causal masks, padding masks,
|
192 |
+
proximity bias, etc.
|
193 |
+
mask: mask for the attention weights. This should be broadcastable to the
|
194 |
+
shape `[batch..., num_heads, q_length, kv_length]`.
|
195 |
+
This can be used for incorporating causal masks.
|
196 |
+
Attention weights are masked out if their corresponding mask value
|
197 |
+
is `False`.
|
198 |
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
199 |
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
200 |
+
dropout_rate: dropout rate
|
201 |
+
deterministic: bool, deterministic or not (to apply dropout)
|
202 |
+
dtype: the dtype of the computation (default: infer from inputs)
|
203 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
204 |
+
for details.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
|
208 |
+
"""
|
209 |
+
query, key, value = promote_dtype(query, key, value, dtype=dtype)
|
210 |
+
dtype = query.dtype
|
211 |
+
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
|
212 |
+
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
|
213 |
+
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
|
214 |
+
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
|
215 |
+
|
216 |
+
# compute attention weights
|
217 |
+
attn_weights = dot_product_attention_weights(
|
218 |
+
query,
|
219 |
+
key,
|
220 |
+
r_pos_embed,
|
221 |
+
r_r_bias,
|
222 |
+
r_w_bias,
|
223 |
+
bias,
|
224 |
+
mask,
|
225 |
+
broadcast_dropout,
|
226 |
+
dropout_rng,
|
227 |
+
dropout_rate,
|
228 |
+
deterministic,
|
229 |
+
dtype,
|
230 |
+
precision,
|
231 |
+
)
|
232 |
+
|
233 |
+
# return weighted sum over values for each query position
|
234 |
+
return jnp.einsum("...hqk,...khd->...qhd", attn_weights, value, precision=precision)
|
235 |
+
|
236 |
+
|
237 |
+
class RelMultiHeadDotProductAttention(Module):
|
238 |
+
"""Multi-head dot-product attention.
|
239 |
+
|
240 |
+
Attributes:
|
241 |
+
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
|
242 |
+
should be divisible by the number of heads.
|
243 |
+
dtype: the dtype of the computation (default: infer from inputs and params)
|
244 |
+
param_dtype: the dtype passed to parameter initializers (default: float32)
|
245 |
+
qkv_features: dimension of the key, query, and value.
|
246 |
+
out_features: dimension of the last projection
|
247 |
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
248 |
+
dropout_rate: dropout rate
|
249 |
+
deterministic: if false, the attention weight is masked randomly using
|
250 |
+
dropout, whereas if true, the attention weights are deterministic.
|
251 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
252 |
+
for details.
|
253 |
+
kernel_init: initializer for the kernel of the Dense layers.
|
254 |
+
bias_init: initializer for the bias of the Dense layers.
|
255 |
+
use_bias: bool: whether pointwise QKVO dense transforms use bias.
|
256 |
+
attention_fn: dot_product_attention or compatible function. Accepts query,
|
257 |
+
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
|
258 |
+
num_heads, value_channels]``
|
259 |
+
decode: whether to prepare and use an autoregressive cache.
|
260 |
+
"""
|
261 |
+
|
262 |
+
num_heads: int
|
263 |
+
dtype: Optional[Dtype] = None
|
264 |
+
param_dtype: Dtype = jnp.float32
|
265 |
+
qkv_features: Optional[int] = None
|
266 |
+
out_features: Optional[int] = None
|
267 |
+
broadcast_dropout: bool = True
|
268 |
+
dropout_rate: float = 0.0
|
269 |
+
deterministic: Optional[bool] = None
|
270 |
+
precision: PrecisionLike = None
|
271 |
+
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
|
272 |
+
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
|
273 |
+
use_bias: bool = True
|
274 |
+
attention_fn: Callable[..., Array] = dot_product_attention
|
275 |
+
decode: bool = False
|
276 |
+
qkv_dot_general: DotGeneralT = lax.dot_general
|
277 |
+
out_dot_general: DotGeneralT = lax.dot_general
|
278 |
+
|
279 |
+
@compact
|
280 |
+
def __call__(
|
281 |
+
self,
|
282 |
+
inputs_q: Array,
|
283 |
+
inputs_kv: Array,
|
284 |
+
pos_embed: Array,
|
285 |
+
mask: Optional[Array] = None,
|
286 |
+
deterministic: Optional[bool] = None,
|
287 |
+
):
|
288 |
+
"""Applies multi-head dot product attention on the input data.
|
289 |
+
|
290 |
+
Projects the inputs into multi-headed query, key, and value vectors,
|
291 |
+
applies dot-product attention and project the results to an output vector.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
inputs_q: input queries of shape
|
295 |
+
`[batch_sizes..., length, features]`.
|
296 |
+
inputs_kv: key/values of shape
|
297 |
+
`[batch_sizes..., length, features]`.
|
298 |
+
mask: attention mask of shape
|
299 |
+
`[batch_sizes..., num_heads, query_length, key/value_length]`.
|
300 |
+
Attention weights are masked out if their corresponding mask value
|
301 |
+
is `False`.
|
302 |
+
deterministic: if false, the attention weight is masked randomly
|
303 |
+
using dropout, whereas if true, the attention weights
|
304 |
+
are deterministic.
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
output of shape `[batch_sizes..., length, features]`.
|
308 |
+
"""
|
309 |
+
features = self.out_features or inputs_q.shape[-1]
|
310 |
+
qkv_features = self.qkv_features or inputs_q.shape[-1]
|
311 |
+
assert qkv_features % self.num_heads == 0, (
|
312 |
+
f"Memory dimension ({qkv_features}) must be divisible by number of" f" heads ({self.num_heads})."
|
313 |
+
)
|
314 |
+
head_dim = qkv_features // self.num_heads
|
315 |
+
|
316 |
+
dense = functools.partial(
|
317 |
+
DenseGeneral,
|
318 |
+
axis=-1,
|
319 |
+
dtype=self.dtype,
|
320 |
+
param_dtype=self.param_dtype,
|
321 |
+
features=(self.num_heads, head_dim),
|
322 |
+
kernel_init=self.kernel_init,
|
323 |
+
bias_init=self.bias_init,
|
324 |
+
use_bias=self.use_bias,
|
325 |
+
precision=self.precision,
|
326 |
+
dot_general=self.qkv_dot_general,
|
327 |
+
)
|
328 |
+
# project inputs_q to multi-headed q/k/v
|
329 |
+
# dimensions are then [batch..., length, n_heads, n_features_per_head]
|
330 |
+
query, key, value = (
|
331 |
+
dense(name="query")(inputs_q),
|
332 |
+
dense(name="key")(inputs_kv),
|
333 |
+
dense(name="value")(inputs_kv),
|
334 |
+
)
|
335 |
+
|
336 |
+
# different bc no bias
|
337 |
+
dense_relpos = functools.partial(
|
338 |
+
DenseGeneral,
|
339 |
+
axis=-1,
|
340 |
+
dtype=self.dtype,
|
341 |
+
param_dtype=self.param_dtype,
|
342 |
+
features=(self.num_heads, head_dim),
|
343 |
+
kernel_init=self.kernel_init,
|
344 |
+
use_bias=False,
|
345 |
+
precision=self.precision,
|
346 |
+
dot_general=self.qkv_dot_general,
|
347 |
+
)
|
348 |
+
|
349 |
+
r_pos_embed = dense_relpos(name="pos_embed_mat")(pos_embed)
|
350 |
+
|
351 |
+
r_r_bias = self.param("r_r_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function
|
352 |
+
r_w_bias = self.param("r_w_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function
|
353 |
+
|
354 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
355 |
+
# and cache the keys and values step by step.
|
356 |
+
if self.decode:
|
357 |
+
# detect if we're initializing by absence of existing cache data.
|
358 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
359 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
360 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
361 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
362 |
+
if is_initialized:
|
363 |
+
(
|
364 |
+
*batch_dims,
|
365 |
+
max_length,
|
366 |
+
num_heads,
|
367 |
+
depth_per_head,
|
368 |
+
) = cached_key.value.shape
|
369 |
+
# shape check of cached keys against query input
|
370 |
+
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
|
371 |
+
if expected_shape != query.shape:
|
372 |
+
raise ValueError(
|
373 |
+
"Autoregressive cache shape error, "
|
374 |
+
"expected query shape %s instead got %s." % (expected_shape, query.shape)
|
375 |
+
)
|
376 |
+
# update key, value caches with our new 1d spatial slices
|
377 |
+
cur_index = cache_index.value
|
378 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
379 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
380 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
381 |
+
cached_key.value = key
|
382 |
+
cached_value.value = value
|
383 |
+
cache_index.value = cache_index.value + 1
|
384 |
+
# causal mask for cached decoder self-attention:
|
385 |
+
# our single query position should only attend to those key
|
386 |
+
# positions that have already been generated and cached,
|
387 |
+
# not the remaining zero elements.
|
388 |
+
mask = combine_masks(
|
389 |
+
mask,
|
390 |
+
jnp.broadcast_to(
|
391 |
+
jnp.arange(max_length) <= cur_index,
|
392 |
+
tuple(batch_dims) + (1, 1, max_length),
|
393 |
+
),
|
394 |
+
)
|
395 |
+
|
396 |
+
dropout_rng = None
|
397 |
+
if self.dropout_rate > 0.0: # Require `deterministic` only if using dropout.
|
398 |
+
m_deterministic = merge_param("deterministic", self.deterministic, deterministic)
|
399 |
+
if not m_deterministic:
|
400 |
+
dropout_rng = self.make_rng("dropout")
|
401 |
+
else:
|
402 |
+
m_deterministic = True
|
403 |
+
|
404 |
+
# apply attention
|
405 |
+
x = self.attention_fn(
|
406 |
+
query,
|
407 |
+
key,
|
408 |
+
value,
|
409 |
+
r_pos_embed,
|
410 |
+
r_r_bias,
|
411 |
+
r_w_bias,
|
412 |
+
mask=mask,
|
413 |
+
dropout_rng=dropout_rng,
|
414 |
+
dropout_rate=self.dropout_rate,
|
415 |
+
broadcast_dropout=self.broadcast_dropout,
|
416 |
+
deterministic=m_deterministic,
|
417 |
+
dtype=self.dtype,
|
418 |
+
precision=self.precision,
|
419 |
+
) # pytype: disable=wrong-keyword-args
|
420 |
+
# back to the original inputs dimensions
|
421 |
+
out = DenseGeneral(
|
422 |
+
features=features,
|
423 |
+
axis=(-2, -1),
|
424 |
+
kernel_init=self.kernel_init,
|
425 |
+
bias_init=self.bias_init,
|
426 |
+
use_bias=self.use_bias,
|
427 |
+
dtype=self.dtype,
|
428 |
+
param_dtype=self.param_dtype,
|
429 |
+
precision=self.precision,
|
430 |
+
dot_general=self.out_dot_general,
|
431 |
+
name="out", # type: ignore[call-arg]
|
432 |
+
)(x)
|
433 |
+
return out
|
434 |
+
|
435 |
+
|
436 |
+
class SelfAttention(RelMultiHeadDotProductAttention):
|
437 |
+
"""Self-attention special case of multi-head dot-product attention."""
|
438 |
+
|
439 |
+
@compact
|
440 |
+
def __call__( # type: ignore
|
441 |
+
self,
|
442 |
+
inputs_q: Array,
|
443 |
+
mask: Optional[Array] = None,
|
444 |
+
deterministic: Optional[bool] = None,
|
445 |
+
):
|
446 |
+
"""Applies multi-head dot product self-attention on the input data.
|
447 |
+
|
448 |
+
Projects the inputs into multi-headed query, key, and value vectors,
|
449 |
+
applies dot-product attention and project the results to an output vector.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
inputs_q: input queries of shape
|
453 |
+
`[batch_sizes..., length, features]`.
|
454 |
+
mask: attention mask of shape
|
455 |
+
`[batch_sizes..., num_heads, query_length, key/value_length]`.
|
456 |
+
Attention weights are masked out if their corresponding mask value
|
457 |
+
is `False`.
|
458 |
+
deterministic: if false, the attention weight is masked randomly
|
459 |
+
using dropout, whereas if true, the attention weights
|
460 |
+
are deterministic.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
output of shape `[batch_sizes..., length, features]`.
|
464 |
+
"""
|
465 |
+
return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)
|
466 |
+
|
467 |
+
|
468 |
+
# mask-making utility functions
|
469 |
+
|
470 |
+
|
471 |
+
def make_attention_mask(
|
472 |
+
query_input: Array,
|
473 |
+
key_input: Array,
|
474 |
+
pairwise_fn: Callable[..., Any] = jnp.multiply,
|
475 |
+
extra_batch_dims: int = 0,
|
476 |
+
dtype: Dtype = jnp.float32,
|
477 |
+
):
|
478 |
+
"""Mask-making helper for attention weights.
|
479 |
+
|
480 |
+
In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
|
481 |
+
attention weights will be `[batch..., heads, len_q, len_kv]` and this
|
482 |
+
function will produce `[batch..., 1, len_q, len_kv]`.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
query_input: a batched, flat input of query_length size
|
486 |
+
key_input: a batched, flat input of key_length size
|
487 |
+
pairwise_fn: broadcasting elementwise comparison function
|
488 |
+
extra_batch_dims: number of extra batch dims to add singleton
|
489 |
+
axes for, none by default
|
490 |
+
dtype: mask return dtype
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
|
494 |
+
"""
|
495 |
+
mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2))
|
496 |
+
mask = jnp.expand_dims(mask, axis=-3)
|
497 |
+
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
|
498 |
+
return mask.astype(dtype)
|
499 |
+
|
500 |
+
|
501 |
+
def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32) -> Array:
|
502 |
+
"""Make a causal mask for self-attention.
|
503 |
+
|
504 |
+
In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
|
505 |
+
will be `[batch..., heads, len, len]` and this function will produce a
|
506 |
+
causal mask of shape `[batch..., 1, len, len]`.
|
507 |
+
|
508 |
+
Args:
|
509 |
+
x: input array of shape `[batch..., len]`
|
510 |
+
extra_batch_dims: number of batch dims to add singleton axes for,
|
511 |
+
none by default
|
512 |
+
dtype: mask return dtype
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
|
516 |
+
"""
|
517 |
+
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
|
518 |
+
return make_attention_mask(
|
519 |
+
idxs,
|
520 |
+
idxs,
|
521 |
+
jnp.greater_equal,
|
522 |
+
extra_batch_dims=extra_batch_dims,
|
523 |
+
dtype=dtype,
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array:
|
528 |
+
"""Combine attention masks.
|
529 |
+
|
530 |
+
Args:
|
531 |
+
*masks: set of attention mask arguments to combine, some can be None.
|
532 |
+
dtype: dtype for the returned mask.
|
533 |
+
|
534 |
+
Returns:
|
535 |
+
Combined mask, reduced by logical and, returns None if no masks given.
|
536 |
+
"""
|
537 |
+
masks_list = [m for m in masks if m is not None]
|
538 |
+
if not masks_list:
|
539 |
+
return None
|
540 |
+
assert all(
|
541 |
+
map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
|
542 |
+
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}"
|
543 |
+
mask, *other_masks = masks_list
|
544 |
+
for other_mask in other_masks:
|
545 |
+
mask = jnp.logical_and(mask, other_mask)
|
546 |
+
return mask.astype(dtype)
|
kinetix/models/transformer_model.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import flax.linen as nn
|
4 |
+
import numpy as np
|
5 |
+
from flax.linen.initializers import constant, orthogonal
|
6 |
+
from typing import List, Sequence
|
7 |
+
|
8 |
+
import distrax
|
9 |
+
import jax
|
10 |
+
|
11 |
+
from kinetix.models.actor_critic import GeneralActorCriticRNN, ScannedRNN
|
12 |
+
|
13 |
+
|
14 |
+
from kinetix.render.renderer_symbolic_entity import EntityObservation
|
15 |
+
|
16 |
+
from flax.linen.attention import MultiHeadDotProductAttention
|
17 |
+
|
18 |
+
|
19 |
+
class Gating(nn.Module):
|
20 |
+
# code taken from https://github.com/dhruvramani/Transformers-RL/blob/master/layers.py
|
21 |
+
d_input: int
|
22 |
+
bg: float = 0.0
|
23 |
+
|
24 |
+
@nn.compact
|
25 |
+
def __call__(self, x, y):
|
26 |
+
r = jax.nn.sigmoid(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(x))
|
27 |
+
z = jax.nn.sigmoid(
|
28 |
+
nn.Dense(self.d_input, use_bias=False)(y)
|
29 |
+
+ nn.Dense(self.d_input, use_bias=False)(x)
|
30 |
+
- self.param("gating_bias", constant(self.bg), (self.d_input,))
|
31 |
+
)
|
32 |
+
h = jnp.tanh(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(r * x))
|
33 |
+
g = (1 - z) * x + (z * h)
|
34 |
+
return g
|
35 |
+
|
36 |
+
|
37 |
+
class transformer_layer(nn.Module):
|
38 |
+
num_heads: int
|
39 |
+
out_features: int
|
40 |
+
qkv_features: int
|
41 |
+
gating: bool = False
|
42 |
+
gating_bias: float = 0.0
|
43 |
+
|
44 |
+
def setup(self):
|
45 |
+
self.attention1 = MultiHeadDotProductAttention(
|
46 |
+
num_heads=self.num_heads, qkv_features=self.qkv_features, out_features=self.out_features
|
47 |
+
)
|
48 |
+
|
49 |
+
self.ln1 = nn.LayerNorm()
|
50 |
+
|
51 |
+
self.dense1 = nn.Dense(self.out_features)
|
52 |
+
|
53 |
+
self.dense2 = nn.Dense(self.out_features)
|
54 |
+
|
55 |
+
self.ln2 = nn.LayerNorm()
|
56 |
+
if self.gating:
|
57 |
+
self.gate1 = Gating(self.out_features, self.gating_bias)
|
58 |
+
self.gate2 = Gating(self.out_features, self.gating_bias)
|
59 |
+
|
60 |
+
def __call__(self, queries: jnp.ndarray, mask: jnp.ndarray):
|
61 |
+
# After reading the paper, this is what I think we should do:
|
62 |
+
# First layernorm, then do attention
|
63 |
+
queries_n = self.ln1(queries)
|
64 |
+
y = self.attention1(queries_n, mask=mask)
|
65 |
+
if self.gating: # and gate
|
66 |
+
y = self.gate1(queries, jax.nn.relu(y))
|
67 |
+
else:
|
68 |
+
y = queries + y
|
69 |
+
# Dense after norming, crucially no relu.
|
70 |
+
e = self.dense1(self.ln2(y))
|
71 |
+
if self.gating: # and gate again
|
72 |
+
# This may be the wrong way around
|
73 |
+
e = self.gate2(y, jax.nn.relu(e))
|
74 |
+
else:
|
75 |
+
e = y + e
|
76 |
+
|
77 |
+
return e
|
78 |
+
|
79 |
+
|
80 |
+
class Transformer(nn.Module):
|
81 |
+
encoder_size: int
|
82 |
+
num_heads: int
|
83 |
+
qkv_features: int
|
84 |
+
num_layers: int
|
85 |
+
gating: bool = False
|
86 |
+
gating_bias: float = 0.0
|
87 |
+
|
88 |
+
def setup(self):
|
89 |
+
# self.encoder = nn.Dense(self.encoder_size)
|
90 |
+
|
91 |
+
# self.positional_encoding = PositionalEncoding(self.encoder_size, max_len=self.max_len)
|
92 |
+
|
93 |
+
self.tf_layers = [
|
94 |
+
transformer_layer(
|
95 |
+
num_heads=self.num_heads,
|
96 |
+
qkv_features=self.qkv_features,
|
97 |
+
out_features=self.encoder_size,
|
98 |
+
gating=self.gating,
|
99 |
+
gating_bias=self.gating_bias,
|
100 |
+
)
|
101 |
+
for _ in range(self.num_layers)
|
102 |
+
]
|
103 |
+
|
104 |
+
self.joint_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)]
|
105 |
+
self.thruster_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)]
|
106 |
+
|
107 |
+
# self.pos_emb=PositionalEmbedding(self.encoder_size)
|
108 |
+
|
109 |
+
def __call__(
|
110 |
+
self,
|
111 |
+
shape_embeddings: jnp.ndarray,
|
112 |
+
shape_attention_mask,
|
113 |
+
joint_embeddings,
|
114 |
+
joint_mask,
|
115 |
+
joint_indexes,
|
116 |
+
thruster_embeddings,
|
117 |
+
thruster_mask,
|
118 |
+
thruster_indexes,
|
119 |
+
):
|
120 |
+
# forward eval so obs is only one timestep
|
121 |
+
# encoded = self.encoder(shape_embeddings)
|
122 |
+
# pos_embed=self.pos_emb(jnp.arange(1+memories.shape[-3],-1,-1))[:1+memories.shape[-3]]
|
123 |
+
|
124 |
+
for tf_layer, joint_layer, thruster_layer in zip(self.tf_layers, self.joint_layers, self.thruster_layers):
|
125 |
+
# Do attention
|
126 |
+
shape_embeddings = tf_layer(shape_embeddings, shape_attention_mask)
|
127 |
+
|
128 |
+
# Joints
|
129 |
+
# T, B, 2J, (2SE + JE)
|
130 |
+
|
131 |
+
@jax.vmap
|
132 |
+
@jax.vmap
|
133 |
+
def do_index2(to_ind, ind):
|
134 |
+
return to_ind[ind]
|
135 |
+
|
136 |
+
joint_shape_embeddings = jnp.concatenate(
|
137 |
+
[
|
138 |
+
do_index2(shape_embeddings, joint_indexes[..., 0]),
|
139 |
+
do_index2(shape_embeddings, joint_indexes[..., 1]),
|
140 |
+
joint_embeddings,
|
141 |
+
],
|
142 |
+
axis=-1,
|
143 |
+
)
|
144 |
+
|
145 |
+
shape_joint_entity_delta = joint_layer(joint_shape_embeddings) * joint_mask[..., None]
|
146 |
+
|
147 |
+
@jax.vmap
|
148 |
+
@jax.vmap
|
149 |
+
def add2(addee, index, adder):
|
150 |
+
return addee.at[index].add(adder)
|
151 |
+
|
152 |
+
# Thrusters
|
153 |
+
thruster_shape_embeddings = jnp.concatenate(
|
154 |
+
[
|
155 |
+
do_index2(shape_embeddings, thruster_indexes),
|
156 |
+
thruster_embeddings,
|
157 |
+
],
|
158 |
+
axis=-1,
|
159 |
+
)
|
160 |
+
|
161 |
+
shape_thruster_entity_delta = thruster_layer(thruster_shape_embeddings) * thruster_mask[..., None]
|
162 |
+
|
163 |
+
shape_embeddings = add2(shape_embeddings, joint_indexes[..., 0], shape_joint_entity_delta)
|
164 |
+
shape_embeddings = add2(shape_embeddings, thruster_indexes, shape_thruster_entity_delta)
|
165 |
+
|
166 |
+
return shape_embeddings
|
167 |
+
|
168 |
+
|
169 |
+
class ActorCriticTransformer(nn.Module):
|
170 |
+
action_dim: Sequence[int]
|
171 |
+
fc_layer_width: int
|
172 |
+
action_mode: str
|
173 |
+
hybrid_action_continuous_dim: int
|
174 |
+
multi_discrete_number_of_dims_per_distribution: List[int]
|
175 |
+
transformer_size: int
|
176 |
+
transformer_encoder_size: int
|
177 |
+
transformer_depth: int
|
178 |
+
fc_layer_depth: int
|
179 |
+
num_heads: int
|
180 |
+
activation: str
|
181 |
+
aggregate_mode: str # "dummy" or "mean" or "dummy_and_mean"
|
182 |
+
full_attention_mask: bool # if true, only mask out inactives, and have everything attend to everything else
|
183 |
+
add_generator_embedding: bool = False
|
184 |
+
generator_embedding_number_of_timesteps: int = 10
|
185 |
+
recurrent: bool = True
|
186 |
+
|
187 |
+
@nn.compact
|
188 |
+
def __call__(self, hidden, x):
|
189 |
+
if self.activation == "relu":
|
190 |
+
activation = nn.relu
|
191 |
+
else:
|
192 |
+
activation = nn.tanh
|
193 |
+
|
194 |
+
og_obs, dones = x
|
195 |
+
if self.add_generator_embedding:
|
196 |
+
obs = og_obs.obs
|
197 |
+
else:
|
198 |
+
obs = og_obs
|
199 |
+
|
200 |
+
# obs._ is [T, B, N, L]
|
201 |
+
# B - batch size
|
202 |
+
# T - time
|
203 |
+
# N - number of things
|
204 |
+
# L - unembedded entity size
|
205 |
+
obs: EntityObservation
|
206 |
+
|
207 |
+
def _single_encoder(features, entity_id, concat=True):
|
208 |
+
# assume two entity types
|
209 |
+
num_to_remove = 1 if concat else 0
|
210 |
+
embedding = activation(
|
211 |
+
nn.Dense(
|
212 |
+
self.transformer_encoder_size - num_to_remove,
|
213 |
+
kernel_init=orthogonal(np.sqrt(2)),
|
214 |
+
bias_init=constant(0.0),
|
215 |
+
)(features)
|
216 |
+
)
|
217 |
+
if concat:
|
218 |
+
id_1h = jnp.zeros((*embedding.shape[:3], 1)).at[:, :, :, entity_id].set(entity_id)
|
219 |
+
return jnp.concatenate([embedding, id_1h], axis=-1)
|
220 |
+
else:
|
221 |
+
return embedding
|
222 |
+
|
223 |
+
circle_encodings = _single_encoder(obs.circles, 0)
|
224 |
+
polygon_encodings = _single_encoder(obs.polygons, 1)
|
225 |
+
joint_encodings = _single_encoder(obs.joints, -1, False)
|
226 |
+
thruster_encodings = _single_encoder(obs.thrusters, -1, False)
|
227 |
+
# Size of this is something like (T, B, N, K) (time, batch, num_entities, embedding_size)
|
228 |
+
|
229 |
+
# T, B, M, K
|
230 |
+
shape_encodings = jnp.concatenate([polygon_encodings, circle_encodings], axis=2)
|
231 |
+
# T, B, M
|
232 |
+
shape_mask = jnp.concatenate([obs.polygon_mask, obs.circle_mask], axis=2)
|
233 |
+
|
234 |
+
def mask_out_inactives(flat_active_mask, matrix_attention_mask):
|
235 |
+
matrix_attention_mask = matrix_attention_mask & (flat_active_mask[:, None]) & (flat_active_mask[None, :])
|
236 |
+
return matrix_attention_mask
|
237 |
+
|
238 |
+
joint_indexes = obs.joint_indexes
|
239 |
+
thruster_indexes = obs.thruster_indexes
|
240 |
+
|
241 |
+
if self.aggregate_mode == "dummy" or self.aggregate_mode == "dummy_and_mean":
|
242 |
+
T, B, _, K = circle_encodings.shape
|
243 |
+
dummy = jnp.ones((T, B, 1, K))
|
244 |
+
shape_encodings = jnp.concatenate([dummy, shape_encodings], axis=2)
|
245 |
+
shape_mask = jnp.concatenate(
|
246 |
+
[jnp.ones((T, B, 1), dtype=bool), shape_mask],
|
247 |
+
axis=2,
|
248 |
+
)
|
249 |
+
N = obs.attention_mask.shape[-1]
|
250 |
+
overall_mask = (
|
251 |
+
jnp.ones((T, B, obs.attention_mask.shape[2], N + 1, N + 1), dtype=bool)
|
252 |
+
.at[:, :, :, 1:, 1:]
|
253 |
+
.set(obs.attention_mask)
|
254 |
+
)
|
255 |
+
overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask)
|
256 |
+
|
257 |
+
# To account for the dummy entity
|
258 |
+
joint_indexes = joint_indexes + 1
|
259 |
+
thruster_indexes = thruster_indexes + 1
|
260 |
+
|
261 |
+
else:
|
262 |
+
overall_mask = obs.attention_mask
|
263 |
+
|
264 |
+
if self.full_attention_mask:
|
265 |
+
overall_mask = jnp.ones(overall_mask.shape, dtype=bool)
|
266 |
+
overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask)
|
267 |
+
|
268 |
+
# Now do attention on these
|
269 |
+
embedding = Transformer(
|
270 |
+
num_layers=self.transformer_depth,
|
271 |
+
num_heads=self.num_heads,
|
272 |
+
qkv_features=self.transformer_size,
|
273 |
+
encoder_size=self.transformer_encoder_size,
|
274 |
+
gating=True,
|
275 |
+
gating_bias=0.0,
|
276 |
+
)(
|
277 |
+
shape_encodings,
|
278 |
+
jnp.repeat(overall_mask, repeats=self.num_heads // overall_mask.shape[2], axis=2),
|
279 |
+
joint_encodings,
|
280 |
+
obs.joint_mask,
|
281 |
+
joint_indexes,
|
282 |
+
thruster_encodings,
|
283 |
+
obs.thruster_mask,
|
284 |
+
thruster_indexes,
|
285 |
+
) # add the extra dimension for the heads
|
286 |
+
|
287 |
+
if self.aggregate_mode == "mean" or self.aggregate_mode == "dummy_and_mean":
|
288 |
+
embedding = jnp.mean(embedding, axis=2, where=shape_mask[..., None])
|
289 |
+
else:
|
290 |
+
embedding = embedding[:, :, 0] # Take the dummy entity as the embedding of the entire scene.
|
291 |
+
|
292 |
+
return GeneralActorCriticRNN(
|
293 |
+
action_dim=self.action_dim,
|
294 |
+
fc_layer_depth=self.fc_layer_depth,
|
295 |
+
fc_layer_width=self.fc_layer_width,
|
296 |
+
action_mode=self.action_mode,
|
297 |
+
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
|
298 |
+
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
|
299 |
+
add_generator_embedding=self.add_generator_embedding,
|
300 |
+
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
|
301 |
+
recurrent=self.recurrent,
|
302 |
+
)(hidden, og_obs, embedding, dones, activation)
|
kinetix/pcg/__init__.py
ADDED
File without changes
|
kinetix/pcg/pcg.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
from jax2d.engine import recalculate_mass_and_inertia, recompute_global_joint_positions, select_shape
|
4 |
+
from kinetix.environment.env_state import EnvState, StaticEnvParams
|
5 |
+
from kinetix.pcg.pcg_state import PCGState
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
|
9 |
+
|
10 |
+
def _process_tied_together_shapes(pcg_state: PCGState, sampled_state: EnvState, static_params: StaticEnvParams):
|
11 |
+
|
12 |
+
# Get the matrix of tied together positions. Since we vmap, we only want one entry active for any (i, j, k). Thus, we mask out some of the duplicate ones.
|
13 |
+
tied = jnp.triu(pcg_state.tied_together & jnp.logical_not(jnp.eye(pcg_state.tied_together.shape[0], dtype=bool)))
|
14 |
+
has_anything_in_column = tied.any(axis=0)
|
15 |
+
tied = (
|
16 |
+
tied * jnp.logical_not(has_anything_in_column)[:, None]
|
17 |
+
) # if there is something in a column, it means a previous one with a lower index has already been processed
|
18 |
+
|
19 |
+
should_use_delta_positions = tied.any(axis=0)
|
20 |
+
|
21 |
+
# This is the delta we have moved after sampling
|
22 |
+
delta_positions = jnp.concatenate(
|
23 |
+
[
|
24 |
+
sampled_state.polygon.position - pcg_state.env_state.polygon.position,
|
25 |
+
sampled_state.circle.position - pcg_state.env_state.circle.position,
|
26 |
+
]
|
27 |
+
)
|
28 |
+
|
29 |
+
def _get_effect_of_shape_i_on_all_others(item_index, item_row_of_what_is_tied):
|
30 |
+
delta_pos = delta_positions[item_index]
|
31 |
+
return jnp.arange(len(item_row_of_what_is_tied)), delta_pos[None] * item_row_of_what_is_tied[:, None]
|
32 |
+
|
33 |
+
indices, positions = jax.vmap(_get_effect_of_shape_i_on_all_others, (0, 0))(jnp.arange(tied.shape[0]), tied)
|
34 |
+
indices = indices.flatten()
|
35 |
+
positions = positions.reshape(indices.shape[0], -1)
|
36 |
+
|
37 |
+
default_positions = jnp.concatenate(
|
38 |
+
[pcg_state.env_state.polygon.position, pcg_state.env_state.circle.position], axis=0
|
39 |
+
)
|
40 |
+
sampled_positions = jnp.concatenate([sampled_state.polygon.position, sampled_state.circle.position], axis=0)
|
41 |
+
|
42 |
+
updated_positions = default_positions.at[indices].add(positions)
|
43 |
+
# Use the deltas or the sampled positions
|
44 |
+
positions = jnp.where(should_use_delta_positions[:, None], updated_positions, sampled_positions)
|
45 |
+
|
46 |
+
sampled_state = sampled_state.replace(
|
47 |
+
polygon=sampled_state.polygon.replace(position=positions[: static_params.num_polygons]),
|
48 |
+
circle=sampled_state.circle.replace(position=positions[static_params.num_polygons :]),
|
49 |
+
)
|
50 |
+
return sampled_state
|
51 |
+
|
52 |
+
|
53 |
+
@partial(jax.jit, static_argnums=(3,))
|
54 |
+
def sample_pcg_state(rng, pcg_state: PCGState, params, static_params):
|
55 |
+
def _pcg_fn(rng, main_val, max_val, mask):
|
56 |
+
pcg_val = jax.random.uniform(rng, shape=main_val.shape) * (
|
57 |
+
max_val.astype(float) - main_val.astype(float)
|
58 |
+
) + main_val.astype(float)
|
59 |
+
if jnp.issubdtype(main_val.dtype, jnp.integer) or jnp.issubdtype(main_val.dtype, jnp.bool_):
|
60 |
+
pcg_val = jnp.round(pcg_val)
|
61 |
+
pcg_val = pcg_val.astype(main_val.dtype)
|
62 |
+
new_val = jax.lax.select(mask.astype(bool), pcg_val, main_val)
|
63 |
+
return new_val
|
64 |
+
|
65 |
+
def _random_split_like_tree(rng, target):
|
66 |
+
tree_def = jax.tree_structure(target)
|
67 |
+
rngs = jax.random.split(rng, tree_def.num_leaves)
|
68 |
+
return jax.tree_unflatten(tree_def, rngs)
|
69 |
+
|
70 |
+
rng, _rng = jax.random.split(rng)
|
71 |
+
rng_tree = _random_split_like_tree(_rng, pcg_state.env_state)
|
72 |
+
|
73 |
+
sampled_state = jax.tree_util.tree_map(
|
74 |
+
_pcg_fn, rng_tree, pcg_state.env_state, pcg_state.env_state_max, pcg_state.env_state_pcg_mask
|
75 |
+
)
|
76 |
+
|
77 |
+
sampled_state = _process_tied_together_shapes(pcg_state, sampled_state, static_params)
|
78 |
+
|
79 |
+
sampled_state = recompute_global_joint_positions(sampled_state, static_params)
|
80 |
+
|
81 |
+
env_state = recalculate_mass_and_inertia(
|
82 |
+
sampled_state, static_params, sampled_state.polygon_densities, sampled_state.circle_densities
|
83 |
+
)
|
84 |
+
|
85 |
+
return env_state
|
86 |
+
|
87 |
+
|
88 |
+
def env_state_to_pcg_state(env_state: EnvState):
|
89 |
+
N = env_state.polygon.active.shape[0] + env_state.circle.active.shape[0]
|
90 |
+
pcg_state = PCGState(
|
91 |
+
env_state=env_state,
|
92 |
+
env_state_max=env_state,
|
93 |
+
env_state_pcg_mask=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x, dtype=bool), env_state),
|
94 |
+
tied_together=jnp.zeros((N, N), dtype=bool),
|
95 |
+
)
|
96 |
+
|
97 |
+
return pcg_state
|
kinetix/pcg/pcg_state.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import field
|
2 |
+
import jax.numpy as jnp
|
3 |
+
from flax import struct
|
4 |
+
|
5 |
+
from jax2d.sim_state import SimState, SimParams, StaticSimParams, RigidBody, Joint, Thruster, CollisionManifold
|
6 |
+
from kinetix.environment.env_state import EnvState
|
7 |
+
|
8 |
+
|
9 |
+
@struct.dataclass
|
10 |
+
class PCGState:
|
11 |
+
# Primary env state
|
12 |
+
env_state: EnvState
|
13 |
+
# The PCG mask. If a value is truthy in this, then it is PCG not static
|
14 |
+
env_state_pcg_mask: EnvState
|
15 |
+
# In the case that a value is PCG, the env_state value is the min and this state represents the max
|
16 |
+
env_state_max: EnvState
|
17 |
+
|
18 |
+
tied_together: jnp.ndarray # NxN matrix of booleans, where N is the number of shapes
|
19 |
+
|
20 |
+
def __setstate__(self, state):
|
21 |
+
if "tied_together" not in state:
|
22 |
+
num_shapes = state["env_state"].polygon.active.shape[0] + state["env_state"].circle.active.shape[0]
|
23 |
+
state["tied_together"] = jnp.zeros((num_shapes, num_shapes), dtype=bool)
|
24 |
+
object.__setattr__(self, "__dict__", state)
|
kinetix/render/__init__.py
ADDED
File without changes
|
kinetix/render/renderer_pixels.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from jax2d import joint
|
8 |
+
from jax2d.engine import select_shape
|
9 |
+
from jax2d.maths import rmat
|
10 |
+
from jax2d.sim_state import RigidBody
|
11 |
+
from jaxgl.maths import dist_from_line
|
12 |
+
from jaxgl.renderer import clear_screen, make_renderer
|
13 |
+
from jaxgl.shaders import (
|
14 |
+
fragment_shader_quad,
|
15 |
+
fragment_shader_edged_quad,
|
16 |
+
make_fragment_shader_texture,
|
17 |
+
nearest_neighbour,
|
18 |
+
make_fragment_shader_quad_textured,
|
19 |
+
)
|
20 |
+
|
21 |
+
from kinetix.render.textures import (
|
22 |
+
THRUSTER_TEXTURE_16_RGBA,
|
23 |
+
RJOINT_TEXTURE_6_RGBA,
|
24 |
+
FJOINT_TEXTURE_6_RGBA,
|
25 |
+
)
|
26 |
+
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
|
27 |
+
from flax import struct
|
28 |
+
|
29 |
+
|
30 |
+
def make_render_pixels(
|
31 |
+
params,
|
32 |
+
static_params: StaticEnvParams,
|
33 |
+
):
|
34 |
+
screen_dim = static_params.screen_dim
|
35 |
+
downscale = static_params.downscale
|
36 |
+
|
37 |
+
joint_tex_size = 6
|
38 |
+
thruster_tex_size = 16
|
39 |
+
|
40 |
+
FIXATED_COLOUR = jnp.array([80, 80, 80])
|
41 |
+
JOINT_COLOURS = jnp.array(
|
42 |
+
[
|
43 |
+
# [0, 0, 255],
|
44 |
+
[255, 255, 255], # yellow
|
45 |
+
[255, 255, 0], # yellow
|
46 |
+
[255, 0, 255], # purple/magenta
|
47 |
+
[0, 255, 255], # cyan
|
48 |
+
[255, 153, 51], # white
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
def colour_thruster_texture(colour):
|
53 |
+
return THRUSTER_TEXTURE_16_RGBA.at[:9, :, :3].mul(colour[None, None, :] / 255.0)
|
54 |
+
|
55 |
+
coloured_thruster_textures = jax.vmap(colour_thruster_texture)(JOINT_COLOURS)
|
56 |
+
|
57 |
+
ROLE_COLOURS = jnp.array(
|
58 |
+
[
|
59 |
+
[160.0, 160.0, 160.0], # None
|
60 |
+
[0.0, 204.0, 0.0], # Green: The ball
|
61 |
+
[0.0, 102.0, 204.0], # Blue: The goal
|
62 |
+
[255.0, 102.0, 102.0], # Red: Death Objects
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
BACKGROUND_COLOUR = jnp.array([255.0, 255.0, 255.0])
|
67 |
+
|
68 |
+
def _get_colour(shape_role, inverse_inertia):
|
69 |
+
base_colour = ROLE_COLOURS[shape_role]
|
70 |
+
f = (inverse_inertia == 0) * 1
|
71 |
+
is_not_normal = (shape_role != 0) * 1
|
72 |
+
|
73 |
+
return jnp.array(
|
74 |
+
[
|
75 |
+
base_colour,
|
76 |
+
base_colour,
|
77 |
+
FIXATED_COLOUR,
|
78 |
+
base_colour * 0.5,
|
79 |
+
]
|
80 |
+
)[2 * f + is_not_normal]
|
81 |
+
|
82 |
+
# Pixels per unit distance
|
83 |
+
ppud = params.pixels_per_unit // downscale
|
84 |
+
|
85 |
+
downscaled_screen_dim = (screen_dim[0] // downscale, screen_dim[1] // downscale)
|
86 |
+
|
87 |
+
full_screen_size = (
|
88 |
+
downscaled_screen_dim[0] + (static_params.max_shape_size * 2 * ppud),
|
89 |
+
downscaled_screen_dim[1] + (static_params.max_shape_size * 2 * ppud),
|
90 |
+
)
|
91 |
+
cleared_screen = clear_screen(full_screen_size, BACKGROUND_COLOUR)
|
92 |
+
|
93 |
+
def _world_space_to_pixel_space(x):
|
94 |
+
return (x + static_params.max_shape_size) * ppud
|
95 |
+
|
96 |
+
def fragment_shader_kinetix_circle(position, current_frag, unit_position, uniform):
|
97 |
+
centre, radius, rotation, colour, mask = uniform
|
98 |
+
|
99 |
+
dist = jnp.sqrt(jnp.square(position - centre).sum())
|
100 |
+
inside = dist <= radius
|
101 |
+
on_edge = dist > radius - 2
|
102 |
+
|
103 |
+
# TODO - precompute?
|
104 |
+
normal = jnp.array([jnp.sin(rotation), -jnp.cos(rotation)])
|
105 |
+
|
106 |
+
dist = dist_from_line(position, centre, centre + normal)
|
107 |
+
|
108 |
+
on_edge |= (dist < 1) & (jnp.dot(normal, position - centre) <= 0)
|
109 |
+
|
110 |
+
fragment = jax.lax.select(on_edge, jnp.zeros(3), colour)
|
111 |
+
|
112 |
+
return jax.lax.select(inside & mask, fragment, current_frag)
|
113 |
+
|
114 |
+
def fragment_shader_kinetix_joint(position, current_frag, unit_position, uniform):
|
115 |
+
texture, colour, mask = uniform
|
116 |
+
|
117 |
+
tex_coord = (
|
118 |
+
jnp.array(
|
119 |
+
[
|
120 |
+
joint_tex_size * unit_position[0],
|
121 |
+
joint_tex_size * unit_position[1],
|
122 |
+
]
|
123 |
+
)
|
124 |
+
- 0.5
|
125 |
+
)
|
126 |
+
|
127 |
+
tex_frag = nearest_neighbour(texture, tex_coord)
|
128 |
+
tex_frag = tex_frag.at[3].mul(mask)
|
129 |
+
tex_frag = tex_frag.at[:3].mul(colour / 255.0)
|
130 |
+
|
131 |
+
tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
|
132 |
+
|
133 |
+
return tex_frag
|
134 |
+
|
135 |
+
thruster_pixel_size = thruster_tex_size // downscale
|
136 |
+
thruster_pixel_size_diagonal = (thruster_pixel_size * np.sqrt(2)).astype(jnp.int32) + 1
|
137 |
+
|
138 |
+
def fragment_shader_kinetix_thruster(fragment_position, current_frag, unit_position, uniform):
|
139 |
+
thruster_position, rotation, texture, mask = uniform
|
140 |
+
|
141 |
+
tex_position = jnp.matmul(rmat(-rotation), (fragment_position - thruster_position)) / thruster_pixel_size + 0.5
|
142 |
+
|
143 |
+
mask &= (tex_position[0] >= 0) & (tex_position[0] <= 1) & (tex_position[1] >= 0) & (tex_position[1] <= 1)
|
144 |
+
|
145 |
+
eps = 0.001
|
146 |
+
tex_coord = (
|
147 |
+
jnp.array(
|
148 |
+
[
|
149 |
+
thruster_tex_size * tex_position[0],
|
150 |
+
thruster_tex_size * tex_position[1],
|
151 |
+
]
|
152 |
+
)
|
153 |
+
- 0.5
|
154 |
+
+ eps
|
155 |
+
)
|
156 |
+
|
157 |
+
tex_frag = nearest_neighbour(texture, tex_coord)
|
158 |
+
tex_frag = tex_frag.at[3].mul(mask)
|
159 |
+
|
160 |
+
tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
|
161 |
+
|
162 |
+
return tex_frag
|
163 |
+
|
164 |
+
patch_size_1d = static_params.max_shape_size * ppud
|
165 |
+
patch_size = (patch_size_1d, patch_size_1d)
|
166 |
+
|
167 |
+
circle_renderer = make_renderer(full_screen_size, fragment_shader_kinetix_circle, patch_size, batched=True)
|
168 |
+
quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, patch_size, batched=True)
|
169 |
+
big_quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, downscaled_screen_dim)
|
170 |
+
|
171 |
+
joint_pixel_size = joint_tex_size // downscale
|
172 |
+
joint_renderer = make_renderer(
|
173 |
+
full_screen_size, fragment_shader_kinetix_joint, (joint_pixel_size, joint_pixel_size), batched=True
|
174 |
+
)
|
175 |
+
|
176 |
+
thruster_renderer = make_renderer(
|
177 |
+
full_screen_size,
|
178 |
+
fragment_shader_kinetix_thruster,
|
179 |
+
(thruster_pixel_size_diagonal, thruster_pixel_size_diagonal),
|
180 |
+
batched=True,
|
181 |
+
)
|
182 |
+
|
183 |
+
@jax.jit
|
184 |
+
def render_pixels(state: EnvState):
|
185 |
+
pixels = cleared_screen
|
186 |
+
|
187 |
+
# Floor
|
188 |
+
floor_uniform = (
|
189 |
+
_world_space_to_pixel_space(state.polygon.position[0, None, :] + state.polygon.vertices[0]),
|
190 |
+
_get_colour(state.polygon_shape_roles[0], 0),
|
191 |
+
jnp.zeros(3),
|
192 |
+
True,
|
193 |
+
)
|
194 |
+
|
195 |
+
pixels = big_quad_renderer(pixels, _world_space_to_pixel_space(jnp.zeros(2, dtype=jnp.int32)), floor_uniform)
|
196 |
+
|
197 |
+
# Rectangles
|
198 |
+
rectangle_patch_positions = _world_space_to_pixel_space(
|
199 |
+
state.polygon.position - (static_params.max_shape_size / 2.0)
|
200 |
+
).astype(jnp.int32)
|
201 |
+
|
202 |
+
rectangle_rmats = jax.vmap(rmat)(state.polygon.rotation)
|
203 |
+
rectangle_rmats = jnp.repeat(rectangle_rmats[:, None, :, :], repeats=static_params.max_polygon_vertices, axis=1)
|
204 |
+
rectangle_vertices_pixel_space = _world_space_to_pixel_space(
|
205 |
+
state.polygon.position[:, None, :] + jax.vmap(jax.vmap(jnp.matmul))(rectangle_rmats, state.polygon.vertices)
|
206 |
+
)
|
207 |
+
rectangle_colours = jax.vmap(_get_colour)(state.polygon_shape_roles, state.polygon.inverse_mass)
|
208 |
+
rectangle_edge_colours = jnp.zeros((static_params.num_polygons, 3))
|
209 |
+
|
210 |
+
rectangle_uniforms = (
|
211 |
+
rectangle_vertices_pixel_space,
|
212 |
+
rectangle_colours,
|
213 |
+
rectangle_edge_colours,
|
214 |
+
state.polygon.active,
|
215 |
+
)
|
216 |
+
|
217 |
+
pixels = quad_renderer(pixels, rectangle_patch_positions, rectangle_uniforms)
|
218 |
+
|
219 |
+
# Circles
|
220 |
+
circle_positions_pixel_space = _world_space_to_pixel_space(state.circle.position)
|
221 |
+
circle_radii_pixel_space = state.circle.radius * ppud
|
222 |
+
circle_patch_positions = _world_space_to_pixel_space(
|
223 |
+
state.circle.position - (static_params.max_shape_size / 2.0)
|
224 |
+
).astype(jnp.int32)
|
225 |
+
|
226 |
+
circle_colours = jax.vmap(_get_colour)(state.circle_shape_roles, state.circle.inverse_mass)
|
227 |
+
|
228 |
+
circle_uniforms = (
|
229 |
+
circle_positions_pixel_space,
|
230 |
+
circle_radii_pixel_space,
|
231 |
+
state.circle.rotation,
|
232 |
+
circle_colours,
|
233 |
+
state.circle.active,
|
234 |
+
)
|
235 |
+
|
236 |
+
pixels = circle_renderer(pixels, circle_patch_positions, circle_uniforms)
|
237 |
+
|
238 |
+
# Joints
|
239 |
+
joint_patch_positions = jnp.round(
|
240 |
+
_world_space_to_pixel_space(state.joint.global_position) - (joint_pixel_size // 2)
|
241 |
+
).astype(jnp.int32)
|
242 |
+
joint_textures = jax.vmap(jax.lax.select, in_axes=(0, None, None))(
|
243 |
+
state.joint.is_fixed_joint, FJOINT_TEXTURE_6_RGBA, RJOINT_TEXTURE_6_RGBA
|
244 |
+
)
|
245 |
+
joint_colours = JOINT_COLOURS[
|
246 |
+
(state.motor_bindings + 1) * (state.joint.motor_on & (~state.joint.is_fixed_joint))
|
247 |
+
]
|
248 |
+
|
249 |
+
joint_uniforms = (joint_textures, joint_colours, state.joint.active)
|
250 |
+
|
251 |
+
pixels = joint_renderer(pixels, joint_patch_positions, joint_uniforms)
|
252 |
+
|
253 |
+
# Thrusters
|
254 |
+
thruster_positions = jnp.round(_world_space_to_pixel_space(state.thruster.global_position)).astype(jnp.int32)
|
255 |
+
thruster_patch_positions = thruster_positions - (thruster_pixel_size_diagonal // 2)
|
256 |
+
thruster_textures = coloured_thruster_textures[state.thruster_bindings + 1]
|
257 |
+
thruster_rotations = (
|
258 |
+
state.thruster.rotation
|
259 |
+
+ jax.vmap(select_shape, in_axes=(None, 0, None))(
|
260 |
+
state, state.thruster.object_index, static_params
|
261 |
+
).rotation
|
262 |
+
)
|
263 |
+
thruster_uniforms = (thruster_positions, thruster_rotations, thruster_textures, state.thruster.active)
|
264 |
+
|
265 |
+
pixels = thruster_renderer(pixels, thruster_patch_positions, thruster_uniforms)
|
266 |
+
|
267 |
+
# Crop out the sides
|
268 |
+
crop_amount = static_params.max_shape_size * ppud
|
269 |
+
return pixels[crop_amount:-crop_amount, crop_amount:-crop_amount]
|
270 |
+
|
271 |
+
return render_pixels
|
272 |
+
|
273 |
+
|
274 |
+
@struct.dataclass
|
275 |
+
class PixelsObservation:
|
276 |
+
image: jnp.ndarray
|
277 |
+
global_info: jnp.ndarray
|
278 |
+
|
279 |
+
|
280 |
+
def make_render_pixels_rl(params, static_params: StaticEnvParams):
|
281 |
+
render_fn = make_render_pixels(params, static_params)
|
282 |
+
|
283 |
+
def inner(state):
|
284 |
+
pixels = render_fn(state) / 255.0
|
285 |
+
return PixelsObservation(
|
286 |
+
image=pixels,
|
287 |
+
global_info=jnp.array([state.gravity[1] / 10.0]),
|
288 |
+
)
|
289 |
+
|
290 |
+
return inner
|
kinetix/render/renderer_symbolic_common.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
from jax2d.sim_state import RigidBody
|
3 |
+
import jax.numpy as jnp
|
4 |
+
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
|
5 |
+
|
6 |
+
|
7 |
+
def _get_base_shape_features(
|
8 |
+
density: jnp.ndarray, roles: jnp.ndarray, shapes: RigidBody, env_params: EnvParams
|
9 |
+
) -> jnp.ndarray:
|
10 |
+
cos = jnp.cos(shapes.rotation)
|
11 |
+
sin = jnp.sin(shapes.rotation)
|
12 |
+
return jnp.concatenate(
|
13 |
+
[
|
14 |
+
shapes.position,
|
15 |
+
shapes.velocity,
|
16 |
+
jnp.expand_dims(shapes.inverse_mass, axis=1),
|
17 |
+
jnp.expand_dims(shapes.inverse_inertia, axis=1),
|
18 |
+
jnp.expand_dims(density, axis=1),
|
19 |
+
jnp.expand_dims(jnp.tanh(shapes.angular_velocity / 10), axis=1),
|
20 |
+
jax.nn.one_hot(roles, env_params.num_shape_roles),
|
21 |
+
jnp.expand_dims(sin, axis=1),
|
22 |
+
jnp.expand_dims(cos, axis=1),
|
23 |
+
jnp.expand_dims(shapes.friction, axis=1),
|
24 |
+
jnp.expand_dims(shapes.restitution, axis=1),
|
25 |
+
],
|
26 |
+
axis=1,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def add_circle_features(
|
31 |
+
base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
|
32 |
+
):
|
33 |
+
return jnp.concatenate(
|
34 |
+
[
|
35 |
+
base_features,
|
36 |
+
shapes.radius[:, None],
|
37 |
+
jnp.ones_like(base_features[:, :1]), # one for circle
|
38 |
+
],
|
39 |
+
axis=1,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def make_circle_features(
|
44 |
+
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
|
45 |
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
46 |
+
base_features = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
|
47 |
+
node_features = add_circle_features(base_features, state.circle, env_params, static_env_params)
|
48 |
+
return node_features, state.circle.active
|
49 |
+
|
50 |
+
|
51 |
+
def add_polygon_features(
|
52 |
+
base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
|
53 |
+
):
|
54 |
+
vertices = jnp.where(
|
55 |
+
jnp.arange(static_env_params.max_polygon_vertices)[None, :, None] < shapes.n_vertices[:, None, None],
|
56 |
+
shapes.vertices,
|
57 |
+
jnp.zeros_like(shapes.vertices) - 1,
|
58 |
+
)
|
59 |
+
|
60 |
+
return jnp.concatenate(
|
61 |
+
[
|
62 |
+
base_features,
|
63 |
+
jnp.zeros_like(base_features[:, :1]), # zero for polygon
|
64 |
+
vertices.reshape((vertices.shape[0], -1)),
|
65 |
+
jnp.expand_dims((shapes.n_vertices <= 3), axis=1),
|
66 |
+
],
|
67 |
+
axis=1,
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def make_polygon_features(
|
72 |
+
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
|
73 |
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
74 |
+
base_features = _get_base_shape_features(
|
75 |
+
state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params
|
76 |
+
)
|
77 |
+
node_features = add_polygon_features(base_features, state.polygon, env_params, static_env_params)
|
78 |
+
return node_features, state.polygon.active
|
79 |
+
|
80 |
+
|
81 |
+
def make_unified_shape_features(
|
82 |
+
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
|
83 |
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
84 |
+
base_p = _get_base_shape_features(state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params)
|
85 |
+
base_c = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
|
86 |
+
base_p = add_polygon_features(base_p, state.polygon, env_params, static_env_params)
|
87 |
+
base_p = add_circle_features(base_p, state.polygon, env_params, static_env_params)
|
88 |
+
|
89 |
+
base_c = add_polygon_features(base_c, state.circle, env_params, static_env_params)
|
90 |
+
base_c = add_circle_features(base_c, state.circle, env_params, static_env_params)
|
91 |
+
|
92 |
+
return jnp.concatenate([base_p, base_c], axis=0), jnp.concatenate(
|
93 |
+
[state.polygon.active, state.circle.active], axis=0
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def make_joint_features(
|
98 |
+
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
|
99 |
+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
100 |
+
# Returns joint_features, indexes, mask, of shape:
|
101 |
+
# (2 * J, K), (2 * J, 2), (2 * J,)
|
102 |
+
def _create_joint_features(joints):
|
103 |
+
# 2, J, A
|
104 |
+
J = joints.active.shape[0]
|
105 |
+
|
106 |
+
def _create_1way_joint_features(direction):
|
107 |
+
from_pos = jax.lax.select(direction, joints.a_relative_pos, joints.b_relative_pos)
|
108 |
+
to_pos = jax.lax.select(direction, joints.b_relative_pos, joints.a_relative_pos)
|
109 |
+
|
110 |
+
rotation_sin, rotation_cos = jnp.sin(joints.rotation), jnp.cos(joints.rotation)
|
111 |
+
rotation_max_sin = jnp.sin(joints.max_rotation) * joints.motor_has_joint_limits
|
112 |
+
rotation_max_cos = jnp.cos(joints.max_rotation) * joints.motor_has_joint_limits
|
113 |
+
rotation_min_sin = jnp.sin(joints.min_rotation) * joints.motor_has_joint_limits
|
114 |
+
rotation_min_cos = jnp.cos(joints.min_rotation) * joints.motor_has_joint_limits
|
115 |
+
|
116 |
+
rotation_diff_max = (joints.max_rotation - joints.rotation) * joints.motor_has_joint_limits
|
117 |
+
rotation_diff_min = (joints.min_rotation - joints.rotation) * joints.motor_has_joint_limits
|
118 |
+
|
119 |
+
base_features = jnp.concatenate(
|
120 |
+
[
|
121 |
+
(joints.active * 1.0)[:, None],
|
122 |
+
(joints.is_fixed_joint * 1.0)[:, None], # J, 1
|
123 |
+
from_pos,
|
124 |
+
to_pos,
|
125 |
+
rotation_sin[:, None],
|
126 |
+
rotation_cos[:, None],
|
127 |
+
],
|
128 |
+
axis=1,
|
129 |
+
)
|
130 |
+
rjoint_features = (
|
131 |
+
jnp.concatenate(
|
132 |
+
[
|
133 |
+
joints.motor_speed[:, None],
|
134 |
+
joints.motor_power[:, None],
|
135 |
+
(joints.motor_on * 1.0)[:, None],
|
136 |
+
(joints.motor_has_joint_limits * 1.0)[:, None],
|
137 |
+
jax.nn.one_hot(state.motor_bindings, num_classes=static_env_params.num_motor_bindings),
|
138 |
+
rotation_min_sin[:, None],
|
139 |
+
rotation_min_cos[:, None],
|
140 |
+
rotation_max_sin[:, None],
|
141 |
+
rotation_max_cos[:, None],
|
142 |
+
rotation_diff_min[:, None],
|
143 |
+
rotation_diff_max[:, None],
|
144 |
+
],
|
145 |
+
axis=1,
|
146 |
+
)
|
147 |
+
* (1.0 - (joints.is_fixed_joint * 1.0))[:, None]
|
148 |
+
)
|
149 |
+
|
150 |
+
return jnp.concatenate([base_features, rjoint_features], axis=1)
|
151 |
+
|
152 |
+
# 2, J, A
|
153 |
+
joint_features = jax.vmap(_create_1way_joint_features)(jnp.array([False, True]))
|
154 |
+
|
155 |
+
# J, 2
|
156 |
+
indexes_from = jnp.concatenate([joints.b_index[:, None], joints.a_index[:, None]], axis=1)
|
157 |
+
indexes_to = jnp.concatenate([joints.a_index[:, None], joints.b_index[:, None]], axis=1)
|
158 |
+
|
159 |
+
indexes_from = jnp.where(joints.active[:, None], indexes_from, jnp.zeros_like(indexes_from))
|
160 |
+
indexes_to = jnp.where(joints.active[:, None], indexes_to, jnp.zeros_like(indexes_to))
|
161 |
+
|
162 |
+
indexes = jnp.concatenate([indexes_from, indexes_to], axis=0)
|
163 |
+
mask = jnp.concatenate([joints.active, joints.active], axis=0)
|
164 |
+
|
165 |
+
return joint_features.reshape((2 * J, -1)), indexes, mask
|
166 |
+
|
167 |
+
return _create_joint_features(state.joint)
|
168 |
+
|
169 |
+
|
170 |
+
def make_thruster_features(
|
171 |
+
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
|
172 |
+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
173 |
+
# Returns thruster_features, indexes, mask, of shape:
|
174 |
+
# (T, K), (T,), (T,)
|
175 |
+
def _create_thruster_features(thrusters):
|
176 |
+
cos = jnp.cos(thrusters.rotation)
|
177 |
+
sin = jnp.sin(thrusters.rotation)
|
178 |
+
return jnp.concatenate(
|
179 |
+
[
|
180 |
+
(thrusters.active * 1.0)[:, None],
|
181 |
+
(thrusters.relative_position),
|
182 |
+
jax.nn.one_hot(state.thruster_bindings, num_classes=static_env_params.num_thruster_bindings),
|
183 |
+
sin[:, None],
|
184 |
+
cos[:, None],
|
185 |
+
thrusters.power[:, None],
|
186 |
+
],
|
187 |
+
axis=1,
|
188 |
+
)
|
189 |
+
|
190 |
+
return _create_thruster_features(state.thruster), state.thruster.object_index, state.thruster.active
|
kinetix/render/renderer_symbolic_entity.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cmath import rect
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
from flax import struct
|
7 |
+
from jax2d.engine import get_pairwise_interaction_indices
|
8 |
+
from kinetix.environment.env_state import EnvState
|
9 |
+
from kinetix.render.renderer_symbolic_common import (
|
10 |
+
make_circle_features,
|
11 |
+
make_joint_features,
|
12 |
+
make_polygon_features,
|
13 |
+
make_thruster_features,
|
14 |
+
make_unified_shape_features,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
@struct.dataclass
|
19 |
+
class EntityObservation:
|
20 |
+
circles: jnp.ndarray
|
21 |
+
polygons: jnp.ndarray
|
22 |
+
joints: jnp.ndarray
|
23 |
+
thrusters: jnp.ndarray
|
24 |
+
|
25 |
+
circle_mask: jnp.ndarray
|
26 |
+
polygon_mask: jnp.ndarray
|
27 |
+
joint_mask: jnp.ndarray
|
28 |
+
thruster_mask: jnp.ndarray
|
29 |
+
attention_mask: jnp.ndarray
|
30 |
+
# collision_mask: jnp.ndarray
|
31 |
+
|
32 |
+
joint_indexes: jnp.ndarray
|
33 |
+
thruster_indexes: jnp.ndarray
|
34 |
+
|
35 |
+
|
36 |
+
def make_render_entities(params, static_params):
|
37 |
+
_, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params)
|
38 |
+
circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons)
|
39 |
+
circle_circle_pairs = circle_circle_pairs + static_params.num_polygons
|
40 |
+
|
41 |
+
def render_entities(state: EnvState):
|
42 |
+
state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state)
|
43 |
+
|
44 |
+
joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params)
|
45 |
+
thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params)
|
46 |
+
|
47 |
+
poly_nodes, poly_mask = make_polygon_features(state, params, static_params)
|
48 |
+
circle_nodes, circle_mask = make_circle_features(state, params, static_params)
|
49 |
+
|
50 |
+
def _add_grav(nodes):
|
51 |
+
return jnp.concatenate(
|
52 |
+
[nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1
|
53 |
+
) # add gravity to each shape's embedding
|
54 |
+
|
55 |
+
poly_nodes = _add_grav(poly_nodes)
|
56 |
+
circle_nodes = _add_grav(circle_nodes)
|
57 |
+
|
58 |
+
# Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster )
|
59 |
+
mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0)
|
60 |
+
num_shapes = static_params.num_polygons + static_params.num_circles
|
61 |
+
|
62 |
+
def make_n_squared_mask(val):
|
63 |
+
# val has shape N of bools.
|
64 |
+
N = val.shape[0]
|
65 |
+
A = jnp.eye(N, N, dtype=bool) # also have things attend to themselves
|
66 |
+
# Make the shapes fully connected
|
67 |
+
full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool))
|
68 |
+
|
69 |
+
one_hop_connected = jnp.zeros((N, N), dtype=bool)
|
70 |
+
one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True)
|
71 |
+
one_hop_connected = one_hop_connected.at[0, 0].set(False) # invalid joints have indices of (0, 0)
|
72 |
+
|
73 |
+
multi_hop_connected = jnp.logical_not(state.collision_matrix)
|
74 |
+
|
75 |
+
collision_mask = state.collision_matrix
|
76 |
+
|
77 |
+
# where val is false, we want to mask out the row and column.
|
78 |
+
full_mask = full_mask & (val[:, None]) & (val[None, :])
|
79 |
+
collision_mask = collision_mask & (val[:, None]) & (val[None, :])
|
80 |
+
multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :])
|
81 |
+
one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :])
|
82 |
+
collision_manifold_mask = jnp.zeros_like(collision_mask)
|
83 |
+
|
84 |
+
def _set(collision_manifold_mask, pairs, active):
|
85 |
+
return collision_manifold_mask.at[
|
86 |
+
pairs[:, 0],
|
87 |
+
pairs[:, 1],
|
88 |
+
].set(active)
|
89 |
+
|
90 |
+
collision_manifold_mask = _set(
|
91 |
+
collision_manifold_mask,
|
92 |
+
rect_rect_pairs,
|
93 |
+
jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]),
|
94 |
+
)
|
95 |
+
|
96 |
+
collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active)
|
97 |
+
collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active)
|
98 |
+
collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :])
|
99 |
+
|
100 |
+
return jnp.concatenate(
|
101 |
+
[full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]],
|
102 |
+
axis=0,
|
103 |
+
)
|
104 |
+
|
105 |
+
mask_n_squared = make_n_squared_mask(mask_flat_shapes)
|
106 |
+
|
107 |
+
return EntityObservation(
|
108 |
+
circles=circle_nodes,
|
109 |
+
polygons=poly_nodes,
|
110 |
+
joints=joint_features,
|
111 |
+
thrusters=thruster_features,
|
112 |
+
circle_mask=circle_mask,
|
113 |
+
polygon_mask=poly_mask,
|
114 |
+
joint_mask=joint_mask,
|
115 |
+
thruster_mask=thruster_mask,
|
116 |
+
attention_mask=mask_n_squared,
|
117 |
+
joint_indexes=joint_indexes,
|
118 |
+
thruster_indexes=thruster_indexes,
|
119 |
+
)
|
120 |
+
|
121 |
+
return render_entities
|
kinetix/render/renderer_symbolic_flat.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from jax2d import joint
|
8 |
+
from jax2d.engine import select_shape
|
9 |
+
from jax2d.maths import rmat
|
10 |
+
from jax2d.sim_state import RigidBody
|
11 |
+
from jaxgl.maths import dist_from_line
|
12 |
+
from jaxgl.renderer import clear_screen, make_renderer
|
13 |
+
from jaxgl.shaders import (
|
14 |
+
fragment_shader_quad,
|
15 |
+
fragment_shader_edged_quad,
|
16 |
+
make_fragment_shader_texture,
|
17 |
+
nearest_neighbour,
|
18 |
+
make_fragment_shader_quad_textured,
|
19 |
+
)
|
20 |
+
from kinetix.render.renderer_symbolic_common import (
|
21 |
+
make_circle_features,
|
22 |
+
make_joint_features,
|
23 |
+
make_polygon_features,
|
24 |
+
make_thruster_features,
|
25 |
+
)
|
26 |
+
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
|
27 |
+
from flax import struct
|
28 |
+
|
29 |
+
|
30 |
+
def make_render_symbolic(params, static_params: StaticEnvParams):
|
31 |
+
def render_symbolic(state):
|
32 |
+
|
33 |
+
n_polys = static_params.num_polygons
|
34 |
+
nshapes = n_polys + static_params.num_circles
|
35 |
+
|
36 |
+
polygon_features, polygon_mask = make_polygon_features(state, params, static_params)
|
37 |
+
mask_to_ignore_walls_ceiling = np.ones(static_params.num_polygons, dtype=bool)
|
38 |
+
mask_to_ignore_walls_ceiling[np.array([1, 2, 3])] = False
|
39 |
+
|
40 |
+
polygon_features = polygon_features[mask_to_ignore_walls_ceiling]
|
41 |
+
polygon_mask = polygon_mask[mask_to_ignore_walls_ceiling]
|
42 |
+
|
43 |
+
circle_features, circle_mask = make_circle_features(state, params, static_params)
|
44 |
+
joint_features, joint_idxs, joint_mask = make_joint_features(state, params, static_params)
|
45 |
+
thruster_features, thruster_idxs, thruster_mask = make_thruster_features(state, params, static_params)
|
46 |
+
|
47 |
+
two_J = joint_features.shape[0]
|
48 |
+
J = two_J // 2 # for symbolic only have the one
|
49 |
+
joint_features = jnp.concatenate(
|
50 |
+
[
|
51 |
+
joint_features[:J], # shape (2 * J, K)
|
52 |
+
jax.nn.one_hot(joint_idxs[:J, 0], nshapes), # shape (2 * J, N)
|
53 |
+
jax.nn.one_hot(joint_idxs[:J, 1], nshapes), # shape (2 * J, N)
|
54 |
+
],
|
55 |
+
axis=1,
|
56 |
+
)
|
57 |
+
thruster_features = jnp.concatenate(
|
58 |
+
[
|
59 |
+
thruster_features,
|
60 |
+
jax.nn.one_hot(thruster_idxs, nshapes),
|
61 |
+
],
|
62 |
+
axis=1,
|
63 |
+
)
|
64 |
+
|
65 |
+
polygon_features = jnp.where(polygon_mask[:, None], polygon_features, 0.0).flatten()
|
66 |
+
circle_features = jnp.where(circle_mask[:, None], circle_features, 0.0).flatten()
|
67 |
+
joint_features = jnp.where(joint_mask[:J, None], joint_features, 0.0).flatten()
|
68 |
+
thruster_features = jnp.where(thruster_mask[:, None], thruster_features, 0.0).flatten()
|
69 |
+
|
70 |
+
def _get_manifold_features(manifold):
|
71 |
+
collision_mask_features = jnp.concatenate(
|
72 |
+
[
|
73 |
+
manifold.normal,
|
74 |
+
jnp.expand_dims(manifold.penetration, axis=-1),
|
75 |
+
manifold.collision_point,
|
76 |
+
jnp.expand_dims(manifold.acc_impulse_normal, axis=-1),
|
77 |
+
jnp.expand_dims(manifold.acc_impulse_tangent, axis=-1),
|
78 |
+
],
|
79 |
+
axis=-1,
|
80 |
+
)
|
81 |
+
|
82 |
+
return (collision_mask_features * manifold.active[..., None]).flatten()
|
83 |
+
|
84 |
+
obs = jnp.concatenate(
|
85 |
+
[
|
86 |
+
polygon_features,
|
87 |
+
circle_features,
|
88 |
+
joint_features,
|
89 |
+
thruster_features,
|
90 |
+
jnp.array([state.gravity[1]]) / 10,
|
91 |
+
# _get_manifold_features(state.acc_cc_manifolds),
|
92 |
+
# _get_manifold_features(state.acc_cr_manifolds),
|
93 |
+
# _get_manifold_features(state.acc_rr_manifolds),
|
94 |
+
],
|
95 |
+
axis=0,
|
96 |
+
)
|
97 |
+
|
98 |
+
obs = jnp.clip(obs, a_min=-10.0, a_max=10.0)
|
99 |
+
obs = jnp.nan_to_num(obs)
|
100 |
+
return obs
|
101 |
+
|
102 |
+
return render_symbolic
|
kinetix/render/textures.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
from enum import Enum
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import imageio.v3 as iio
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
def load_texture(filename, render_size):
|
11 |
+
filename = os.path.join(pathlib.Path(__file__).parent.parent.resolve(), "assets", filename)
|
12 |
+
img = iio.imread(filename)
|
13 |
+
jnp_img = jnp.array(img).astype(jnp.int32)
|
14 |
+
|
15 |
+
if jnp_img.shape[2] == 4:
|
16 |
+
jnp_img = jnp_img.at[:, :, 3].set(jnp_img[:, :, 3] // 255)
|
17 |
+
|
18 |
+
img = np.array(jnp_img, dtype=np.uint8)
|
19 |
+
image = Image.fromarray(img)
|
20 |
+
image = image.resize((render_size, render_size), resample=Image.NEAREST)
|
21 |
+
jnp_img = jnp.array(image, dtype=jnp.float32)
|
22 |
+
|
23 |
+
return jnp_img.transpose((1, 0, 2))
|
24 |
+
|
25 |
+
|
26 |
+
EDIT_TEXTURE_RGBA = load_texture("edit.png", 64)
|
27 |
+
PLAY_TEXTURE_RGBA = load_texture("play.png", 64)
|
28 |
+
|
29 |
+
CIRCLE_TEXTURE_RGBA = load_texture("circle.png", 32)
|
30 |
+
RECT_TEXTURE_RGBA = load_texture("square.png", 32)
|
31 |
+
TRIANGLE_TEXTURE_RGBA = load_texture("triangle.png", 32)
|
32 |
+
RJOINT_TEXTURE_6_RGBA = load_texture("rjoint.png", 6)
|
33 |
+
RJOINT_TEXTURE_RGBA = load_texture("rjoint2.png", 32)
|
34 |
+
|
35 |
+
FJOINT_TEXTURE_6_RGBA = load_texture("fjoint.png", 6)
|
36 |
+
FJOINT_TEXTURE_RGBA = load_texture("fjoint2.png", 32)
|
37 |
+
|
38 |
+
|
39 |
+
ROTATION_TEXTURE_RGBA = load_texture("rotate.png", 32)
|
40 |
+
SELECT_TEXTURE_RGBA = load_texture("hand.png", 32)
|
41 |
+
|
42 |
+
THRUSTER_TEXTURE_RGBA = jnp.rot90(load_texture("thruster6.png", 32), k=3)
|
43 |
+
THRUSTER_TEXTURE_16_RGBA = jnp.rot90(load_texture("thruster.png", 16), k=3)
|
kinetix/util/__init__.py
ADDED
File without changes
|
kinetix/util/config.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import gzip
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from hashlib import md5
|
7 |
+
|
8 |
+
import jax
|
9 |
+
import jax.numpy as jnp
|
10 |
+
import numpy as np
|
11 |
+
from numpy import isin
|
12 |
+
from kinetix.environment.ued.ued_state import UEDParams
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from pandas import isna
|
15 |
+
from typing import List, Tuple
|
16 |
+
import wandb
|
17 |
+
from kinetix.environment.env_state import EnvParams, StaticEnvParams
|
18 |
+
from collections import defaultdict
|
19 |
+
|
20 |
+
from kinetix.util.saving import load_from_json_file
|
21 |
+
|
22 |
+
|
23 |
+
def get_hash_without_seed(config):
|
24 |
+
old_seed = config["seed"]
|
25 |
+
config["seed"] = 0
|
26 |
+
ans = md5(OmegaConf.to_yaml(config, sort_keys=True).encode()).hexdigest()
|
27 |
+
config["seed"] = old_seed
|
28 |
+
return ans
|
29 |
+
|
30 |
+
|
31 |
+
def get_date() -> str:
|
32 |
+
return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
33 |
+
|
34 |
+
|
35 |
+
def generate_params_from_config(config):
|
36 |
+
if config.get("env_size_type", "predefined") == "custom":
|
37 |
+
# must load env params from a file
|
38 |
+
_, static_env_params, env_params = load_from_json_file(os.path.join("worlds", config["custom_path"]))
|
39 |
+
return env_params, static_env_params.replace(
|
40 |
+
frame_skip=config["frame_skip"],
|
41 |
+
)
|
42 |
+
env_params = EnvParams()
|
43 |
+
|
44 |
+
static_env_params = StaticEnvParams().replace(
|
45 |
+
num_polygons=config["num_polygons"],
|
46 |
+
num_circles=config["num_circles"],
|
47 |
+
num_joints=config["num_joints"],
|
48 |
+
num_thrusters=config["num_thrusters"],
|
49 |
+
frame_skip=config["frame_skip"],
|
50 |
+
num_motor_bindings=config["num_motor_bindings"],
|
51 |
+
num_thruster_bindings=config["num_thruster_bindings"],
|
52 |
+
)
|
53 |
+
|
54 |
+
return env_params, static_env_params
|
55 |
+
|
56 |
+
|
57 |
+
def generate_ued_params_from_config(config) -> UEDParams:
|
58 |
+
ans = UEDParams()
|
59 |
+
|
60 |
+
if config["env_size_name"] == "s":
|
61 |
+
ans = ans.replace(add_shape_n_proposals=1) # otherwise we get a very weird XLA bug.
|
62 |
+
if "fixate_chance_max" in config:
|
63 |
+
print("Changing fixate chance max to", config["fixate_chance_max"])
|
64 |
+
ans = ans.replace(fixate_chance_max=config["fixate_chance_max"])
|
65 |
+
return ans
|
66 |
+
|
67 |
+
|
68 |
+
def get_eval_level_groups(eval_levels: List[str]) -> List[Tuple[str, str]]:
|
69 |
+
def get_groups(s):
|
70 |
+
# This is the size group
|
71 |
+
group_one = s.split("/")[0]
|
72 |
+
group_two = s.split("/")[1].split("_")[0]
|
73 |
+
group_two = "".join([i for i in group_two if not i.isdigit()])
|
74 |
+
if group_two == "h":
|
75 |
+
group_two = "handmade"
|
76 |
+
if group_two == "r":
|
77 |
+
group_two = "random"
|
78 |
+
return f"{group_one}_all", f"{group_one}_{group_two}"
|
79 |
+
|
80 |
+
indices = defaultdict(list)
|
81 |
+
|
82 |
+
for idx, s in enumerate(eval_levels):
|
83 |
+
groups = get_groups(s)
|
84 |
+
for group in groups:
|
85 |
+
indices[group].append(idx)
|
86 |
+
|
87 |
+
indices2 = {}
|
88 |
+
for g in indices:
|
89 |
+
indices2[g] = np.array(indices[g])
|
90 |
+
|
91 |
+
return indices2
|
92 |
+
|
93 |
+
|
94 |
+
def normalise_config(config, name, editor_config=False):
|
95 |
+
old_config = copy.deepcopy(config)
|
96 |
+
keys = ["env", "learning", "model", "misc", "eval", "ued", "env_size", "train_levels"]
|
97 |
+
for k in keys:
|
98 |
+
if k not in config:
|
99 |
+
config[k] = {}
|
100 |
+
small_d = config[k]
|
101 |
+
del config[k]
|
102 |
+
for kk, vv in small_d.items():
|
103 |
+
assert kk not in config, kk
|
104 |
+
config[kk] = vv
|
105 |
+
|
106 |
+
if not editor_config:
|
107 |
+
config["eval_env_size_true"] = config["eval_env_size"]
|
108 |
+
if config["num_train_envs"] == 2048 and "Pixels" in config["env_name"]:
|
109 |
+
config["num_train_envs"] = 512
|
110 |
+
if "SFL" in name and config["env_size_name"] in ["m", "l"]:
|
111 |
+
config["eval_num_attempts"] = 6 # to avoid a very weird XLA bug.
|
112 |
+
config["hash"] = get_hash_without_seed(config)
|
113 |
+
|
114 |
+
config["random_hash"] = np.random.randint(2**31)
|
115 |
+
|
116 |
+
config["log_save_path"] = f"logs/{config['hash']}/{config['seed']}-{get_date()}"
|
117 |
+
os.makedirs(config["log_save_path"], exist_ok=True)
|
118 |
+
with open(f"{config['log_save_path']}/config.yaml", "w") as f:
|
119 |
+
f.write(OmegaConf.to_yaml(old_config))
|
120 |
+
if config["group"] == "auto":
|
121 |
+
config["group"] = f"{name}-" + config["group_auto_prefix"] + config["env_name"].replace("Kinetix-", "")
|
122 |
+
config["group"] += "-" + str(config["env_size_name"])
|
123 |
+
|
124 |
+
if config["eval_levels"] == ["auto"] or config["eval_levels"] == "auto":
|
125 |
+
config["eval_levels"] = config["train_levels_list"]
|
126 |
+
print("Using Auto eval levels:", config["eval_levels"])
|
127 |
+
config["num_eval_levels"] = len(config["eval_levels"])
|
128 |
+
|
129 |
+
steps = (
|
130 |
+
config["num_steps"]
|
131 |
+
* config.get("outer_rollout_steps", 1)
|
132 |
+
* config["num_train_envs"]
|
133 |
+
* (2 if name == "PAIRED" else 1)
|
134 |
+
)
|
135 |
+
config["num_updates"] = int(config["total_timesteps"]) // steps
|
136 |
+
|
137 |
+
nsteps = int(config["total_timesteps"] // 1e6)
|
138 |
+
letter = "M"
|
139 |
+
if nsteps >= 1000:
|
140 |
+
nsteps = nsteps // 1000
|
141 |
+
letter = "B"
|
142 |
+
config["run_name"] = (
|
143 |
+
config["env_name"] + f"-{name}-" + str(nsteps) + letter + "-" + str(config["num_train_envs"])
|
144 |
+
)
|
145 |
+
|
146 |
+
if config["checkpoint_save_freq"] >= config["num_updates"]:
|
147 |
+
config["checkpoint_save_freq"] = config["num_updates"]
|
148 |
+
return config
|
149 |
+
|
150 |
+
|
151 |
+
def get_tags(config, name):
|
152 |
+
return [name]
|
153 |
+
tags = [name]
|
154 |
+
if name in ["PLR", "ACCEL", "DR"]:
|
155 |
+
if config["use_accel"]:
|
156 |
+
tags.append("ACCEL")
|
157 |
+
else:
|
158 |
+
tags.append("PLR")
|
159 |
+
return tags
|
160 |
+
|
161 |
+
|
162 |
+
def init_wandb(config, name) -> wandb.run:
|
163 |
+
run = wandb.init(
|
164 |
+
config=config,
|
165 |
+
project=config["wandb_project"],
|
166 |
+
group=config["group"],
|
167 |
+
name=config["run_name"],
|
168 |
+
entity=config["wandb_entity"],
|
169 |
+
mode=config["wandb_mode"],
|
170 |
+
tags=get_tags(config, name),
|
171 |
+
)
|
172 |
+
wandb.define_metric("timing/num_updates")
|
173 |
+
wandb.define_metric("timing/num_env_steps")
|
174 |
+
wandb.define_metric("*", step_metric="timing/num_env_steps")
|
175 |
+
wandb.define_metric("timing/sps", step_metric="timing/num_env_steps")
|
176 |
+
return run
|
177 |
+
|
178 |
+
|
179 |
+
def save_data_to_local_file(data_to_save, config):
|
180 |
+
if not config.get("save_local_data", False):
|
181 |
+
return
|
182 |
+
|
183 |
+
def reverse_in(li, value):
|
184 |
+
for i, v in enumerate(li):
|
185 |
+
if v in value:
|
186 |
+
return True
|
187 |
+
return False
|
188 |
+
|
189 |
+
clean_data = {k: v for k, v in data_to_save.items() if not reverse_in(["media/", "images/"], k)}
|
190 |
+
|
191 |
+
def _clean(x):
|
192 |
+
if isinstance(x, jnp.ndarray):
|
193 |
+
return x.tolist()
|
194 |
+
elif isinstance(x, jnp.float32):
|
195 |
+
if jnp.isnan(x):
|
196 |
+
return -float("inf")
|
197 |
+
return round(float(x) * 1000) / 1000
|
198 |
+
elif isinstance(x, jnp.int32):
|
199 |
+
return int(x)
|
200 |
+
return x
|
201 |
+
|
202 |
+
clean_data = jax.tree_map(lambda x: _clean(x), clean_data)
|
203 |
+
print("Saving this data:", clean_data)
|
204 |
+
with open(f"{config['log_save_path']}/data.jsonl", "a+") as f:
|
205 |
+
f.write(json.dumps(clean_data) + "\n")
|
206 |
+
|
207 |
+
|
208 |
+
def compress_log_files_after_run(config):
|
209 |
+
fpath = f"{config['log_save_path']}/data.jsonl"
|
210 |
+
with open(fpath, "rb") as f_in, gzip.open(fpath + ".gz", "wb") as f_out:
|
211 |
+
f_out.writelines(f_in)
|
212 |
+
|
213 |
+
|
214 |
+
def get_video_frequency(config, update_step):
|
215 |
+
frac_through_training = update_step / config["num_updates"]
|
216 |
+
vid_frequency = (
|
217 |
+
config["eval_freq"]
|
218 |
+
* config["video_frequency"]
|
219 |
+
* jax.lax.select(
|
220 |
+
(0.1 <= frac_through_training) & (frac_through_training < 0.3),
|
221 |
+
1,
|
222 |
+
jax.lax.select(
|
223 |
+
(0.3 <= frac_through_training) & (frac_through_training < 0.6),
|
224 |
+
2,
|
225 |
+
4,
|
226 |
+
),
|
227 |
+
)
|
228 |
+
)
|
229 |
+
return vid_frequency
|
kinetix/util/learning.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from enum import IntEnum
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
import chex
|
10 |
+
import jax
|
11 |
+
import jax.numpy as jnp
|
12 |
+
import numpy as np
|
13 |
+
import optax
|
14 |
+
import orbax.checkpoint as ocp
|
15 |
+
from flax import core, struct
|
16 |
+
from flax.training.train_state import TrainState as BaseTrainState
|
17 |
+
|
18 |
+
import wandb
|
19 |
+
from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv
|
20 |
+
from jaxued.level_sampler import LevelSampler
|
21 |
+
from jaxued.utils import compute_max_returns, max_mc, positive_value_loss
|
22 |
+
|
23 |
+
from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name
|
24 |
+
from kinetix.environment.env_state import StaticEnvParams
|
25 |
+
from kinetix.environment.utils import permute_pcg_state
|
26 |
+
from kinetix.environment.wrappers import (
|
27 |
+
UnderspecifiedToGymnaxWrapper,
|
28 |
+
LogWrapper,
|
29 |
+
DenseRewardWrapper,
|
30 |
+
AutoReplayWrapper,
|
31 |
+
)
|
32 |
+
from kinetix.models import make_network_from_config
|
33 |
+
from kinetix.pcg.pcg import env_state_to_pcg_state
|
34 |
+
from kinetix.render.renderer_pixels import make_render_pixels
|
35 |
+
from kinetix.models.actor_critic import ScannedRNN
|
36 |
+
from kinetix.util.saving import (
|
37 |
+
expand_pcg_state,
|
38 |
+
get_pcg_state_from_json,
|
39 |
+
load_pcg_state_pickle,
|
40 |
+
load_world_state_pickle,
|
41 |
+
stack_list_of_pytrees,
|
42 |
+
import_env_state_from_json,
|
43 |
+
load_from_json_file,
|
44 |
+
)
|
45 |
+
from flax.training.train_state import TrainState
|
46 |
+
|
47 |
+
BASE_DIR = "worlds"
|
48 |
+
|
49 |
+
DEFAULT_EVAL_LEVELS = [
|
50 |
+
"easy.cartpole",
|
51 |
+
"easy.flappy_bird",
|
52 |
+
"easy.unicycle",
|
53 |
+
"easy.car_left",
|
54 |
+
"easy.car_right",
|
55 |
+
"easy.pinball",
|
56 |
+
"easy.swing_up",
|
57 |
+
"easy.thruster",
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
+
def get_eval_levels(eval_levels, static_env_params):
|
62 |
+
should_permute = [".permute" in l for l in eval_levels]
|
63 |
+
eval_levels = [re.sub(r"\.permute\d+", "", l) for l in eval_levels]
|
64 |
+
ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in eval_levels]
|
65 |
+
ls = [expand_pcg_state(l, static_env_params) for l in ls]
|
66 |
+
new_ls = []
|
67 |
+
rng = jax.random.PRNGKey(0)
|
68 |
+
for sp, l in zip(should_permute, ls):
|
69 |
+
rng, _rng = jax.random.split(rng)
|
70 |
+
if sp:
|
71 |
+
l = permute_pcg_state(_rng, l, static_env_params)
|
72 |
+
new_ls.append(l)
|
73 |
+
return stack_list_of_pytrees(new_ls)
|
74 |
+
|
75 |
+
|
76 |
+
def evaluate_rnn( # from jaxued
|
77 |
+
rng: chex.PRNGKey,
|
78 |
+
env: UnderspecifiedEnv,
|
79 |
+
env_params: EnvParams,
|
80 |
+
train_state: TrainState,
|
81 |
+
init_hstate: chex.ArrayTree,
|
82 |
+
init_obs: Observation,
|
83 |
+
init_env_state: EnvState,
|
84 |
+
max_episode_length: int,
|
85 |
+
keep_states=True,
|
86 |
+
return_trajectories=False,
|
87 |
+
) -> Tuple[chex.Array, chex.Array, chex.Array]:
|
88 |
+
"""This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths)
|
89 |
+
|
90 |
+
Args:
|
91 |
+
rng (chex.PRNGKey):
|
92 |
+
env (UnderspecifiedEnv):
|
93 |
+
env_params (EnvParams):
|
94 |
+
train_state (TrainState):
|
95 |
+
init_hstate (chex.ArrayTree): Shape (num_levels, )
|
96 |
+
init_obs (Observation): Shape (num_levels, )
|
97 |
+
init_env_state (EnvState): Shape (num_levels, )
|
98 |
+
max_episode_length (int):
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,)
|
102 |
+
"""
|
103 |
+
num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0]
|
104 |
+
|
105 |
+
def step(carry, _):
|
106 |
+
rng, hstate, obs, state, done, mask, episode_length = carry
|
107 |
+
rng, rng_action, rng_step = jax.random.split(rng, 3)
|
108 |
+
|
109 |
+
x = jax.tree.map(lambda x: x[None, ...], (obs, done))
|
110 |
+
hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x)
|
111 |
+
action = pi.sample(seed=rng_action).squeeze(0)
|
112 |
+
|
113 |
+
obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
114 |
+
jax.random.split(rng_step, num_levels), state, action, env_params
|
115 |
+
)
|
116 |
+
|
117 |
+
next_mask = mask & ~done
|
118 |
+
episode_length += mask
|
119 |
+
|
120 |
+
if keep_states:
|
121 |
+
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, done, info)
|
122 |
+
else:
|
123 |
+
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, done, info)
|
124 |
+
|
125 |
+
(_, _, _, _, _, _, episode_lengths), (states, rewards, dones, infos) = jax.lax.scan(
|
126 |
+
step,
|
127 |
+
(
|
128 |
+
rng,
|
129 |
+
init_hstate,
|
130 |
+
init_obs,
|
131 |
+
init_env_state,
|
132 |
+
jnp.zeros(num_levels, dtype=bool),
|
133 |
+
jnp.ones(num_levels, dtype=bool),
|
134 |
+
jnp.zeros(num_levels, dtype=jnp.int32),
|
135 |
+
),
|
136 |
+
None,
|
137 |
+
length=max_episode_length,
|
138 |
+
)
|
139 |
+
done_idx = jnp.argmax(dones, axis=0)
|
140 |
+
|
141 |
+
to_return = (states, rewards, done_idx, episode_lengths, infos)
|
142 |
+
if return_trajectories:
|
143 |
+
return to_return, (dones, rewards)
|
144 |
+
return to_return
|
145 |
+
|
146 |
+
|
147 |
+
def general_eval(
|
148 |
+
rng: chex.PRNGKey,
|
149 |
+
eval_env: UnderspecifiedEnv,
|
150 |
+
env_params: EnvParams,
|
151 |
+
train_state: TrainState,
|
152 |
+
levels: EnvState,
|
153 |
+
num_eval_steps: int,
|
154 |
+
num_levels: int,
|
155 |
+
keep_states=True,
|
156 |
+
return_trajectories=False,
|
157 |
+
):
|
158 |
+
"""
|
159 |
+
This evaluates the current policy on the set of evaluation levels
|
160 |
+
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
|
161 |
+
"""
|
162 |
+
rng, rng_reset = jax.random.split(rng)
|
163 |
+
init_obs, init_env_state = jax.vmap(eval_env.reset_to_level, (0, 0, None))(
|
164 |
+
jax.random.split(rng_reset, num_levels), levels, env_params
|
165 |
+
)
|
166 |
+
init_hstate = ScannedRNN.initialize_carry(num_levels)
|
167 |
+
(states, rewards, done_idx, episode_lengths, infos), (dones, reward) = evaluate_rnn(
|
168 |
+
rng,
|
169 |
+
eval_env,
|
170 |
+
env_params,
|
171 |
+
train_state,
|
172 |
+
init_hstate,
|
173 |
+
init_obs,
|
174 |
+
init_env_state,
|
175 |
+
num_eval_steps,
|
176 |
+
keep_states=keep_states,
|
177 |
+
return_trajectories=True,
|
178 |
+
)
|
179 |
+
mask = jnp.arange(num_eval_steps)[..., None] < episode_lengths
|
180 |
+
cum_rewards = (rewards * mask).sum(axis=0)
|
181 |
+
to_return = (
|
182 |
+
states,
|
183 |
+
cum_rewards,
|
184 |
+
done_idx,
|
185 |
+
episode_lengths,
|
186 |
+
infos,
|
187 |
+
) # (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
|
188 |
+
|
189 |
+
if return_trajectories:
|
190 |
+
return to_return, (dones, reward)
|
191 |
+
return to_return
|
192 |
+
|
193 |
+
|
194 |
+
def compute_gae(
|
195 |
+
gamma: float,
|
196 |
+
lambd: float,
|
197 |
+
last_value: chex.Array,
|
198 |
+
values: chex.Array,
|
199 |
+
rewards: chex.Array,
|
200 |
+
dones: chex.Array,
|
201 |
+
) -> Tuple[chex.Array, chex.Array]:
|
202 |
+
"""This takes in arrays of shape (NUM_STEPS, NUM_ENVS) and returns the advantages and targets.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
gamma (float):
|
206 |
+
lambd (float):
|
207 |
+
last_value (chex.Array): Shape (NUM_ENVS)
|
208 |
+
values (chex.Array): Shape (NUM_STEPS, NUM_ENVS)
|
209 |
+
rewards (chex.Array): Shape (NUM_STEPS, NUM_ENVS)
|
210 |
+
dones (chex.Array): Shape (NUM_STEPS, NUM_ENVS)
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Tuple[chex.Array, chex.Array]: advantages, targets; each of shape (NUM_STEPS, NUM_ENVS)
|
214 |
+
"""
|
215 |
+
|
216 |
+
def compute_gae_at_timestep(carry, x):
|
217 |
+
gae, next_value = carry
|
218 |
+
value, reward, done = x
|
219 |
+
delta = reward + gamma * next_value * (1 - done) - value
|
220 |
+
gae = delta + gamma * lambd * (1 - done) * gae
|
221 |
+
return (gae, value), gae
|
222 |
+
|
223 |
+
_, advantages = jax.lax.scan(
|
224 |
+
compute_gae_at_timestep,
|
225 |
+
(jnp.zeros_like(last_value), last_value),
|
226 |
+
(values, rewards, dones),
|
227 |
+
reverse=True,
|
228 |
+
unroll=16,
|
229 |
+
)
|
230 |
+
return advantages, advantages + values
|
231 |
+
|
232 |
+
|
233 |
+
def sample_trajectories_rnn(
|
234 |
+
rng: chex.PRNGKey,
|
235 |
+
env: UnderspecifiedEnv,
|
236 |
+
env_params: EnvParams,
|
237 |
+
train_state: TrainState,
|
238 |
+
init_hstate: chex.ArrayTree,
|
239 |
+
init_obs: Observation,
|
240 |
+
init_env_state: EnvState,
|
241 |
+
num_envs: int,
|
242 |
+
max_episode_length: int,
|
243 |
+
return_states: bool = False,
|
244 |
+
) -> Tuple[
|
245 |
+
Tuple[chex.PRNGKey, TrainState, chex.ArrayTree, Observation, EnvState, chex.Array],
|
246 |
+
Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict],
|
247 |
+
]:
|
248 |
+
"""This samples trajectories from the environment using the agent specified by the `train_state`.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
|
252 |
+
rng (chex.PRNGKey): Singleton
|
253 |
+
env (UnderspecifiedEnv):
|
254 |
+
env_params (EnvParams):
|
255 |
+
train_state (TrainState): Singleton
|
256 |
+
init_hstate (chex.ArrayTree): This is the init RNN hidden state, has to have shape (NUM_ENVS, ...)
|
257 |
+
init_obs (Observation): The initial observation, shape (NUM_ENVS, ...)
|
258 |
+
init_env_state (EnvState): The initial env state (NUM_ENVS, ...)
|
259 |
+
num_envs (int): The number of envs that are vmapped over.
|
260 |
+
max_episode_length (int): The maximum episode length, i.e., the number of steps to do the rollouts for.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
Tuple[Tuple[chex.PRNGKey, TrainState, chex.ArrayTree, Observation, EnvState, chex.Array], Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict]]: (rng, train_state, hstate, last_obs, last_env_state, last_value), traj, where traj is (obs, action, reward, done, log_prob, value, info). The first element in the tuple consists of arrays that have shapes (NUM_ENVS, ...) (except `rng` and and `train_state` which are singleton). The second element in the tuple is of shape (NUM_STEPS, NUM_ENVS, ...), and it contains the trajectory.
|
264 |
+
"""
|
265 |
+
|
266 |
+
def sample_step(carry, _):
|
267 |
+
rng, train_state, hstate, obs, env_state, last_done = carry
|
268 |
+
prev_state = env_state
|
269 |
+
rng, rng_action, rng_step = jax.random.split(rng, 3)
|
270 |
+
|
271 |
+
x = jax.tree.map(lambda x: x[None, ...], (obs, last_done))
|
272 |
+
hstate, pi, value = train_state.apply_fn(train_state.params, hstate, x)
|
273 |
+
action = pi.sample(seed=rng_action)
|
274 |
+
log_prob = pi.log_prob(action)
|
275 |
+
value, action, log_prob = jax.tree.map(lambda x: x.squeeze(0), (value, action, log_prob))
|
276 |
+
|
277 |
+
next_obs, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
278 |
+
jax.random.split(rng_step, num_envs), env_state, action, env_params
|
279 |
+
)
|
280 |
+
|
281 |
+
carry = (rng, train_state, hstate, next_obs, env_state, done)
|
282 |
+
step = (obs, action, reward, done, log_prob, value, info)
|
283 |
+
if return_states:
|
284 |
+
step += (prev_state,)
|
285 |
+
return carry, step
|
286 |
+
|
287 |
+
(rng, train_state, hstate, last_obs, last_env_state, last_done), traj = jax.lax.scan(
|
288 |
+
sample_step,
|
289 |
+
(
|
290 |
+
rng,
|
291 |
+
train_state,
|
292 |
+
init_hstate,
|
293 |
+
init_obs,
|
294 |
+
init_env_state,
|
295 |
+
jnp.zeros(num_envs, dtype=bool),
|
296 |
+
),
|
297 |
+
None,
|
298 |
+
length=max_episode_length,
|
299 |
+
)
|
300 |
+
|
301 |
+
x = jax.tree.map(lambda x: x[None, ...], (last_obs, last_done))
|
302 |
+
_, _, last_value = train_state.apply_fn(train_state.params, hstate, x)
|
303 |
+
|
304 |
+
my_obs = traj[0]
|
305 |
+
rew = traj[2]
|
306 |
+
|
307 |
+
return (rng, train_state, hstate, last_obs, last_env_state, last_value.squeeze(0)), traj
|
308 |
+
|
309 |
+
|
310 |
+
def update_actor_critic_rnn(
|
311 |
+
rng: chex.PRNGKey,
|
312 |
+
train_state: TrainState,
|
313 |
+
init_hstate: chex.ArrayTree,
|
314 |
+
batch: chex.ArrayTree,
|
315 |
+
num_envs: int,
|
316 |
+
n_steps: int,
|
317 |
+
n_minibatch: int,
|
318 |
+
n_epochs: int,
|
319 |
+
clip_eps: float,
|
320 |
+
entropy_coeff: float,
|
321 |
+
critic_coeff: float,
|
322 |
+
update_grad: bool = True,
|
323 |
+
) -> Tuple[Tuple[chex.PRNGKey, TrainState], chex.ArrayTree]:
|
324 |
+
"""This function takes in a rollout, and PPO hyperparameters, and updates the train state.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
rng (chex.PRNGKey):
|
328 |
+
train_state (TrainState):
|
329 |
+
init_hstate (chex.ArrayTree):
|
330 |
+
batch (chex.ArrayTree): obs, actions, dones, log_probs, values, targets, advantages
|
331 |
+
num_envs (int):
|
332 |
+
n_steps (int):
|
333 |
+
n_minibatch (int):
|
334 |
+
n_epochs (int):
|
335 |
+
clip_eps (float):
|
336 |
+
entropy_coeff (float):
|
337 |
+
critic_coeff (float):
|
338 |
+
update_grad (bool, optional): If False, the train state does not actually get updated. Defaults to True.
|
339 |
+
|
340 |
+
Returns:
|
341 |
+
Tuple[Tuple[chex.PRNGKey, TrainState], chex.ArrayTree]: It returns a new rng, the updated train_state, and the losses. The losses have structure (loss, (l_vf, l_clip, entropy))
|
342 |
+
"""
|
343 |
+
obs, actions, dones, log_probs, values, targets, advantages = batch
|
344 |
+
last_dones = jnp.roll(dones, 1, axis=0).at[0].set(False)
|
345 |
+
batch = obs, actions, last_dones, log_probs, values, targets, advantages
|
346 |
+
|
347 |
+
def update_epoch(carry, _):
|
348 |
+
def update_minibatch(train_state, minibatch):
|
349 |
+
init_hstate, obs, actions, last_dones, log_probs, values, targets, advantages = minibatch
|
350 |
+
|
351 |
+
def loss_fn(params):
|
352 |
+
_, pi, values_pred = train_state.apply_fn(params, init_hstate, (obs, last_dones))
|
353 |
+
log_probs_pred = pi.log_prob(actions)
|
354 |
+
entropy = pi.entropy().mean()
|
355 |
+
|
356 |
+
ratio = jnp.exp(log_probs_pred - log_probs)
|
357 |
+
A = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
358 |
+
l_clip = (-jnp.minimum(ratio * A, jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps) * A)).mean()
|
359 |
+
|
360 |
+
values_pred_clipped = values + (values_pred - values).clip(-clip_eps, clip_eps)
|
361 |
+
l_vf = 0.5 * jnp.maximum((values_pred - targets) ** 2, (values_pred_clipped - targets) ** 2).mean()
|
362 |
+
|
363 |
+
loss = l_clip + critic_coeff * l_vf - entropy_coeff * entropy
|
364 |
+
|
365 |
+
return loss, (l_vf, l_clip, entropy)
|
366 |
+
|
367 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
368 |
+
loss, grads = grad_fn(train_state.params)
|
369 |
+
if update_grad:
|
370 |
+
train_state = train_state.apply_gradients(grads=grads)
|
371 |
+
grad_norm = jnp.linalg.norm(
|
372 |
+
jnp.concatenate(jax.tree_util.tree_map(lambda x: x.flatten(), jax.tree_util.tree_flatten(grads)[0]))
|
373 |
+
)
|
374 |
+
return train_state, (loss, grad_norm)
|
375 |
+
|
376 |
+
rng, train_state = carry
|
377 |
+
rng, rng_perm = jax.random.split(rng)
|
378 |
+
permutation = jax.random.permutation(rng_perm, num_envs)
|
379 |
+
minibatches = (
|
380 |
+
jax.tree.map(
|
381 |
+
lambda x: jnp.take(x, permutation, axis=0).reshape(n_minibatch, -1, *x.shape[1:]),
|
382 |
+
init_hstate,
|
383 |
+
),
|
384 |
+
*jax.tree.map(
|
385 |
+
lambda x: jnp.take(x, permutation, axis=1)
|
386 |
+
.reshape(x.shape[0], n_minibatch, -1, *x.shape[2:])
|
387 |
+
.swapaxes(0, 1),
|
388 |
+
batch,
|
389 |
+
),
|
390 |
+
)
|
391 |
+
train_state, (losses, grads) = jax.lax.scan(update_minibatch, train_state, minibatches)
|
392 |
+
return (rng, train_state), (losses, grads)
|
393 |
+
|
394 |
+
return jax.lax.scan(update_epoch, (rng, train_state), None, n_epochs)
|
395 |
+
|
396 |
+
|
397 |
+
@partial(jax.jit, static_argnums=(0, 2, 8, 9))
|
398 |
+
def sample_trajectories_and_learn(
|
399 |
+
env: UnderspecifiedEnv,
|
400 |
+
env_params: EnvParams,
|
401 |
+
config: dict,
|
402 |
+
rng: chex.PRNGKey,
|
403 |
+
train_state: TrainState,
|
404 |
+
init_hstate: chex.Array,
|
405 |
+
init_obs: Observation,
|
406 |
+
init_env_state: EnvState,
|
407 |
+
update_grad: bool = True,
|
408 |
+
return_states: bool = False,
|
409 |
+
) -> Tuple[
|
410 |
+
Tuple[chex.PRNGKey, TrainState, Observation, EnvState],
|
411 |
+
Tuple[
|
412 |
+
Observation,
|
413 |
+
chex.Array,
|
414 |
+
chex.Array,
|
415 |
+
chex.Array,
|
416 |
+
chex.Array,
|
417 |
+
chex.Array,
|
418 |
+
dict,
|
419 |
+
chex.Array,
|
420 |
+
chex.Array,
|
421 |
+
chex.ArrayTree,
|
422 |
+
chex.Array,
|
423 |
+
],
|
424 |
+
]:
|
425 |
+
"""This function loops the following:
|
426 |
+
- rollout for config['num_steps']
|
427 |
+
- learn / update policy
|
428 |
+
|
429 |
+
And it loops it for config['outer_rollout_steps'].
|
430 |
+
What is returns is a new carry (rng, train_state, init_obs, init_env_state), and concatenated rollouts. The shape of the rollouts are config['num_steps'] * config['outer_rollout_steps']. In other words, the trajectories returned by this function are the same as if we ran rollouts for config['num_steps'] * config['outer_rollout_steps'] steps, but the agent does perform PPO updates in between.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
env (UnderspecifiedEnv):
|
434 |
+
env_params (EnvParams):
|
435 |
+
config (dict):
|
436 |
+
rng (chex.PRNGKey):
|
437 |
+
train_state (TrainState):
|
438 |
+
init_obs (Observation):
|
439 |
+
init_env_state (EnvState):
|
440 |
+
update_grad (bool, optional): Defaults to True.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
Tuple[Tuple[chex.PRNGKey, TrainState, Observation, EnvState], Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict, chex.Array, chex.Array, chex.ArrayTree, chex.Array]]: This returns a tuple:
|
444 |
+
(
|
445 |
+
(rng, train_state, init_obs, init_env_state),
|
446 |
+
(obs, actions, rewards, dones, log_probs, values, info, advantages, targets, losses, grads)
|
447 |
+
)
|
448 |
+
"""
|
449 |
+
|
450 |
+
def single_step(carry, _):
|
451 |
+
rng, train_state, init_hstate, init_obs, init_env_state = carry
|
452 |
+
((rng, train_state, new_hstate, last_obs, last_env_state, last_value), traj,) = sample_trajectories_rnn(
|
453 |
+
rng,
|
454 |
+
env,
|
455 |
+
env_params,
|
456 |
+
train_state,
|
457 |
+
init_hstate,
|
458 |
+
init_obs,
|
459 |
+
init_env_state,
|
460 |
+
config["num_train_envs"],
|
461 |
+
config["num_steps"],
|
462 |
+
return_states=return_states,
|
463 |
+
)
|
464 |
+
if return_states:
|
465 |
+
states = traj[-1]
|
466 |
+
traj = traj[:-1]
|
467 |
+
|
468 |
+
(obs, actions, rewards, dones, log_probs, values, info) = traj
|
469 |
+
advantages, targets = compute_gae(config["gamma"], config["gae_lambda"], last_value, values, rewards, dones)
|
470 |
+
|
471 |
+
# Update the policy using trajectories collected from replay levels
|
472 |
+
(rng, train_state), (losses, grads) = update_actor_critic_rnn(
|
473 |
+
rng,
|
474 |
+
train_state,
|
475 |
+
init_hstate,
|
476 |
+
(obs, actions, dones, log_probs, values, targets, advantages),
|
477 |
+
config["num_train_envs"],
|
478 |
+
config["num_steps"],
|
479 |
+
config["num_minibatches"],
|
480 |
+
config["update_epochs"],
|
481 |
+
config["clip_eps"],
|
482 |
+
config["ent_coef"],
|
483 |
+
config["vf_coef"],
|
484 |
+
update_grad=update_grad,
|
485 |
+
)
|
486 |
+
new_carry = (rng, train_state, new_hstate, last_obs, last_env_state)
|
487 |
+
step = (obs, actions, rewards, dones, log_probs, values, info, advantages, targets, losses, grads)
|
488 |
+
if return_states:
|
489 |
+
step += (states,)
|
490 |
+
return new_carry, step
|
491 |
+
|
492 |
+
carry = (rng, train_state, init_hstate, init_obs, init_env_state)
|
493 |
+
new_carry, all_rollouts = jax.lax.scan(single_step, carry, None, length=config["outer_rollout_steps"])
|
494 |
+
|
495 |
+
all_rollouts = jax.tree_util.tree_map(lambda x: jnp.concatenate(x, axis=0), all_rollouts)
|
496 |
+
return new_carry, all_rollouts
|
497 |
+
|
498 |
+
|
499 |
+
def no_op_rollout(
|
500 |
+
env: UnderspecifiedEnv,
|
501 |
+
env_params: EnvParams,
|
502 |
+
rng: chex.PRNGKey,
|
503 |
+
init_obs: Observation,
|
504 |
+
init_env_state: EnvState,
|
505 |
+
num_envs: int,
|
506 |
+
max_episode_length: int,
|
507 |
+
do_random=False,
|
508 |
+
):
|
509 |
+
|
510 |
+
noop = jnp.array(env.action_type.noop_action())
|
511 |
+
zero_action = jnp.repeat(noop[None, ...], num_envs, axis=0)
|
512 |
+
SHAPE = zero_action.shape
|
513 |
+
|
514 |
+
def sample_step(carry, _):
|
515 |
+
rng, obs, env_state, last_done = carry
|
516 |
+
rng, rng_step, _rng = jax.random.split(rng, 3)
|
517 |
+
if do_random:
|
518 |
+
action = jax.vmap(env.action_space(env_params).sample)(jax.random.split(_rng, num_envs))
|
519 |
+
else:
|
520 |
+
action = zero_action
|
521 |
+
|
522 |
+
next_obs, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
|
523 |
+
jax.random.split(rng_step, num_envs), env_state, action, env_params
|
524 |
+
)
|
525 |
+
|
526 |
+
carry = (rng, next_obs, env_state, done)
|
527 |
+
return carry, (obs, action, reward, done, info)
|
528 |
+
|
529 |
+
(rng, last_obs, last_env_state, last_done), traj = jax.lax.scan(
|
530 |
+
sample_step,
|
531 |
+
(
|
532 |
+
rng,
|
533 |
+
init_obs,
|
534 |
+
init_env_state,
|
535 |
+
jnp.zeros(num_envs, dtype=bool),
|
536 |
+
),
|
537 |
+
None,
|
538 |
+
length=max_episode_length,
|
539 |
+
)
|
540 |
+
|
541 |
+
info = traj[-1]
|
542 |
+
dones = traj[-2]
|
543 |
+
|
544 |
+
returns_per_env = (info["returned_episode_returns"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0))
|
545 |
+
lens_per_env = (info["returned_episode_lengths"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0))
|
546 |
+
success_per_env = (info["returned_episode_solved"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0))
|
547 |
+
return returns_per_env, lens_per_env, success_per_env
|
548 |
+
|
549 |
+
|
550 |
+
def no_op_and_random_rollout(
|
551 |
+
env: UnderspecifiedEnv,
|
552 |
+
env_params: EnvParams,
|
553 |
+
rng: chex.PRNGKey,
|
554 |
+
init_obs: Observation,
|
555 |
+
init_env_state: EnvState,
|
556 |
+
num_envs: int,
|
557 |
+
max_episode_length: int,
|
558 |
+
):
|
559 |
+
returns_noop, lens_noop, success_noop = no_op_rollout(
|
560 |
+
env, env_params, rng, init_obs, init_env_state, num_envs, max_episode_length, do_random=False
|
561 |
+
)
|
562 |
+
returns_random, lens_random, success_random = no_op_rollout(
|
563 |
+
env, env_params, rng, init_obs, init_env_state, num_envs, max_episode_length, do_random=True
|
564 |
+
)
|
565 |
+
return returns_noop, lens_noop, success_noop, returns_random, lens_random, success_random
|
kinetix/util/saving.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
from typing import Any, Dict, Union
|
5 |
+
|
6 |
+
import flax.serialization
|
7 |
+
import flax.serialization
|
8 |
+
import flax.serialization
|
9 |
+
import flax.serialization
|
10 |
+
import flax.serialization
|
11 |
+
import flax.serialization
|
12 |
+
import flax.serialization
|
13 |
+
import jax
|
14 |
+
import jax.numpy as jnp
|
15 |
+
import flax
|
16 |
+
import wandb
|
17 |
+
from jax2d.engine import (
|
18 |
+
calculate_collision_matrix,
|
19 |
+
get_empty_collision_manifolds,
|
20 |
+
get_pairwise_interaction_indices,
|
21 |
+
recalculate_mass_and_inertia,
|
22 |
+
)
|
23 |
+
from jax2d.sim_state import RigidBody, SimState
|
24 |
+
from kinetix.environment.env_state import EnvState, StaticEnvParams, EnvParams
|
25 |
+
|
26 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
27 |
+
|
28 |
+
from safetensors.flax import save_file, load_file
|
29 |
+
|
30 |
+
from kinetix.pcg.pcg import env_state_to_pcg_state
|
31 |
+
from kinetix.pcg.pcg_state import PCGState
|
32 |
+
import bz2
|
33 |
+
|
34 |
+
|
35 |
+
def check_if_mass_and_inertia_are_correct(state: SimState, env_params: EnvParams, static_params):
|
36 |
+
new = recalculate_mass_and_inertia(state, static_params, state.polygon_densities, state.circle_densities)
|
37 |
+
|
38 |
+
def _check(a, b, shape, name):
|
39 |
+
a = jnp.where(shape.active, a, jnp.zeros_like(a))
|
40 |
+
b = jnp.where(shape.active, b, jnp.zeros_like(b))
|
41 |
+
|
42 |
+
if not jnp.allclose(a, b):
|
43 |
+
idxs = jnp.arange(len(shape.active))[(a != b) & shape.active]
|
44 |
+
new_one = a[idxs]
|
45 |
+
old_one = b[idxs]
|
46 |
+
raise ValueError(
|
47 |
+
f"Error: {name} is not the same after loading. Indexes {idxs} are incorrect. New = {new_one} | Before = {old_one}"
|
48 |
+
)
|
49 |
+
|
50 |
+
_check(new.polygon.inverse_mass, state.polygon.inverse_mass, state.polygon, "Polygon inverse mass")
|
51 |
+
_check(new.circle.inverse_mass, state.circle.inverse_mass, state.circle, "Circle inverse mass")
|
52 |
+
_check(new.polygon.inverse_inertia, state.polygon.inverse_inertia, state.polygon, "Polygon inverse inertia")
|
53 |
+
_check(new.circle.inverse_inertia, state.circle.inverse_inertia, state.circle, "Circle inverse inertia")
|
54 |
+
return True
|
55 |
+
|
56 |
+
|
57 |
+
def save_pickle(filename, state):
|
58 |
+
with open(filename, "wb") as f:
|
59 |
+
pickle.dump(state, f)
|
60 |
+
|
61 |
+
|
62 |
+
def load_pcg_state_pickle(filename):
|
63 |
+
with open(filename, "rb") as f:
|
64 |
+
return pickle.load(f)
|
65 |
+
|
66 |
+
|
67 |
+
def expand_env_state(env_state: EnvState, static_env_params: StaticEnvParams, ignore_collision_matrix=False):
|
68 |
+
|
69 |
+
num_rects = len(env_state.polygon.position)
|
70 |
+
num_circles = len(env_state.circle.position)
|
71 |
+
num_joints = len(env_state.joint.a_index)
|
72 |
+
num_thrusters = len(env_state.thruster.object_index)
|
73 |
+
|
74 |
+
def _add_dummy(num_to_add, obj):
|
75 |
+
return jax.tree_map(
|
76 |
+
lambda current: jnp.concatenate(
|
77 |
+
[current, jnp.zeros((num_to_add, *current.shape[1:]), dtype=current.dtype)], axis=0
|
78 |
+
),
|
79 |
+
obj,
|
80 |
+
)
|
81 |
+
|
82 |
+
does_need_to_change = False
|
83 |
+
added_rects = 0
|
84 |
+
|
85 |
+
if (
|
86 |
+
num_rects > static_env_params.num_polygons
|
87 |
+
or num_circles > static_env_params.num_circles
|
88 |
+
or num_joints > static_env_params.num_joints
|
89 |
+
):
|
90 |
+
raise Exception(
|
91 |
+
f"The current static_env_params is too small to accommodate the loaded env_state (needs num_rects={num_rects}, num_circles={num_circles}, num_joints={num_joints} but current is {static_env_params.num_polygons}, {static_env_params.num_circles}, {static_env_params.num_joints})."
|
92 |
+
)
|
93 |
+
|
94 |
+
if num_rects < static_env_params.num_polygons:
|
95 |
+
added_rects = static_env_params.num_polygons - num_rects
|
96 |
+
does_need_to_change = True
|
97 |
+
env_state = env_state.replace(
|
98 |
+
polygon=_add_dummy(added_rects, env_state.polygon),
|
99 |
+
polygon_shape_roles=_add_dummy(added_rects, env_state.polygon_shape_roles),
|
100 |
+
polygon_highlighted=_add_dummy(added_rects, env_state.polygon_highlighted),
|
101 |
+
polygon_densities=_add_dummy(added_rects, env_state.polygon_densities),
|
102 |
+
)
|
103 |
+
|
104 |
+
if num_circles < static_env_params.num_circles:
|
105 |
+
does_need_to_change = True
|
106 |
+
n_to_add = static_env_params.num_circles - num_circles
|
107 |
+
env_state = env_state.replace(
|
108 |
+
circle=_add_dummy(n_to_add, env_state.circle),
|
109 |
+
circle_shape_roles=_add_dummy(n_to_add, env_state.circle_shape_roles),
|
110 |
+
circle_highlighted=_add_dummy(n_to_add, env_state.circle_highlighted),
|
111 |
+
circle_densities=_add_dummy(n_to_add, env_state.circle_densities),
|
112 |
+
)
|
113 |
+
|
114 |
+
if num_joints < static_env_params.num_joints:
|
115 |
+
does_need_to_change = True
|
116 |
+
n_to_add = static_env_params.num_joints - num_joints
|
117 |
+
env_state = env_state.replace(
|
118 |
+
joint=_add_dummy(n_to_add, env_state.joint),
|
119 |
+
motor_bindings=_add_dummy(n_to_add, env_state.motor_bindings),
|
120 |
+
motor_auto=_add_dummy(n_to_add, env_state.motor_auto),
|
121 |
+
)
|
122 |
+
|
123 |
+
if num_thrusters < static_env_params.num_thrusters:
|
124 |
+
does_need_to_change = True
|
125 |
+
n_to_add = static_env_params.num_thrusters - num_thrusters
|
126 |
+
env_state = env_state.replace(
|
127 |
+
thruster=_add_dummy(n_to_add, env_state.thruster),
|
128 |
+
thruster_bindings=_add_dummy(n_to_add, env_state.thruster_bindings),
|
129 |
+
)
|
130 |
+
|
131 |
+
# This fixes the indices
|
132 |
+
def _modify_index(old_indices):
|
133 |
+
return jnp.where(old_indices >= num_rects, old_indices + added_rects, old_indices)
|
134 |
+
|
135 |
+
if added_rects > 0:
|
136 |
+
env_state = env_state.replace(
|
137 |
+
joint=env_state.joint.replace(
|
138 |
+
a_index=_modify_index(env_state.joint.a_index),
|
139 |
+
b_index=_modify_index(env_state.joint.b_index),
|
140 |
+
),
|
141 |
+
thruster=env_state.thruster.replace(
|
142 |
+
object_index=_modify_index(env_state.thruster.object_index),
|
143 |
+
),
|
144 |
+
)
|
145 |
+
# Double check the collision manifolds are fine
|
146 |
+
if does_need_to_change or 1:
|
147 |
+
# print("Loading but changing the shapes to match the current static params.")
|
148 |
+
acc_rr_manifolds, acc_cr_manifolds, acc_cc_manifolds = get_empty_collision_manifolds(static_env_params)
|
149 |
+
env_state = env_state.replace(
|
150 |
+
collision_matrix=(
|
151 |
+
env_state.collision_matrix
|
152 |
+
if ignore_collision_matrix
|
153 |
+
else calculate_collision_matrix(static_env_params, env_state.joint)
|
154 |
+
),
|
155 |
+
acc_rr_manifolds=acc_rr_manifolds,
|
156 |
+
acc_cr_manifolds=acc_cr_manifolds,
|
157 |
+
acc_cc_manifolds=acc_cc_manifolds,
|
158 |
+
)
|
159 |
+
return env_state
|
160 |
+
|
161 |
+
|
162 |
+
def expand_pcg_state(pcg_state: PCGState, static_env_params):
|
163 |
+
new_pcg_state = pcg_state.replace(
|
164 |
+
env_state=expand_env_state(pcg_state.env_state, static_env_params),
|
165 |
+
env_state_max=expand_env_state(pcg_state.env_state_max, static_env_params),
|
166 |
+
env_state_pcg_mask=expand_env_state(
|
167 |
+
pcg_state.env_state_pcg_mask, static_env_params, ignore_collision_matrix=True
|
168 |
+
),
|
169 |
+
)
|
170 |
+
new_pcg_state = new_pcg_state.replace(
|
171 |
+
env_state_pcg_mask=new_pcg_state.env_state_pcg_mask.replace(
|
172 |
+
collision_matrix=jnp.zeros_like(new_pcg_state.env_state.collision_matrix, dtype=bool),
|
173 |
+
)
|
174 |
+
)
|
175 |
+
num_shapes = new_pcg_state.env_state.polygon.active.shape[0] + new_pcg_state.env_state.circle.active.shape[0]
|
176 |
+
|
177 |
+
return new_pcg_state.replace(
|
178 |
+
tied_together=jnp.zeros((num_shapes, num_shapes), dtype=bool)
|
179 |
+
.at[
|
180 |
+
: pcg_state.tied_together.shape[0],
|
181 |
+
: pcg_state.tied_together.shape[1],
|
182 |
+
]
|
183 |
+
.set(pcg_state.tied_together)
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
def load_world_state_pickle(filename, params=None, static_env_params=None):
|
188 |
+
static_params = static_env_params or StaticEnvParams()
|
189 |
+
with open(filename, "rb") as f:
|
190 |
+
state: SimState = pickle.load(f)
|
191 |
+
state = jax.tree.map(lambda x: jnp.nan_to_num(x), state)
|
192 |
+
# Check if the mass and inertia are reasonable.
|
193 |
+
check_if_mass_and_inertia_are_correct(state, params or EnvParams(), static_params)
|
194 |
+
|
195 |
+
# Now check if the shapes are correct
|
196 |
+
return expand_env_state(state, static_params)
|
197 |
+
|
198 |
+
|
199 |
+
def stack_list_of_pytrees(list_of_pytrees):
|
200 |
+
v = jax.tree_map(lambda x: jnp.expand_dims(x, 0), list_of_pytrees[0])
|
201 |
+
for l in list_of_pytrees[1:]:
|
202 |
+
v = jax.tree_map(lambda x, y: jnp.concatenate([x, jnp.expand_dims(y, 0)], axis=0), v, l)
|
203 |
+
return v
|
204 |
+
|
205 |
+
def get_pcg_state_from_json(json_filename) -> PCGState:
|
206 |
+
env_state, _, _ = load_from_json_file(json_filename)
|
207 |
+
return env_state_to_pcg_state(env_state)
|
208 |
+
|
209 |
+
def my_load_file(filename):
|
210 |
+
data = bz2.BZ2File(filename, "rb")
|
211 |
+
data = pickle.load(data)
|
212 |
+
return data
|
213 |
+
|
214 |
+
|
215 |
+
def my_save_file(obj, filename):
|
216 |
+
with bz2.BZ2File(filename, "w") as f:
|
217 |
+
pickle.dump(obj, f)
|
218 |
+
|
219 |
+
|
220 |
+
def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
|
221 |
+
my_save_file(params, filename)
|
222 |
+
|
223 |
+
|
224 |
+
def load_params(filename: Union[str, os.PathLike], legacy=False) -> Dict:
|
225 |
+
if legacy:
|
226 |
+
filename = filename.replace("full_model.pbz2", "model.safetensors")
|
227 |
+
filename = filename.replace(".pbz2", ".safetensors")
|
228 |
+
return unflatten_dict(load_file(filename), sep=",")
|
229 |
+
return my_load_file(filename)
|
230 |
+
|
231 |
+
|
232 |
+
def load_params_from_wandb_artifact_path(checkpoint_name, legacy=False):
|
233 |
+
api = wandb.Api()
|
234 |
+
name = api.artifact(checkpoint_name).download()
|
235 |
+
network_params = load_params(name + "/model.pbz2", legacy=legacy)
|
236 |
+
return network_params
|
237 |
+
|
238 |
+
|
239 |
+
def save_params_to_wandb(params, timesteps, config):
|
240 |
+
if config["checkpoint_human_numbers"]:
|
241 |
+
timesteps = str(round(timesteps / 1e9)) + "B"
|
242 |
+
|
243 |
+
run_name = config["run_name"] + "-" + str(config["random_hash"]) + "-" + str(timesteps)
|
244 |
+
save_dir = os.path.join(config["save_path"], run_name)
|
245 |
+
os.makedirs(save_dir, exist_ok=True)
|
246 |
+
save_params(params, f"{save_dir}/model.pbz2")
|
247 |
+
|
248 |
+
# upload this to wandb as an artifact
|
249 |
+
artifact = wandb.Artifact(f"{run_name}-checkpoint", type="checkpoint")
|
250 |
+
artifact.add_file(f"{save_dir}/model.pbz2")
|
251 |
+
artifact.save()
|
252 |
+
print(f"Parameters of model saved in {save_dir}/model.pbz2")
|
253 |
+
|
254 |
+
|
255 |
+
def load_params_wandb_artifact_path_full_model(checkpoint_name):
|
256 |
+
api = wandb.Api()
|
257 |
+
name = api.artifact(checkpoint_name).download()
|
258 |
+
all_dict = load_params(name + "/full_model.pbz2")
|
259 |
+
return all_dict["params"]
|
260 |
+
|
261 |
+
|
262 |
+
def load_train_state_from_wandb_artifact_path(train_state, checkpoint_name, load_only_params=False, legacy=False):
|
263 |
+
api = wandb.Api()
|
264 |
+
name = api.artifact(checkpoint_name).download()
|
265 |
+
all_dict = load_params(name + "/full_model.pbz2", legacy=legacy)
|
266 |
+
if legacy:
|
267 |
+
return train_state.replace(params=all_dict)
|
268 |
+
train_state = train_state.replace(params=all_dict["params"])
|
269 |
+
if not load_only_params:
|
270 |
+
train_state = train_state.replace(
|
271 |
+
# step=all_dict["step"],
|
272 |
+
opt_state=all_dict["opt_state"]
|
273 |
+
)
|
274 |
+
return train_state
|
275 |
+
|
276 |
+
|
277 |
+
def save_params_to_wandb(params, timesteps, config):
|
278 |
+
return save_dict_to_wandb(params, timesteps, config, "params")
|
279 |
+
|
280 |
+
|
281 |
+
def save_dict_to_wandb(dict, timesteps, config, name):
|
282 |
+
timesteps = str(round(timesteps / 1e9)) + "B"
|
283 |
+
run_name = config["run_name"] + "-" + str(config["random_hash"]) + "-" + str(timesteps)
|
284 |
+
save_dir = os.path.join(config["save_path"], run_name)
|
285 |
+
os.makedirs(save_dir, exist_ok=True)
|
286 |
+
save_params(dict, f"{save_dir}/{name}.pbz2")
|
287 |
+
|
288 |
+
# upload this to wandb as an artifact
|
289 |
+
artifact = wandb.Artifact(f"{run_name}-checkpoint", type="checkpoint")
|
290 |
+
artifact.add_file(f"{save_dir}/{name}.pbz2")
|
291 |
+
artifact.save()
|
292 |
+
print(f"Parameters of model saved in {save_dir}/{name}.pbz2")
|
293 |
+
|
294 |
+
|
295 |
+
def save_model_to_wandb(train_state, timesteps, config, is_final=False):
|
296 |
+
dict_to_use = {"step": train_state.step, "params": train_state.params, "opt_state": train_state.opt_state}
|
297 |
+
step = int(train_state.step)
|
298 |
+
if config["economical_saving"]:
|
299 |
+
if step in [2048, 10240, 40960, 81920] or is_final:
|
300 |
+
save_dict_to_wandb(dict_to_use, timesteps, config, "full_model")
|
301 |
+
else:
|
302 |
+
print("Not saving model because step is", step)
|
303 |
+
else:
|
304 |
+
save_dict_to_wandb(dict_to_use, timesteps, config, "full_model")
|
305 |
+
|
306 |
+
|
307 |
+
def import_env_state_from_json(json_file: dict[str, Any]) -> tuple[EnvState, StaticEnvParams, EnvParams]:
|
308 |
+
from kinetix.environment.env import create_empty_env
|
309 |
+
|
310 |
+
def normalise(k, v):
|
311 |
+
if k == "screen_dim":
|
312 |
+
return v
|
313 |
+
if type(v) == dict and "0" in v:
|
314 |
+
return jnp.array([normalise(k, v[str(i)]) for i in range(len(v))])
|
315 |
+
return v
|
316 |
+
|
317 |
+
env_state = json_file["env_state"]
|
318 |
+
env_params = json_file["env_params"]
|
319 |
+
static_env_params = json_file["static_env_params"]
|
320 |
+
env_params_target = EnvParams()
|
321 |
+
static_env_params_target = StaticEnvParams()
|
322 |
+
new_env_params = flax.serialization.from_state_dict(
|
323 |
+
env_params_target, {k: normalise(k, v) for k, v in env_params.items()}
|
324 |
+
)
|
325 |
+
norm_static = {k: normalise(k, v) for k, v in static_env_params.items()}
|
326 |
+
# norm_static["screen_dim"] = tuple(static_env_params_target.screen_dim)
|
327 |
+
norm_static["downscale"] = static_env_params_target.downscale
|
328 |
+
# print(
|
329 |
+
# static_env_params_target,
|
330 |
+
# )
|
331 |
+
new_static_env_params = flax.serialization.from_state_dict(static_env_params_target, norm_static)
|
332 |
+
new_static_env_params = new_static_env_params.replace(screen_dim=static_env_params_target.screen_dim)
|
333 |
+
|
334 |
+
env_state_target = create_empty_env(new_static_env_params)
|
335 |
+
|
336 |
+
def astype(x, all):
|
337 |
+
return jnp.astype(x, all.dtype)
|
338 |
+
|
339 |
+
def _load_rigidbody(env_state_target, i, is_poly):
|
340 |
+
|
341 |
+
to_load_from: dict[str, Any] = env_state["circle" if not is_poly else "polygon"][i]
|
342 |
+
role = to_load_from.pop("role")
|
343 |
+
density = to_load_from.pop("density")
|
344 |
+
if "highlighted" in to_load_from:
|
345 |
+
_ = to_load_from.pop("highlighted")
|
346 |
+
new_obj = flax.serialization.from_state_dict(
|
347 |
+
jax.tree.map(lambda x: x[i], env_state_target.circle if not is_poly else env_state_target.polygon),
|
348 |
+
{k: normalise(k, v) for k, v in to_load_from.items()},
|
349 |
+
)
|
350 |
+
|
351 |
+
if is_poly:
|
352 |
+
env_state_target = env_state_target.replace(
|
353 |
+
polygon_shape_roles=env_state_target.polygon_shape_roles.at[i].set(role),
|
354 |
+
polygon_densities=env_state_target.polygon_densities.at[i].set(density),
|
355 |
+
polygon=jax.tree.map(
|
356 |
+
lambda all, new: all.at[i].set(astype(new, all)), env_state_target.polygon, new_obj
|
357 |
+
),
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
env_state_target = env_state_target.replace(
|
361 |
+
circle_shape_roles=env_state_target.circle_shape_roles.at[i].set(role),
|
362 |
+
circle_densities=env_state_target.circle_densities.at[i].set(density),
|
363 |
+
circle=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.circle, new_obj),
|
364 |
+
)
|
365 |
+
return env_state_target
|
366 |
+
|
367 |
+
# Now load the env state:
|
368 |
+
for i in range(new_static_env_params.num_circles):
|
369 |
+
env_state_target = _load_rigidbody(env_state_target, i, False)
|
370 |
+
for i in range(new_static_env_params.num_polygons):
|
371 |
+
env_state_target = _load_rigidbody(env_state_target, i, True)
|
372 |
+
|
373 |
+
for i in range(new_static_env_params.num_joints):
|
374 |
+
to_load_from = env_state["joint"][i]
|
375 |
+
motor_binding = to_load_from.pop("motor_binding")
|
376 |
+
new_obj = flax.serialization.from_state_dict(
|
377 |
+
jax.tree.map(lambda x: x[i], env_state_target.joint), {k: normalise(k, v) for k, v in to_load_from.items()}
|
378 |
+
)
|
379 |
+
env_state_target = env_state_target.replace(
|
380 |
+
joint=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.joint, new_obj),
|
381 |
+
motor_bindings=env_state_target.motor_bindings.at[i].set(motor_binding),
|
382 |
+
)
|
383 |
+
|
384 |
+
for i in range(new_static_env_params.num_thrusters):
|
385 |
+
to_load_from = env_state["thruster"][i]
|
386 |
+
thruster_binding = to_load_from.pop("thruster_binding")
|
387 |
+
new_obj = flax.serialization.from_state_dict(
|
388 |
+
jax.tree.map(lambda x: x[i], env_state_target.thruster),
|
389 |
+
{k: normalise(k, v) for k, v in to_load_from.items()},
|
390 |
+
)
|
391 |
+
|
392 |
+
env_state_target = env_state_target.replace(
|
393 |
+
thruster=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.thruster, new_obj),
|
394 |
+
thruster_bindings=env_state_target.thruster_bindings.at[i].set(thruster_binding),
|
395 |
+
)
|
396 |
+
|
397 |
+
env_state_target = env_state_target.replace(
|
398 |
+
collision_matrix=flax.serialization.from_state_dict(
|
399 |
+
env_state_target.collision_matrix, normalise("collision_matrix", env_state["collision_matrix"])
|
400 |
+
)
|
401 |
+
)
|
402 |
+
|
403 |
+
for i in range(env_state_target.acc_rr_manifolds.active.shape[0]):
|
404 |
+
a = flax.serialization.from_state_dict(
|
405 |
+
jax.tree.map(lambda x: x[i], env_state_target.acc_rr_manifolds),
|
406 |
+
{k: normalise(k, v) for k, v in env_state["acc_rr_manifolds"][i].items()},
|
407 |
+
)
|
408 |
+
b = flax.serialization.from_state_dict(
|
409 |
+
jax.tree.map(lambda x: x[i], env_state_target.acc_rr_manifolds),
|
410 |
+
{k: normalise(k, v) for k, v in env_state["acc_rr_manifolds"][i + 1].items()},
|
411 |
+
)
|
412 |
+
env_state_target = env_state_target.replace(
|
413 |
+
acc_rr_manifolds=jax.tree.map(
|
414 |
+
lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_rr_manifolds, a
|
415 |
+
),
|
416 |
+
)
|
417 |
+
env_state_target.replace(
|
418 |
+
acc_rr_manifolds=jax.tree.map(
|
419 |
+
lambda all, new: all.at[i + 1].set(astype(new, all)), env_state_target.acc_rr_manifolds, b
|
420 |
+
)
|
421 |
+
)
|
422 |
+
for i in range(env_state_target.acc_cr_manifolds.active.shape[0]):
|
423 |
+
a = flax.serialization.from_state_dict(
|
424 |
+
jax.tree.map(lambda x: x[i], env_state_target.acc_cr_manifolds),
|
425 |
+
{k: normalise(k, v) for k, v in env_state["acc_cr_manifolds"][i].items()},
|
426 |
+
)
|
427 |
+
env_state_target = env_state_target.replace(
|
428 |
+
acc_cr_manifolds=jax.tree.map(
|
429 |
+
lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_cr_manifolds, a
|
430 |
+
),
|
431 |
+
)
|
432 |
+
for i in range(env_state_target.acc_cc_manifolds.active.shape[0]):
|
433 |
+
a = flax.serialization.from_state_dict(
|
434 |
+
jax.tree.map(lambda x: x[i], env_state_target.acc_cc_manifolds),
|
435 |
+
{k: normalise(k, v) for k, v in env_state["acc_cc_manifolds"][i].items()},
|
436 |
+
)
|
437 |
+
env_state_target = env_state_target.replace(
|
438 |
+
acc_cc_manifolds=jax.tree.map(
|
439 |
+
lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_cc_manifolds, a
|
440 |
+
),
|
441 |
+
)
|
442 |
+
|
443 |
+
env_state_target = env_state_target.replace(
|
444 |
+
collision_matrix=calculate_collision_matrix(new_static_env_params, env_state_target.joint)
|
445 |
+
)
|
446 |
+
|
447 |
+
return (
|
448 |
+
env_state_target,
|
449 |
+
new_static_env_params,
|
450 |
+
new_env_params.replace(max_timesteps=env_params_target.max_timesteps),
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
def export_env_state_to_json(
|
455 |
+
filename: str, env_state: EnvState, static_env_params: StaticEnvParams, env_params: EnvParams
|
456 |
+
):
|
457 |
+
json_to_save = {
|
458 |
+
"polygon": [],
|
459 |
+
"circle": [],
|
460 |
+
"joint": [],
|
461 |
+
"thruster": [],
|
462 |
+
"collision_matrix": flax.serialization.to_state_dict(env_state.collision_matrix.tolist()),
|
463 |
+
"acc_rr_manifolds": [],
|
464 |
+
"acc_cr_manifolds": [],
|
465 |
+
"acc_cc_manifolds": [],
|
466 |
+
"gravity": flax.serialization.to_state_dict(env_state.gravity.tolist()),
|
467 |
+
}
|
468 |
+
|
469 |
+
def _rigidbody_to_json(index: int, is_poly):
|
470 |
+
main_arr = env_state.polygon if is_poly else env_state.circle
|
471 |
+
c = jax.tree.map(lambda x: x[index].tolist(), main_arr)
|
472 |
+
roles = env_state.polygon_shape_roles if is_poly else env_state.circle_shape_roles
|
473 |
+
densities = env_state.polygon_densities if is_poly else env_state.circle_densities
|
474 |
+
highlighted = env_state.polygon_highlighted if is_poly else env_state.circle_highlighted
|
475 |
+
|
476 |
+
d = flax.serialization.to_state_dict(c)
|
477 |
+
d["role"] = roles[index].tolist()
|
478 |
+
d["density"] = densities[index].tolist()
|
479 |
+
d["highlighted"] = highlighted[index].tolist()
|
480 |
+
return d
|
481 |
+
|
482 |
+
def _joint_to_json(i):
|
483 |
+
joint = jax.tree.map(lambda x: x[i].tolist(), env_state.joint)
|
484 |
+
d = flax.serialization.to_state_dict(joint)
|
485 |
+
d["motor_binding"] = env_state.motor_bindings[i].tolist()
|
486 |
+
return d
|
487 |
+
|
488 |
+
def _thruster_to_json(i):
|
489 |
+
thruster = jax.tree.map(lambda x: x[i].tolist(), env_state.thruster)
|
490 |
+
d = flax.serialization.to_state_dict(thruster)
|
491 |
+
d["thruster_binding"] = env_state.thruster_bindings[i].tolist()
|
492 |
+
return d
|
493 |
+
|
494 |
+
for i in range(static_env_params.num_circles):
|
495 |
+
json_to_save["circle"].append(_rigidbody_to_json(i, False))
|
496 |
+
for i in range(static_env_params.num_polygons):
|
497 |
+
json_to_save["polygon"].append(_rigidbody_to_json(i, True))
|
498 |
+
for i in range(static_env_params.num_joints):
|
499 |
+
json_to_save["joint"].append(_joint_to_json(i))
|
500 |
+
for i in range(static_env_params.num_thrusters):
|
501 |
+
json_to_save["thruster"].append(_thruster_to_json(i))
|
502 |
+
|
503 |
+
ncc, ncr, nrr, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(
|
504 |
+
static_env_params
|
505 |
+
)
|
506 |
+
for i in range(nrr):
|
507 |
+
a = jax.tree.map(lambda x: x[i, 0].tolist(), env_state.acc_rr_manifolds)
|
508 |
+
b = jax.tree.map(lambda x: x[i, 1].tolist(), env_state.acc_rr_manifolds)
|
509 |
+
json_to_save["acc_rr_manifolds"].append(flax.serialization.to_state_dict(a))
|
510 |
+
json_to_save["acc_rr_manifolds"].append(flax.serialization.to_state_dict(b))
|
511 |
+
for i in range(ncr):
|
512 |
+
a = jax.tree.map(lambda x: x[i].tolist(), env_state.acc_cr_manifolds)
|
513 |
+
json_to_save["acc_cr_manifolds"].append(flax.serialization.to_state_dict(a))
|
514 |
+
|
515 |
+
for i in range(ncc):
|
516 |
+
a = jax.tree.map(lambda x: x[i].tolist(), env_state.acc_cc_manifolds)
|
517 |
+
json_to_save["acc_cc_manifolds"].append(flax.serialization.to_state_dict(a))
|
518 |
+
|
519 |
+
to_save = {
|
520 |
+
"env_state": json_to_save,
|
521 |
+
"env_params": flax.serialization.to_state_dict(
|
522 |
+
jax.tree.map(lambda x: x.tolist() if type(x) == jnp.ndarray else x, env_params)
|
523 |
+
),
|
524 |
+
"static_env_params": flax.serialization.to_state_dict(
|
525 |
+
jax.tree.map(lambda x: x.tolist() if type(x) == jnp.ndarray else x, static_env_params)
|
526 |
+
),
|
527 |
+
}
|
528 |
+
with open(filename, "w+") as f:
|
529 |
+
json.dump(to_save, f)
|
530 |
+
|
531 |
+
return to_save
|
532 |
+
|
533 |
+
|
534 |
+
def load_from_json_file(filename):
|
535 |
+
with open(filename, "r") as f:
|
536 |
+
return import_env_state_from_json(json.load(f))
|
537 |
+
|
538 |
+
|
539 |
+
if __name__ == "__main__":
|
540 |
+
pass
|
kinetix/util/timing.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from timeit import default_timer as tmr
|
2 |
+
|
3 |
+
counter = 0
|
4 |
+
|
5 |
+
|
6 |
+
def time_function(f, name):
|
7 |
+
global counter
|
8 |
+
t = "\t" * counter
|
9 |
+
# print(f"{t}Starting... {name}")
|
10 |
+
ss = tmr()
|
11 |
+
counter += 1
|
12 |
+
a = f()
|
13 |
+
counter -= 1
|
14 |
+
print(f"{t}{name} took {tmr() - ss} seconds")
|
15 |
+
return a
|