tree3po commited on
Commit
581eeac
·
verified ·
1 Parent(s): ec333dc

Upload 46 files

Browse files
Files changed (46) hide show
  1. kinetix/__init__.py +0 -0
  2. kinetix/assets/circle.png +0 -0
  3. kinetix/assets/edit.png +0 -0
  4. kinetix/assets/fjoint.png +0 -0
  5. kinetix/assets/fjoint2.png +0 -0
  6. kinetix/assets/hand.png +0 -0
  7. kinetix/assets/joint.png +0 -0
  8. kinetix/assets/play.png +0 -0
  9. kinetix/assets/rjoint.png +0 -0
  10. kinetix/assets/rjoint2.png +0 -0
  11. kinetix/assets/rotate.png +0 -0
  12. kinetix/assets/square.png +0 -0
  13. kinetix/assets/thruster.png +0 -0
  14. kinetix/assets/thruster6.png +0 -0
  15. kinetix/assets/triangle.png +0 -0
  16. kinetix/editor.py +0 -0
  17. kinetix/environment/__init__.py +0 -0
  18. kinetix/environment/env.py +829 -0
  19. kinetix/environment/env_state.py +43 -0
  20. kinetix/environment/ued/distributions.py +349 -0
  21. kinetix/environment/ued/mutators.py +1157 -0
  22. kinetix/environment/ued/ued.py +249 -0
  23. kinetix/environment/ued/ued_state.py +53 -0
  24. kinetix/environment/ued/util.py +358 -0
  25. kinetix/environment/utils.py +66 -0
  26. kinetix/environment/wrappers.py +309 -0
  27. kinetix/models/.gitignore +2 -0
  28. kinetix/models/__init__.py +65 -0
  29. kinetix/models/action_spaces.py +58 -0
  30. kinetix/models/actor_critic.py +206 -0
  31. kinetix/models/rel_multi_head.py +546 -0
  32. kinetix/models/transformer_model.py +302 -0
  33. kinetix/pcg/__init__.py +0 -0
  34. kinetix/pcg/pcg.py +97 -0
  35. kinetix/pcg/pcg_state.py +24 -0
  36. kinetix/render/__init__.py +0 -0
  37. kinetix/render/renderer_pixels.py +290 -0
  38. kinetix/render/renderer_symbolic_common.py +190 -0
  39. kinetix/render/renderer_symbolic_entity.py +121 -0
  40. kinetix/render/renderer_symbolic_flat.py +102 -0
  41. kinetix/render/textures.py +43 -0
  42. kinetix/util/__init__.py +0 -0
  43. kinetix/util/config.py +229 -0
  44. kinetix/util/learning.py +565 -0
  45. kinetix/util/saving.py +540 -0
  46. kinetix/util/timing.py +15 -0
kinetix/__init__.py ADDED
File without changes
kinetix/assets/circle.png ADDED
kinetix/assets/edit.png ADDED
kinetix/assets/fjoint.png ADDED
kinetix/assets/fjoint2.png ADDED
kinetix/assets/hand.png ADDED
kinetix/assets/joint.png ADDED
kinetix/assets/play.png ADDED
kinetix/assets/rjoint.png ADDED
kinetix/assets/rjoint2.png ADDED
kinetix/assets/rotate.png ADDED
kinetix/assets/square.png ADDED
kinetix/assets/thruster.png ADDED
kinetix/assets/thruster6.png ADDED
kinetix/assets/triangle.png ADDED
kinetix/editor.py ADDED
The diff for this file is too large to render. See raw diff
 
kinetix/environment/__init__.py ADDED
File without changes
kinetix/environment/env.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from functools import partial
3
+ from typing import Any, Dict, Optional, Tuple, Union
4
+
5
+ import chex
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from chex._src.pytypes import PRNGKey
10
+ from gymnax.environments import environment, spaces
11
+ from gymnax.environments.environment import TEnvParams, TEnvState
12
+ from gymnax.environments.spaces import Space
13
+ from jax import lax
14
+
15
+ from jax2d.engine import PhysicsEngine, create_empty_sim, recalculate_mass_and_inertia
16
+ from jax2d.sim_state import CollisionManifold, SimState
17
+ from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
18
+ from kinetix.environment.wrappers import (
19
+ AutoReplayWrapper,
20
+ AutoResetWrapper,
21
+ UnderspecifiedToGymnaxWrapper,
22
+ DenseRewardWrapper,
23
+ LogWrapper,
24
+ )
25
+
26
+ from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state
27
+ from kinetix.pcg.pcg_state import PCGState
28
+ from kinetix.render.renderer_symbolic_entity import make_render_entities
29
+ from kinetix.render.renderer_pixels import make_render_pixels, make_render_pixels_rl
30
+ from kinetix.render.renderer_symbolic_flat import make_render_symbolic
31
+
32
+ from kinetix.util.saving import load_pcg_state_pickle
33
+ from jaxued.environments import UnderspecifiedEnv
34
+
35
+
36
+ def create_empty_env(static_env_params):
37
+ sim_state = create_empty_sim(static_env_params)
38
+ return EnvState(
39
+ timestep=0,
40
+ thruster_bindings=jnp.zeros(static_env_params.num_thrusters, dtype=jnp.int32),
41
+ motor_bindings=jnp.zeros(static_env_params.num_joints, dtype=jnp.int32),
42
+ motor_auto=jnp.zeros(static_env_params.num_joints, dtype=bool),
43
+ polygon_shape_roles=jnp.zeros(static_env_params.num_polygons, dtype=jnp.int32),
44
+ circle_shape_roles=jnp.zeros(static_env_params.num_circles, dtype=jnp.int32),
45
+ polygon_highlighted=jnp.zeros(static_env_params.num_polygons, dtype=bool),
46
+ circle_highlighted=jnp.zeros(static_env_params.num_circles, dtype=bool),
47
+ polygon_densities=jnp.ones(static_env_params.num_polygons, dtype=jnp.float32),
48
+ circle_densities=jnp.ones(static_env_params.num_circles, dtype=jnp.float32),
49
+ **sim_state.__dict__,
50
+ )
51
+
52
+
53
+ def index_motor_actions(
54
+ action: jnp.ndarray,
55
+ state: EnvState,
56
+ clip_min=None,
57
+ clip_max=None,
58
+ ):
59
+ # Expand the motor actions to all joints with the same colour
60
+ return jnp.clip(action[state.motor_bindings], clip_min, clip_max)
61
+
62
+
63
+ def index_thruster_actions(
64
+ action: jnp.ndarray,
65
+ state: EnvState,
66
+ clip_min=None,
67
+ clip_max=None,
68
+ ):
69
+ # Expand the thruster actions to all joints with the same colour
70
+ return jnp.clip(action[state.thruster_bindings], clip_min, clip_max)
71
+
72
+
73
+ def convert_continuous_actions(
74
+ action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams
75
+ ):
76
+ action_motor = action[: static_env_params.num_motor_bindings]
77
+ action_thruster = action[static_env_params.num_motor_bindings :]
78
+ action_motor = index_motor_actions(action_motor, state, -1, 1)
79
+ action_thruster = index_thruster_actions(action_thruster, state, 0, 1)
80
+
81
+ action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor)
82
+
83
+ action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0)
84
+ return action_to_perform
85
+
86
+
87
+ def convert_discrete_actions(action: int, state: SimState, static_env_params: StaticEnvParams, params: EnvParams):
88
+ # so, we have
89
+ # 0 to NJC * 2 - 1: Joint Actions
90
+ # NJC * 2: No-op
91
+ # NJC * 2 + 1 to NJC * 2 + 1 + NTC - 1: Thruster Actions
92
+ # action here is a categorical action
93
+ which_idx = action // 2
94
+ which_dir = action % 2
95
+ actions = (
96
+ jnp.zeros(static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings)
97
+ .at[which_idx]
98
+ .set(which_dir * 2 - 1)
99
+ )
100
+ actions = actions * (
101
+ 1 - (action >= static_env_params.num_motor_bindings * 2)
102
+ ) # if action is the last one, set it to zero, i.e., a no-op. Alternatively, if the action is larger than NJC * 2, then it is a thruster action and we shouldn't control the joints.
103
+
104
+ actions = jax.lax.select(
105
+ action > static_env_params.num_motor_bindings * 2,
106
+ actions.at[action - static_env_params.num_motor_bindings * 2 - 1 + static_env_params.num_motor_bindings].set(1),
107
+ actions,
108
+ )
109
+
110
+ action_motor = index_motor_actions(actions[: static_env_params.num_motor_bindings], state, -1, 1)
111
+ action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor)
112
+ action_thruster = index_thruster_actions(actions[static_env_params.num_motor_bindings :], state, 0, 1)
113
+ action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0)
114
+ return action_to_perform
115
+
116
+
117
+ def convert_multi_discrete_actions(
118
+ action: jnp.ndarray, state: SimState, static_env_params: StaticEnvParams, params: EnvParams
119
+ ):
120
+ # Comes in with each action being in {0,1,2} for joints and {0,1} for thrusters
121
+ # Convert to [-1., 1.] for joints and [0., 1.] for thrusters
122
+
123
+ def _single_motor_action(act):
124
+ return jax.lax.switch(
125
+ act,
126
+ [lambda: 0.0, lambda: 1.0, lambda: -1.0],
127
+ )
128
+
129
+ def _single_thruster_act(act):
130
+ return jax.lax.select(
131
+ act == 0,
132
+ 0.0,
133
+ 1.0,
134
+ )
135
+
136
+ action_motor = jax.vmap(_single_motor_action)(action[: static_env_params.num_motor_bindings])
137
+ action_thruster = jax.vmap(_single_thruster_act)(action[static_env_params.num_motor_bindings :])
138
+
139
+ action_motor = index_motor_actions(action_motor, state, -1, 1)
140
+ action_thruster = index_thruster_actions(action_thruster, state, 0, 1)
141
+
142
+ action_motor = jnp.where(state.motor_auto, jnp.ones_like(action_motor), action_motor)
143
+
144
+ action_to_perform = jnp.concatenate([action_motor, action_thruster], axis=0)
145
+ return action_to_perform
146
+
147
+
148
+ def make_kinetix_env_from_args(
149
+ obs_type, action_type, reset_type, static_env_params=None, auto_reset_fn=None, dense_reward_scale=1.0
150
+ ):
151
+ if obs_type == "entity":
152
+ if action_type == "multidiscrete":
153
+ env = KinetixEntityMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
154
+ elif action_type == "discrete":
155
+ env = KinetixEntityDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
156
+ elif action_type == "continuous":
157
+ env = KinetixEntityContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
158
+ else:
159
+ raise ValueError(f"Unknown action type: {action_type}")
160
+
161
+ elif obs_type == "symbolic":
162
+ if action_type == "multidiscrete":
163
+ env = KinetixSymbolicMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
164
+ elif action_type == "discrete":
165
+ env = KinetixSymbolicDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
166
+ elif action_type == "continuous":
167
+ env = KinetixSymbolicContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
168
+ else:
169
+ raise ValueError(f"Unknown action type: {action_type}")
170
+
171
+ elif obs_type == "pixels":
172
+ if action_type == "multidiscrete":
173
+ env = KinetixPixelsMultiDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
174
+ elif action_type == "discrete":
175
+ env = KinetixPixelsDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
176
+ elif action_type == "continuous":
177
+ env = KinetixPixelsContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
178
+ else:
179
+ raise ValueError(f"Unknown action type: {action_type}")
180
+
181
+ elif obs_type == "blind":
182
+ if action_type == "discrete":
183
+ env = KinetixBlindDiscreteActions(should_do_pcg_reset=True, static_env_params=static_env_params)
184
+ elif action_type == "continuous":
185
+ env = KinetixBlindContinuousActions(should_do_pcg_reset=True, static_env_params=static_env_params)
186
+ else:
187
+ raise ValueError(f"Unknown action type: {action_type}")
188
+
189
+ else:
190
+ raise ValueError(f"Unknown observation type: {obs_type}")
191
+
192
+ # Wrap
193
+ if reset_type == "replay":
194
+ env = AutoReplayWrapper(env)
195
+ elif reset_type == "reset":
196
+ env = AutoResetWrapper(env, sample_level=auto_reset_fn)
197
+ else:
198
+ raise ValueError(f"Unknown reset type {reset_type}")
199
+
200
+ env = UnderspecifiedToGymnaxWrapper(env)
201
+ env = DenseRewardWrapper(env, dense_reward_scale=dense_reward_scale)
202
+ env = LogWrapper(env)
203
+
204
+ return env
205
+
206
+
207
+ def make_kinetix_env_from_name(name, static_env_params=None):
208
+ kwargs = dict(filename_to_use_for_reset=None, should_do_pcg_reset=True, static_env_params=static_env_params)
209
+ values = {
210
+ "Kinetix-Pixels-MultiDiscrete-v1": KinetixPixelsMultiDiscreteActions,
211
+ "Kinetix-Pixels-Discrete-v1": KinetixPixelsDiscreteActions,
212
+ "Kinetix-Pixels-Continuous-v1": KinetixPixelsContinuousActions,
213
+ "Kinetix-Symbolic-MultiDiscrete-v1": KinetixSymbolicMultiDiscreteActions,
214
+ "Kinetix-Symbolic-Discrete-v1": KinetixSymbolicDiscreteActions,
215
+ "Kinetix-Symbolic-Continuous-v1": KinetixSymbolicContinuousActions,
216
+ "Kinetix-Blind-Discrete-v1": KinetixBlindDiscreteActions,
217
+ "Kinetix-Blind-Continuous-v1": KinetixBlindContinuousActions,
218
+ "Kinetix-Entity-Discrete-v1": KinetixEntityDiscreteActions,
219
+ "Kinetix-Entity-Continuous-v1": KinetixEntityContinuousActions,
220
+ "Kinetix-Entity-MultiDiscrete-v1": KinetixEntityMultiDiscreteActions,
221
+ }
222
+
223
+ return values[name](**kwargs)
224
+
225
+
226
+ class ObservationSpace:
227
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
228
+ pass
229
+
230
+ def get_obs(self, state: EnvState):
231
+ raise NotImplementedError()
232
+
233
+ def observation_space(self, params: EnvParams):
234
+ raise NotImplementedError()
235
+
236
+
237
+ class PixelObservations(ObservationSpace):
238
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
239
+ self.render_function = make_render_pixels_rl(params, static_env_params)
240
+ self.static_env_params = static_env_params
241
+
242
+ def get_obs(self, state: EnvState):
243
+ return self.render_function(state)
244
+
245
+ def observation_space(self, params: EnvParams) -> spaces.Box:
246
+ return spaces.Box(
247
+ 0.0,
248
+ 1.0,
249
+ tuple(a // self.static_env_params.downscale for a in self.static_env_params.screen_dim) + (3,),
250
+ dtype=jnp.float32,
251
+ )
252
+
253
+
254
+ class SymbolicObservations(ObservationSpace):
255
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
256
+ self.render_function = make_render_symbolic(params, static_env_params)
257
+
258
+ def get_obs(self, state: EnvState):
259
+ return self.render_function(state)
260
+
261
+
262
+ class EntityObservations(ObservationSpace):
263
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
264
+ self.render_function = make_render_entities(params, static_env_params)
265
+
266
+ def get_obs(self, state: EnvState):
267
+ return self.render_function(state)
268
+
269
+
270
+ class BlindObservations(ObservationSpace):
271
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
272
+ self.params = params
273
+
274
+ def get_obs(self, state: EnvState):
275
+ return jax.nn.one_hot(state.timestep, self.params.max_timesteps + 1)
276
+
277
+
278
+ def get_observation_space_from_name(name: str, params, static_env_params):
279
+ if "Pixels" in name:
280
+ return PixelObservations(params, static_env_params)
281
+ elif "Symbolic" in name:
282
+ return SymbolicObservations(params, static_env_params)
283
+ elif "Entity" in name:
284
+ return EntityObservations(params, static_env_params)
285
+ if "Blind" in name:
286
+ return BlindObservations(params, static_env_params)
287
+ else:
288
+ raise ValueError(f"Unknown name {name}")
289
+
290
+
291
+ class ActionType:
292
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
293
+ # This is the processed, unified action space size that is shared with all action types
294
+ # 1 dim per motor and thruster
295
+ self.unified_action_space_size = static_env_params.num_motor_bindings + static_env_params.num_thruster_bindings
296
+
297
+ def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]:
298
+ raise NotImplementedError()
299
+
300
+ def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray:
301
+ raise NotImplementedError()
302
+
303
+ def noop_action(self) -> jnp.ndarray:
304
+ raise NotImplementedError()
305
+
306
+ def random_action(self, rng: chex.PRNGKey):
307
+ raise NotImplementedError()
308
+
309
+
310
+ class ActionTypeContinuous(ActionType):
311
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
312
+ super().__init__(params, static_env_params)
313
+
314
+ self.params = params
315
+ self.static_env_params = static_env_params
316
+
317
+ def action_space(self, params: EnvParams | None = None) -> spaces.Discrete | spaces.Box:
318
+ return spaces.Box(
319
+ low=jnp.ones(self.unified_action_space_size) * -1.0,
320
+ high=jnp.ones(self.unified_action_space_size) * 1.0,
321
+ shape=(self.unified_action_space_size,),
322
+ )
323
+
324
+ def process_action(self, action: PRNGKey, state: EnvState, static_env_params: StaticEnvParams) -> PRNGKey:
325
+ return convert_continuous_actions(action, state, static_env_params, self.params)
326
+
327
+ def noop_action(self) -> jnp.ndarray:
328
+ return jnp.zeros(self.unified_action_space_size, dtype=jnp.float32)
329
+
330
+ def random_action(self, rng: chex.PRNGKey) -> jnp.ndarray:
331
+ actions = jax.random.uniform(rng, shape=(self.unified_action_space_size,), minval=-1.0, maxval=1.0)
332
+ # Motors between -1 and 1, thrusters between 0 and 1
333
+ actions = actions.at[self.static_env_params.num_motor_bindings :].set(
334
+ jnp.abs(actions[self.static_env_params.num_motor_bindings :])
335
+ )
336
+
337
+ return actions
338
+
339
+
340
+ class ActionTypeDiscrete(ActionType):
341
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
342
+ super().__init__(params, static_env_params)
343
+
344
+ self.params = params
345
+ self.static_env_params = static_env_params
346
+
347
+ self._n_actions = (
348
+ self.static_env_params.num_motor_bindings * 2 + 1 + self.static_env_params.num_thruster_bindings
349
+ )
350
+
351
+ def action_space(self, params: Optional[EnvParams] = None) -> spaces.Discrete:
352
+ return spaces.Discrete(self._n_actions)
353
+
354
+ def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray:
355
+ return convert_discrete_actions(action, state, static_env_params, self.params)
356
+
357
+ def noop_action(self) -> int:
358
+ return self.static_env_params.num_motor_bindings * 2
359
+
360
+ def random_action(self, rng: chex.PRNGKey):
361
+ return jax.random.randint(rng, shape=(), minval=0, maxval=self._n_actions)
362
+
363
+
364
+ class MultiDiscrete(Space):
365
+ def __init__(self, n, number_of_dims_per_distribution):
366
+ self.number_of_dims_per_distribution = number_of_dims_per_distribution
367
+ self.n = n
368
+ self.shape = (number_of_dims_per_distribution.shape[0],)
369
+ self.dtype = jnp.int_
370
+
371
+ def sample(self, rng: chex.PRNGKey) -> chex.Array:
372
+ """Sample random action uniformly from set of categorical choices."""
373
+ uniform_sample = jax.random.uniform(rng, shape=self.shape) * self.number_of_dims_per_distribution
374
+ md_dist = jnp.floor(uniform_sample)
375
+ return md_dist.astype(self.dtype)
376
+
377
+ def contains(self, x) -> jnp.ndarray:
378
+ """Check whether specific object is within space."""
379
+ range_cond = jnp.logical_and(x >= 0, (x < self.number_of_dims_per_distribution).all())
380
+ return range_cond
381
+
382
+
383
+ class ActionTypeMultiDiscrete(ActionType):
384
+ def __init__(self, params: EnvParams, static_env_params: StaticEnvParams):
385
+ super().__init__(params, static_env_params)
386
+
387
+ self.params = params
388
+ self.static_env_params = static_env_params
389
+ # This is the action space that will be used internally by an agent
390
+ # 3 dims per motor (foward, backward, off) and 2 per thruster (on, off)
391
+ self.n_hot_action_space_size = (
392
+ self.static_env_params.num_motor_bindings * 3 + self.static_env_params.num_thruster_bindings * 2
393
+ )
394
+
395
+ def _make_sample_random():
396
+ minval = jnp.zeros(self.unified_action_space_size, dtype=jnp.int32)
397
+ maxval = jnp.ones(self.unified_action_space_size, dtype=jnp.int32) * 3
398
+ maxval = maxval.at[self.static_env_params.num_motor_bindings :].set(2)
399
+
400
+ def random(rng):
401
+ return jax.random.randint(rng, shape=(self.unified_action_space_size,), minval=minval, maxval=maxval)
402
+
403
+ return random
404
+
405
+ self._random = _make_sample_random
406
+
407
+ self.number_of_dims_per_distribution = jnp.concatenate(
408
+ [
409
+ np.ones(self.static_env_params.num_motor_bindings) * 3,
410
+ np.ones(self.static_env_params.num_thruster_bindings) * 2,
411
+ ]
412
+ ).astype(np.int32)
413
+
414
+ def action_space(self, params: Optional[EnvParams] = None) -> MultiDiscrete:
415
+ return MultiDiscrete(self.n_hot_action_space_size, self.number_of_dims_per_distribution)
416
+
417
+ def process_action(self, action: jnp.ndarray, state: EnvState, static_env_params: StaticEnvParams) -> jnp.ndarray:
418
+ return convert_multi_discrete_actions(action, state, static_env_params, self.params)
419
+
420
+ def noop_action(self):
421
+ return jnp.zeros(self.unified_action_space_size, dtype=jnp.int32)
422
+
423
+ def random_action(self, rng: chex.PRNGKey):
424
+ return self._random()(rng)
425
+
426
+
427
+ class BasePhysicsEnv(UnderspecifiedEnv):
428
+ def __init__(
429
+ self,
430
+ action_type: ActionType,
431
+ observation_space: ObservationSpace,
432
+ static_env_params: StaticEnvParams,
433
+ target_index: int = 0,
434
+ filename_to_use_for_reset=None, # "worlds/games/bipedal_v1",
435
+ should_do_pcg_reset: bool = False,
436
+ ):
437
+ super().__init__()
438
+ self.target_index = target_index
439
+ self.static_env_params = static_env_params
440
+ self.action_type = action_type
441
+ self._observation_space = observation_space
442
+ self.physics_engine = PhysicsEngine(self.static_env_params)
443
+ self.should_do_pcg_reset = should_do_pcg_reset
444
+
445
+ self.filename_to_use_for_reset = filename_to_use_for_reset
446
+ if self.filename_to_use_for_reset is not None:
447
+ self.reset_state = load_pcg_state_pickle(filename_to_use_for_reset)
448
+ else:
449
+ env_state = create_empty_env(self.static_env_params)
450
+ self.reset_state = env_state_to_pcg_state(env_state)
451
+
452
+ # Action / Observations
453
+ def action_space(self, params: Optional[EnvParams] = None) -> Union[spaces.Discrete, spaces.Box]:
454
+ return self.action_type.action_space(params)
455
+
456
+ def observation_space(self, params: Any):
457
+ return self._observation_space.observation_space(params)
458
+
459
+ def get_obs(self, state: EnvState):
460
+ return self._observation_space.get_obs(state)
461
+
462
+ def step_env(self, rng, state, action: jnp.ndarray, params):
463
+ action_processed = self.action_type.process_action(action, state, self.static_env_params)
464
+ return self.engine_step(state, action_processed, params)
465
+
466
+ def reset_env(self, rng, params):
467
+ # Wrap in AutoResetWrapper or AutoReplayWrapper
468
+ raise NotImplementedError()
469
+
470
+ def reset_env_to_level(self, rng, state: EnvState, params):
471
+ if isinstance(state, PCGState):
472
+ return self.reset_env_to_pcg_level(rng, state, params)
473
+ return self.get_obs(state), state
474
+
475
+ def reset_env_to_pcg_level(self, rng, state: PCGState, params):
476
+ env_state = sample_pcg_state(rng, state, params, self.static_env_params)
477
+ return self.get_obs(env_state), env_state
478
+
479
+ @property
480
+ def default_params(self) -> EnvParams:
481
+ return EnvParams()
482
+
483
+ @staticmethod
484
+ def default_static_params() -> StaticEnvParams:
485
+ return StaticEnvParams()
486
+
487
+ def compute_reward_info(
488
+ self, state: EnvState, manifolds: tuple[CollisionManifold, CollisionManifold, CollisionManifold]
489
+ ) -> float:
490
+ def get_active(manifold: CollisionManifold) -> jnp.ndarray:
491
+ return manifold.active
492
+
493
+ def dist(a, b):
494
+ return jnp.linalg.norm(a - b)
495
+
496
+ @jax.vmap
497
+ def dist_rr(idxa, idxb):
498
+ return dist(state.polygon.position[idxa], state.polygon.position[idxb])
499
+
500
+ @jax.vmap
501
+ def dist_cc(idxa, idxb):
502
+ return dist(state.circle.position[idxa], state.circle.position[idxb])
503
+
504
+ @jax.vmap
505
+ def dist_cr(idxa, idxb):
506
+ return dist(state.circle.position[idxa], state.polygon.position[idxb])
507
+
508
+ info = {
509
+ "GoalR": False,
510
+ }
511
+ negative_reward = 0
512
+ reward = 0
513
+ distance = 0
514
+ rs = state.polygon_shape_roles * state.polygon.active
515
+ cs = state.circle_shape_roles * state.circle.active
516
+
517
+ # Polygon Polygon
518
+ r1 = rs[self.physics_engine.poly_poly_pairs[:, 0]]
519
+ r2 = rs[self.physics_engine.poly_poly_pairs[:, 1]]
520
+ reward += ((r1 * r2 == 2) * get_active(manifolds[0])).sum()
521
+ negative_reward += ((r1 * r2 == 3) * get_active(manifolds[0])).sum()
522
+
523
+ distance += (
524
+ (r1 * r2 == 2)
525
+ * dist_rr(self.physics_engine.poly_poly_pairs[:, 0], self.physics_engine.poly_poly_pairs[:, 1])
526
+ ).sum()
527
+
528
+ # Circle Polygon
529
+ c1 = cs[self.physics_engine.circle_poly_pairs[:, 0]]
530
+ r2 = rs[self.physics_engine.circle_poly_pairs[:, 1]]
531
+ reward += ((c1 * r2 == 2) * get_active(manifolds[1])).sum()
532
+ negative_reward += ((c1 * r2 == 3) * get_active(manifolds[1])).sum()
533
+
534
+ t = dist_cr(self.physics_engine.circle_poly_pairs[:, 0], self.physics_engine.circle_poly_pairs[:, 1])
535
+ distance += ((c1 * r2 == 2) * t).sum()
536
+
537
+ # Circle Circle
538
+ c1 = cs[self.physics_engine.circle_circle_pairs[:, 0]]
539
+ c2 = cs[self.physics_engine.circle_circle_pairs[:, 1]]
540
+ reward += ((c1 * c2 == 2) * get_active(manifolds[2])).sum()
541
+ negative_reward += ((c1 * c2 == 3) * get_active(manifolds[2])).sum()
542
+
543
+ distance += (
544
+ (c1 * c2 == 2)
545
+ * dist_cc(self.physics_engine.circle_circle_pairs[:, 0], self.physics_engine.circle_circle_pairs[:, 1])
546
+ ).sum()
547
+
548
+ reward = jax.lax.select(
549
+ negative_reward > 0,
550
+ -1.0,
551
+ jax.lax.select(
552
+ reward > 0,
553
+ 1.0,
554
+ 0.0,
555
+ ),
556
+ )
557
+
558
+ info["GoalR"] = reward > 0
559
+ info["distance"] = distance
560
+ return reward, info
561
+
562
+ @partial(jax.jit, static_argnums=(0,))
563
+ def engine_step(self, env_state, action_to_perform, env_params):
564
+ def _single_step(env_state, unused):
565
+ env_state, mfolds = self.physics_engine.step(
566
+ env_state,
567
+ env_params,
568
+ action_to_perform,
569
+ )
570
+
571
+ reward, info = self.compute_reward_info(env_state, mfolds)
572
+
573
+ done = reward != 0
574
+
575
+ info = {"rr_manifolds": None, "cr_manifolds": None} | info
576
+
577
+ return env_state, (reward, done, info)
578
+
579
+ env_state, (rewards, dones, infos) = jax.lax.scan(
580
+ _single_step, env_state, xs=None, length=self.static_env_params.frame_skip
581
+ )
582
+ env_state = env_state.replace(timestep=env_state.timestep + 1)
583
+
584
+ reward = rewards.max()
585
+ done = dones.sum() > 0 | jax.tree.reduce(
586
+ jnp.logical_or, jax.tree.map(lambda x: jnp.isnan(x).any(), env_state), False
587
+ )
588
+ done |= env_state.timestep >= env_params.max_timesteps
589
+
590
+ info = jax.tree.map(lambda x: x[-1], infos)
591
+
592
+ return (
593
+ lax.stop_gradient(self.get_obs(env_state)),
594
+ lax.stop_gradient(env_state),
595
+ reward,
596
+ done,
597
+ info,
598
+ )
599
+
600
+ @functools.partial(jax.jit, static_argnums=(0,))
601
+ def step(
602
+ self,
603
+ key: chex.PRNGKey,
604
+ state: TEnvState,
605
+ action: Union[int, float, chex.Array],
606
+ params: Optional[TEnvParams] = None,
607
+ ) -> Tuple[chex.Array, TEnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
608
+ raise NotImplementedError()
609
+
610
+
611
+ class KinetixPixelsDiscreteActions(BasePhysicsEnv):
612
+ def __init__(
613
+ self,
614
+ static_env_params: StaticEnvParams | None = None,
615
+ **kwargs,
616
+ ):
617
+
618
+ params = self.default_params
619
+ static_env_params = static_env_params or self.default_static_params()
620
+ super().__init__(
621
+ action_type=ActionTypeDiscrete(params, static_env_params),
622
+ observation_space=PixelObservations(params, static_env_params),
623
+ static_env_params=static_env_params,
624
+ **kwargs,
625
+ )
626
+
627
+ @property
628
+ def name(self) -> str:
629
+ return "Kinetix-Pixels-Discrete-v1"
630
+
631
+
632
+ class KinetixPixelsContinuousActions(BasePhysicsEnv):
633
+ def __init__(
634
+ self,
635
+ static_env_params: StaticEnvParams | None = None,
636
+ **kwargs,
637
+ ):
638
+ params = self.default_params
639
+ static_env_params = static_env_params or self.default_static_params()
640
+ super().__init__(
641
+ action_type=ActionTypeContinuous(params, static_env_params),
642
+ observation_space=PixelObservations(params, static_env_params),
643
+ static_env_params=static_env_params,
644
+ **kwargs,
645
+ )
646
+
647
+ @property
648
+ def name(self) -> str:
649
+ return "Kinetix-Pixels-Continuous-v1"
650
+
651
+
652
+ class KinetixPixelsMultiDiscreteActions(BasePhysicsEnv):
653
+ def __init__(
654
+ self,
655
+ static_env_params: StaticEnvParams | None = None,
656
+ **kwargs,
657
+ ):
658
+ params = self.default_params
659
+ static_env_params = static_env_params or self.default_static_params()
660
+ super().__init__(
661
+ action_type=ActionTypeMultiDiscrete(params, static_env_params),
662
+ observation_space=PixelObservations(params, static_env_params),
663
+ static_env_params=static_env_params,
664
+ **kwargs,
665
+ )
666
+
667
+ @property
668
+ def name(self) -> str:
669
+ return "Kinetix-Pixels-MultiDiscrete-v1"
670
+
671
+
672
+ class KinetixSymbolicDiscreteActions(BasePhysicsEnv):
673
+ def __init__(
674
+ self,
675
+ static_env_params: StaticEnvParams | None = None,
676
+ **kwargs,
677
+ ):
678
+ params = self.default_params
679
+ static_env_params = static_env_params or self.default_static_params()
680
+ super().__init__(
681
+ action_type=ActionTypeDiscrete(params, static_env_params),
682
+ observation_space=SymbolicObservations(params, static_env_params),
683
+ static_env_params=static_env_params,
684
+ **kwargs,
685
+ )
686
+
687
+ @property
688
+ def name(self) -> str:
689
+ return "Kinetix-Symbolic-Discrete-v1"
690
+
691
+
692
+ class KinetixSymbolicContinuousActions(BasePhysicsEnv):
693
+ def __init__(
694
+ self,
695
+ static_env_params: StaticEnvParams | None = None,
696
+ **kwargs,
697
+ ):
698
+ params = self.default_params
699
+ static_env_params = static_env_params or self.default_static_params()
700
+ super().__init__(
701
+ action_type=ActionTypeContinuous(params, static_env_params),
702
+ observation_space=SymbolicObservations(params, static_env_params),
703
+ static_env_params=static_env_params,
704
+ **kwargs,
705
+ )
706
+
707
+ @property
708
+ def name(self) -> str:
709
+ return "Kinetix-Symbolic-Continuous-v1"
710
+
711
+
712
+ class KinetixSymbolicMultiDiscreteActions(BasePhysicsEnv):
713
+ def __init__(
714
+ self,
715
+ static_env_params: StaticEnvParams | None = None,
716
+ **kwargs,
717
+ ):
718
+ params = self.default_params
719
+ static_env_params = static_env_params or self.default_static_params()
720
+ super().__init__(
721
+ action_type=ActionTypeMultiDiscrete(params, static_env_params),
722
+ observation_space=SymbolicObservations(params, static_env_params),
723
+ static_env_params=static_env_params,
724
+ **kwargs,
725
+ )
726
+
727
+ @property
728
+ def name(self) -> str:
729
+ return "Kinetix-Symbolic-MultiDiscrete-v1"
730
+
731
+
732
+ class KinetixEntityDiscreteActions(BasePhysicsEnv):
733
+ def __init__(
734
+ self,
735
+ static_env_params: StaticEnvParams | None = None,
736
+ **kwargs,
737
+ ):
738
+ params = self.default_params
739
+ static_env_params = static_env_params or self.default_static_params()
740
+ super().__init__(
741
+ action_type=ActionTypeDiscrete(params, static_env_params),
742
+ observation_space=EntityObservations(params, static_env_params),
743
+ static_env_params=static_env_params,
744
+ **kwargs,
745
+ )
746
+
747
+ @property
748
+ def name(self) -> str:
749
+ return "Kinetix-Entity-Discrete-v1"
750
+
751
+
752
+ class KinetixEntityContinuousActions(BasePhysicsEnv):
753
+ def __init__(
754
+ self,
755
+ static_env_params: StaticEnvParams | None = None,
756
+ **kwargs,
757
+ ):
758
+ params = self.default_params
759
+ static_env_params = static_env_params or self.default_static_params()
760
+ super().__init__(
761
+ action_type=ActionTypeContinuous(params, static_env_params),
762
+ observation_space=EntityObservations(params, static_env_params),
763
+ static_env_params=static_env_params,
764
+ **kwargs,
765
+ )
766
+
767
+ @property
768
+ def name(self) -> str:
769
+ return "Kinetix-Entity-Continuous-v1"
770
+
771
+
772
+ class KinetixEntityMultiDiscreteActions(BasePhysicsEnv):
773
+ def __init__(
774
+ self,
775
+ static_env_params: StaticEnvParams | None = None,
776
+ **kwargs,
777
+ ):
778
+ params = self.default_params
779
+ static_env_params = static_env_params or self.default_static_params()
780
+ super().__init__(
781
+ action_type=ActionTypeMultiDiscrete(params, static_env_params),
782
+ observation_space=EntityObservations(params, static_env_params),
783
+ static_env_params=static_env_params,
784
+ **kwargs,
785
+ )
786
+
787
+ @property
788
+ def name(self) -> str:
789
+ return "Kinetix-Entity-MultiDiscrete-v1"
790
+
791
+
792
+ class KinetixBlindDiscreteActions(BasePhysicsEnv):
793
+ def __init__(
794
+ self,
795
+ static_env_params: StaticEnvParams | None = None,
796
+ **kwargs,
797
+ ):
798
+ params = self.default_params
799
+ static_env_params = static_env_params or self.default_static_params()
800
+ super().__init__(
801
+ action_type=ActionTypeDiscrete(params, static_env_params),
802
+ observation_space=BlindObservations(params, static_env_params),
803
+ static_env_params=static_env_params,
804
+ **kwargs,
805
+ )
806
+
807
+ @property
808
+ def name(self) -> str:
809
+ return "Kinetix-Blind-Discrete-v1"
810
+
811
+
812
+ class KinetixBlindContinuousActions(BasePhysicsEnv):
813
+ def __init__(
814
+ self,
815
+ static_env_params: StaticEnvParams | None = None,
816
+ **kwargs,
817
+ ):
818
+ params = self.default_params
819
+ static_env_params = static_env_params or self.default_static_params()
820
+ super().__init__(
821
+ action_type=ActionTypeContinuous(params, static_env_params),
822
+ observation_space=BlindObservations(params, static_env_params),
823
+ static_env_params=static_env_params,
824
+ **kwargs,
825
+ )
826
+
827
+ @property
828
+ def name(self) -> str:
829
+ return "Kinetix-Blind-Continuous-v1"
kinetix/environment/env_state.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import field
2
+ import jax.numpy as jnp
3
+ from flax import struct
4
+
5
+ from jax2d.sim_state import SimState, SimParams, StaticSimParams
6
+
7
+
8
+ @struct.dataclass
9
+ class EnvState(SimState):
10
+ thruster_bindings: jnp.ndarray
11
+ motor_bindings: jnp.ndarray
12
+ motor_auto: jnp.ndarray
13
+
14
+ polygon_shape_roles: jnp.ndarray
15
+ circle_shape_roles: jnp.ndarray
16
+
17
+ polygon_highlighted: jnp.ndarray
18
+ circle_highlighted: jnp.ndarray
19
+
20
+ polygon_densities: jnp.ndarray
21
+ circle_densities: jnp.ndarray
22
+
23
+ timestep: int = 0
24
+
25
+
26
+ @struct.dataclass
27
+ class EnvParams(SimParams):
28
+ max_timesteps: int = 256
29
+ pixels_per_unit: int = 100
30
+ dense_reward_scale: float = 0.1
31
+ num_shape_roles: int = 4
32
+
33
+
34
+ @struct.dataclass
35
+ class StaticEnvParams(StaticSimParams):
36
+ screen_dim: tuple[int, int] = (500, 500)
37
+ downscale: int = 4
38
+
39
+ frame_skip: int = 1
40
+ max_shape_size: int = 2
41
+
42
+ num_motor_bindings: int = 4
43
+ num_thruster_bindings: int = 2
kinetix/environment/ued/distributions.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+
4
+ import chex
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax.serialization import to_state_dict
8
+ from jax2d.engine import (
9
+ calculate_collision_matrix,
10
+ calc_inverse_mass_polygon,
11
+ calc_inverse_mass_circle,
12
+ calc_inverse_inertia_circle,
13
+ calc_inverse_inertia_polygon,
14
+ recalculate_mass_and_inertia,
15
+ select_shape,
16
+ PhysicsEngine,
17
+ )
18
+ from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
19
+ from jax2d.maths import rmat
20
+ from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
21
+ from kinetix.environment.ued.mutators import (
22
+ mutate_add_connected_shape_proper,
23
+ mutate_add_shape,
24
+ mutate_add_connected_shape,
25
+ mutate_add_thruster,
26
+ )
27
+ from kinetix.environment.ued.ued_state import UEDParams
28
+ from kinetix.environment.ued.util import (
29
+ get_role,
30
+ sample_dimensions,
31
+ is_space_for_shape,
32
+ random_position_on_polygon,
33
+ random_position_on_circle,
34
+ are_there_shapes_present,
35
+ is_space_for_joint,
36
+ )
37
+ from kinetix.environment.utils import permute_state
38
+ from kinetix.util.saving import load_world_state_pickle
39
+ from flax import struct
40
+ from kinetix.environment.env import create_empty_env
41
+
42
+
43
+ @partial(jax.jit, static_argnums=(1, 3, 5, 6, 7, 8, 9, 10))
44
+ def create_vmapped_filtered_distribution(
45
+ rng,
46
+ level_sampler,
47
+ env_params: EnvParams,
48
+ static_env_params: StaticEnvParams,
49
+ ued_params: UEDParams,
50
+ n_samples: int,
51
+ env,
52
+ do_filter_levels: bool,
53
+ level_filter_sample_ratio: int,
54
+ env_size_name: str,
55
+ level_filter_n_steps: int,
56
+ ):
57
+
58
+ if do_filter_levels and level_filter_n_steps > 0:
59
+ sample_ratio = level_filter_sample_ratio
60
+ n_unfiltered_samples = sample_ratio * n_samples
61
+ rng, _rng = jax.random.split(rng)
62
+ _rngs = jax.random.split(_rng, n_unfiltered_samples)
63
+
64
+ # unfiltered_levels = jax.vmap(level_sampler, in_axes=(0, None, None, None, None))(
65
+ # _rngs, env_params, static_env_params, ued_params, env_size_name
66
+ # )
67
+ unfiltered_levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs)
68
+ #
69
+
70
+ # No-op filtering
71
+
72
+ def _noop_step(states, rng):
73
+ rng, _rng = jax.random.split(rng)
74
+ _rngs = jax.random.split(_rng, n_unfiltered_samples)
75
+
76
+ action = jnp.zeros((n_unfiltered_samples, *env.action_space(env_params).shape), dtype=jnp.int32)
77
+
78
+ obs, states, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
79
+ _rngs, states, action, env_params
80
+ )
81
+
82
+ return states, (done, reward)
83
+
84
+ # Wrap levels
85
+ rng, _rng = jax.random.split(rng)
86
+ _rngs = jax.random.split(_rng, n_unfiltered_samples)
87
+ obsv, unfiltered_levels_wrapped = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
88
+ _rngs, unfiltered_levels, env_params
89
+ )
90
+
91
+ rng, _rng = jax.random.split(rng)
92
+ _rngs = jax.random.split(_rng, level_filter_n_steps)
93
+ _, (done, rewards) = jax.lax.scan(_noop_step, unfiltered_levels_wrapped, xs=_rngs)
94
+
95
+ done_indexes = jnp.argmax(done, axis=0)
96
+ done_rewards = rewards[done_indexes, jnp.arange(n_unfiltered_samples)]
97
+
98
+ noop_solved_indexes = done_rewards > 0.5
99
+ p = noop_solved_indexes * 0.001 + (1 - noop_solved_indexes) * 1.0
100
+ p /= p.sum()
101
+
102
+ rng, _rng = jax.random.split(rng)
103
+ level_indexes = jax.random.choice(
104
+ _rng, jnp.arange(n_unfiltered_samples), shape=(n_samples,), replace=False, p=p
105
+ )
106
+
107
+ levels = jax.tree.map(lambda x: x[level_indexes], unfiltered_levels)
108
+
109
+ else:
110
+ rng, _rng = jax.random.split(rng)
111
+ _rngs = jax.random.split(_rng, n_samples)
112
+
113
+ levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs)
114
+
115
+ return levels
116
+
117
+
118
+ @partial(jax.jit, static_argnums=(1, 3, 4, 5))
119
+ def sample_kinetix_level(
120
+ rng,
121
+ engine: PhysicsEngine,
122
+ env_params: EnvParams,
123
+ static_env_params: StaticEnvParams,
124
+ ued_params: UEDParams,
125
+ env_size_name: str = "l",
126
+ ):
127
+ rng, _rng = jax.random.split(rng)
128
+ _rngs = jax.random.split(_rng, 12)
129
+
130
+ small_force_no_fixate = env_size_name == "s"
131
+
132
+ # Start with empty state
133
+ state = create_empty_env(static_env_params)
134
+
135
+ # Set the floor
136
+ prob_of_floor_colour = jnp.array(
137
+ [
138
+ ued_params.floor_prob_normal,
139
+ ued_params.floor_prob_green,
140
+ ued_params.floor_prob_blue,
141
+ ued_params.floor_prob_red,
142
+ ]
143
+ )
144
+ floor_colour = jax.random.choice(_rngs[0], jnp.arange(4), p=prob_of_floor_colour)
145
+ state = state.replace(polygon_shape_roles=state.polygon_shape_roles.at[0].set(floor_colour))
146
+
147
+ # When we add shapes we don't want them to collide with already existing shapes
148
+ def _choose_proposal_with_least_collisions(proposals, bias=None):
149
+ rr, cr, cc = jax.vmap(engine.calculate_collision_manifolds)(proposals)
150
+
151
+ rr_collisions = jnp.sum(jnp.sum(rr.active.astype(jnp.int32), axis=-1), axis=-1)
152
+ cr_collisions = jnp.sum(cr.active.astype(jnp.int32), axis=-1)
153
+ cc_collisions = jnp.sum(cc.active.astype(jnp.int32), axis=-1)
154
+
155
+ all_collisions = jnp.concatenate(
156
+ [rr_collisions[:, None], cr_collisions[:, None], cc_collisions[:, None]], axis=1
157
+ )
158
+ num_collisions = jnp.sum(all_collisions, axis=-1)
159
+ if bias is not None:
160
+ num_collisions = num_collisions + bias
161
+
162
+ chosen_addition_idx = jnp.argmin(num_collisions)
163
+
164
+ return jax.tree.map(lambda x: x[chosen_addition_idx], proposals)
165
+
166
+ def _add_filtered_shape(rng, state, force_no_fixate=False):
167
+ rng, _rng = jax.random.split(rng)
168
+ _rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals)
169
+ proposed_additions = jax.vmap(mutate_add_shape, in_axes=(0, None, None, None, None, None))(
170
+ _rngs,
171
+ state,
172
+ env_params,
173
+ static_env_params,
174
+ ued_params,
175
+ jnp.logical_or(force_no_fixate, small_force_no_fixate),
176
+ )
177
+
178
+ return _choose_proposal_with_least_collisions(proposed_additions)
179
+
180
+ def _add_filtered_connected_shape(rng, state, force_rjoint=False):
181
+ rng, _rng = jax.random.split(rng)
182
+ _rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals)
183
+
184
+ proposed_additions, valid = jax.vmap(mutate_add_connected_shape, in_axes=(0, None, None, None, None, None))(
185
+ _rngs, state, env_params, static_env_params, ued_params, force_rjoint
186
+ )
187
+
188
+ bias = (jnp.ones(ued_params.add_shape_n_proposals) - 1 * valid) * ued_params.connect_no_visibility_bias
189
+
190
+ return _choose_proposal_with_least_collisions(proposed_additions, bias=bias)
191
+
192
+ # Add green and blue - make sure they're not both fixated
193
+ force_green_no_fixate = (jax.random.uniform(_rngs[1]) < 0.5) | (state.polygon_shape_roles[0] == 2)
194
+ state = _add_filtered_shape(_rngs[2], state, force_green_no_fixate)
195
+ state = _add_filtered_shape(_rngs[3], state, ~force_green_no_fixate)
196
+
197
+ # Forced controls
198
+ forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[4], (), 0, 3)]
199
+ force_thruster, force_motor = forced_control[0], forced_control[1]
200
+
201
+ # Forced motor
202
+ state = jax.lax.cond(
203
+ force_motor,
204
+ lambda: _add_filtered_connected_shape(_rngs[5], state, force_rjoint=True), # force the rjoint
205
+ lambda: _add_filtered_shape(_rngs[6], state),
206
+ )
207
+
208
+ # Forced thruster
209
+ state = jax.lax.cond(
210
+ force_thruster,
211
+ lambda: mutate_add_thruster(_rngs[7], state, env_params, static_env_params, ued_params),
212
+ lambda: state,
213
+ )
214
+
215
+ # Add rest of shapes
216
+ n_shapes_to_add = (
217
+ static_env_params.num_polygons + static_env_params.num_circles - 3 - static_env_params.num_static_fixated_polys
218
+ )
219
+
220
+ def _add_shape(state, rng):
221
+ rng, _rng = jax.random.split(rng)
222
+ _rngs = jax.random.split(_rng, 3)
223
+ shape_add_type = jax.random.choice(
224
+ _rngs[0],
225
+ jnp.arange(3),
226
+ p=jnp.array(
227
+ [ued_params.add_connected_shape_chance, ued_params.add_shape_chance, ued_params.add_no_shape_chance]
228
+ ),
229
+ )
230
+
231
+ state = jax.lax.switch(
232
+ shape_add_type,
233
+ [
234
+ lambda: _add_filtered_connected_shape(_rngs[1], state),
235
+ lambda: _add_filtered_shape(_rngs[2], state),
236
+ lambda: state,
237
+ ],
238
+ )
239
+
240
+ return state, None
241
+
242
+ state, _ = jax.lax.scan(_add_shape, state, jax.random.split(_rngs[8], n_shapes_to_add))
243
+
244
+ # Add thrusters
245
+ n_thrusters_to_add = static_env_params.num_thrusters - 1
246
+
247
+ def _add_thruster(state, rng):
248
+ rng, _rng = jax.random.split(rng)
249
+ _rngs = jax.random.split(_rng, 3)
250
+ state = jax.lax.cond(
251
+ jax.random.uniform(_rngs[0]) < ued_params.add_thruster_chance,
252
+ lambda: mutate_add_thruster(_rngs[1], state, env_params, static_env_params, ued_params),
253
+ lambda: state,
254
+ )
255
+
256
+ return state, None
257
+
258
+ state, _ = jax.lax.scan(_add_thruster, state, jax.random.split(_rngs[9], n_thrusters_to_add))
259
+
260
+ # Randomly swap green and blue to remove left-right bias
261
+ def _swap_roles(do_swap_roles, roles):
262
+ role1 = roles == 1
263
+ role2 = roles == 2
264
+
265
+ swapped_roles = roles * ~(role1 | role2) + role1.astype(int) * 2 + role2.astype(int) * 1
266
+ return jax.lax.select(do_swap_roles, swapped_roles, roles)
267
+
268
+ do_swap_roles = jax.random.uniform(_rngs[10], shape=()) < 0.5
269
+ # Don't want to swap if floor is non-standard
270
+ do_swap_roles &= state.polygon_shape_roles[0] == 0
271
+ state = state.replace(
272
+ polygon_shape_roles=_swap_roles(do_swap_roles, state.polygon_shape_roles),
273
+ circle_shape_roles=_swap_roles(do_swap_roles, state.circle_shape_roles),
274
+ )
275
+
276
+ return permute_state(_rngs[11], state, static_env_params)
277
+
278
+
279
+ @partial(jax.jit, static_argnums=(2, 4, 5))
280
+ def create_random_starting_distribution(
281
+ rng,
282
+ env_params: EnvParams,
283
+ static_env_params: StaticEnvParams,
284
+ ued_params: UEDParams,
285
+ env_size_name: str,
286
+ controllable=True,
287
+ ):
288
+ rng, _rng = jax.random.split(rng)
289
+ _rngs = jax.random.split(_rng, 15)
290
+ d = to_state_dict(ued_params)
291
+ ued_params = UEDParams(
292
+ **(
293
+ d
294
+ | dict(
295
+ goal_body_size_factor=2.0,
296
+ thruster_power_multiplier=2.0,
297
+ max_shape_size=0.5,
298
+ )
299
+ ),
300
+ )
301
+
302
+ prob_of_large_shapes = 0.05
303
+
304
+ ued_params_large_shapes = ued_params.replace(
305
+ max_shape_size=static_env_params.max_shape_size * 1.0, goal_body_size_factor=1.0
306
+ )
307
+
308
+ state = create_empty_env(env_params, static_env_params)
309
+
310
+ def _get_ued_params(rng):
311
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
312
+ large_shapes = jax.random.uniform(_rng) < prob_of_large_shapes
313
+ params_to_use = jax.tree.map(
314
+ lambda x, y: jax.lax.select(large_shapes, x, y), ued_params_large_shapes, ued_params
315
+ )
316
+ return params_to_use
317
+
318
+ def _my_add_shape(rng, state):
319
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
320
+ return mutate_add_shape(_rng, state, env_params, static_env_params, _get_ued_params(_rng2))
321
+
322
+ def _my_add_connected_shape(rng, state, **kwargs):
323
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
324
+ return mutate_add_connected_shape_proper(
325
+ _rng, state, env_params, static_env_params, _get_ued_params(_rng2), **kwargs
326
+ )
327
+
328
+ # Add the green thing and blue thing
329
+ state = _my_add_shape(_rngs[0], state)
330
+ state = _my_add_shape(_rngs[1], state)
331
+ if controllable:
332
+ # Forced controls
333
+ forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[2], (), 0, 3)]
334
+ force_thruster, force_motor = forced_control[0], forced_control[1]
335
+
336
+ # Forced motor
337
+ state = jax.lax.cond(
338
+ force_motor,
339
+ lambda: _my_add_connected_shape(_rngs[3], state, force_rjoint=True), # force the rjoint
340
+ lambda: state,
341
+ )
342
+
343
+ # Forced thruster
344
+ state = jax.lax.cond(
345
+ force_thruster,
346
+ lambda: mutate_add_thruster(_rngs[4], state, env_params, static_env_params, ued_params),
347
+ lambda: state,
348
+ )
349
+ return permute_state(_rngs[7], state, static_env_params)
kinetix/environment/ued/mutators.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+
4
+ import chex
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax.serialization import to_state_dict
8
+ from jax2d.engine import (
9
+ PhysicsEngine,
10
+ calculate_collision_matrix,
11
+ calc_inverse_mass_polygon,
12
+ calc_inverse_mass_circle,
13
+ calc_inverse_inertia_circle,
14
+ calc_inverse_inertia_polygon,
15
+ recalculate_mass_and_inertia,
16
+ select_shape,
17
+ )
18
+ from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
19
+ from jax2d.maths import rmat
20
+ from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
21
+ from kinetix.environment.ued.ued_state import UEDParams
22
+ from kinetix.environment.ued.util import (
23
+ count_roles,
24
+ is_space_for_joint,
25
+ make_velocities_zero,
26
+ sample_dimensions,
27
+ random_position_on_polygon,
28
+ random_position_on_circle,
29
+ get_role,
30
+ is_space_for_shape,
31
+ are_there_shapes_present,
32
+ )
33
+ from kinetix.util.saving import load_world_state_pickle
34
+ from flax import struct
35
+ from kinetix.environment.env import create_empty_env
36
+ from kinetix.environment.ued.util import make_do_dummy_step
37
+
38
+
39
+ @partial(jax.jit, static_argnums=(3, 4))
40
+ def mutate_add_shape(
41
+ rng,
42
+ state: EnvState,
43
+ params: EnvParams,
44
+ static_env_params: StaticEnvParams,
45
+ ued_params: UEDParams,
46
+ force_no_fixate: bool = False,
47
+ ):
48
+ def do_dummy(rng, state):
49
+ return state
50
+
51
+ def do_add(rng, state):
52
+ rng, _rng = jax.random.split(rng)
53
+ _rngs = jax.random.split(_rng, 9)
54
+
55
+ space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons
56
+ space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles
57
+
58
+ is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0])
59
+ is_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=is_rect_p)
60
+
61
+ rect_index = jnp.argmin(state.polygon.active)
62
+ circle_index = jnp.argmin(state.circle.active)
63
+
64
+ shape_role = get_role(_rngs[1], state, static_env_params)
65
+
66
+ max_shape_size = (
67
+ jnp.array([1.0, ued_params.goal_body_size_factor, ued_params.goal_body_size_factor, 1.0])[shape_role]
68
+ * ued_params.max_shape_size
69
+ )
70
+
71
+ vertices, half_dimensions, radius = sample_dimensions(
72
+ _rngs[2],
73
+ static_env_params,
74
+ is_rect,
75
+ ued_params,
76
+ max_shape_size=max_shape_size,
77
+ )
78
+ n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4)
79
+
80
+ largest = jnp.max(jnp.array([half_dimensions[0] * jnp.sqrt(2), half_dimensions[1] * jnp.sqrt(2), radius]))
81
+
82
+ screen_dim_world = (
83
+ static_env_params.screen_dim[0] / params.pixels_per_unit,
84
+ static_env_params.screen_dim[1] / params.pixels_per_unit,
85
+ )
86
+ min_x = largest
87
+ max_x = screen_dim_world[0] - largest
88
+ min_y = largest + 0.4
89
+ max_y = screen_dim_world[1] - largest
90
+
91
+ def _og_minmax():
92
+ return min_x, max_x, min_y, max_y
93
+
94
+ def _opposite_minmax():
95
+ return jax.lax.switch(
96
+ shape_role,
97
+ [
98
+ (lambda: (min_x, max_x, min_y, max_y)),
99
+ (lambda: (min_x, max_x - screen_dim_world[0] / 2, min_y, max_y)),
100
+ (lambda: (min_x + screen_dim_world[0] / 2, max_x, min_y, max_y)),
101
+ (lambda: (min_x, max_x, min_y, max_y)),
102
+ ],
103
+ )
104
+
105
+ min_x, max_x, min_y, max_y = jax.lax.cond(
106
+ jax.random.uniform(_rngs[4], shape=()) < ued_params.goal_body_opposide_side_chance,
107
+ _opposite_minmax,
108
+ _og_minmax,
109
+ )
110
+
111
+ position = jax.random.uniform(_rngs[5], shape=(2,)) * jnp.array(
112
+ [
113
+ max_x - min_x,
114
+ max_y - min_y,
115
+ ]
116
+ ) + jnp.array([min_x, min_y])
117
+
118
+ rotation = jax.random.uniform(_rngs[6], shape=()) * 2 * math.pi
119
+ velocity = jnp.array([0.0, 0.0])
120
+ angular_velocity = 0.0
121
+
122
+ density = 1.0
123
+ inverse_mass = jax.lax.select(
124
+ is_rect,
125
+ calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0],
126
+ calc_inverse_mass_circle(radius, density),
127
+ )
128
+
129
+ inverse_inertia = jax.lax.select(
130
+ is_rect,
131
+ calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density),
132
+ calc_inverse_inertia_circle(radius, density),
133
+ )
134
+
135
+ fixate_chance = ued_params.fixate_chance_min + (1.0 / inverse_mass) * ued_params.fixate_chance_scale
136
+ fixate_chance = jnp.minimum(fixate_chance, ued_params.fixate_chance_max)
137
+ is_fixated = jax.random.uniform(_rngs[7], shape=()) < fixate_chance
138
+ is_fixated &= ~force_no_fixate
139
+
140
+ inverse_mass *= 1 - is_fixated
141
+ inverse_inertia *= 1 - is_fixated
142
+
143
+ # We want to bias fixated shapes to starting nearer the bottom half of the screen
144
+ fixate_shape_bottom_bias = (
145
+ ued_params.fixate_shape_bottom_bias + ued_params.fixate_shape_bottom_bias_special_role * (shape_role != 0)
146
+ )
147
+ is_forcing_bottom = jax.random.uniform(_rngs[8]) < fixate_shape_bottom_bias
148
+
149
+
150
+ half_screen_height = (static_env_params.screen_dim[1] / params.pixels_per_unit) / 2.0
151
+ position = jax.lax.select(
152
+ is_fixated & is_forcing_bottom & (position[1] >= half_screen_height),
153
+ position.at[1].add(-half_screen_height),
154
+ position,
155
+ )
156
+
157
+ # This could be either a rect or a circle
158
+ new_rigid_body = RigidBody(
159
+ position=position,
160
+ velocity=velocity,
161
+ inverse_mass=inverse_mass,
162
+ inverse_inertia=inverse_inertia,
163
+ rotation=rotation,
164
+ angular_velocity=angular_velocity,
165
+ radius=radius,
166
+ active=True,
167
+ friction=1.0,
168
+ vertices=vertices,
169
+ n_vertices=n_vertices,
170
+ collision_mode=1,
171
+ restitution=0.0,
172
+ )
173
+
174
+ state = state.replace(
175
+ polygon=jax.tree.map(
176
+ lambda x, y: jax.lax.select(is_rect, y.at[rect_index].set(x), y), new_rigid_body, state.polygon
177
+ ),
178
+ circle=jax.tree.map(
179
+ lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[circle_index].set(x), y),
180
+ new_rigid_body,
181
+ state.circle,
182
+ ),
183
+ polygon_shape_roles=jax.lax.select(
184
+ is_rect,
185
+ state.polygon_shape_roles.at[rect_index].set(shape_role),
186
+ state.polygon_shape_roles,
187
+ ),
188
+ circle_shape_roles=jax.lax.select(
189
+ jnp.logical_not(is_rect),
190
+ state.circle_shape_roles.at[circle_index].set(shape_role),
191
+ state.circle_shape_roles,
192
+ ),
193
+ )
194
+ return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
195
+
196
+ return jax.lax.cond(is_space_for_shape(state), do_add, do_dummy, rng, state)
197
+
198
+
199
+ @partial(jax.jit, static_argnums=(3, 4))
200
+ def mutate_add_connected_shape(
201
+ rng,
202
+ state: EnvState,
203
+ params: EnvParams,
204
+ static_env_params: StaticEnvParams,
205
+ ued_params: UEDParams,
206
+ force_rjoint: bool = False,
207
+ ):
208
+ def do_dummy(rng, state):
209
+ return state, False
210
+
211
+ def do_add(rng, state):
212
+ rng, _rng = jax.random.split(rng)
213
+ _rngs = jax.random.split(_rng, 21)
214
+
215
+ # Select a random index amongst the currently active shapes.
216
+ p_rect = state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False)
217
+ p_circle = state.circle.active
218
+
219
+ p_rect = p_rect.astype(jnp.float32)
220
+ p_circle = p_circle.astype(jnp.float32)
221
+
222
+ p_rect *= (state.polygon.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + (
223
+ state.polygon.inverse_mass != 0
224
+ ) * 1.0
225
+ p_circle *= (state.circle.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + (
226
+ state.circle.inverse_mass != 0
227
+ ) * 1.0
228
+
229
+ # Bias based on number of existing connections
230
+ rect_connections = jnp.zeros(static_env_params.num_polygons)
231
+ circle_connections = jnp.zeros(static_env_params.num_circles)
232
+
233
+ rect_connections = rect_connections.at[state.joint.a_index].add(
234
+ jnp.ones(static_env_params.num_joints)
235
+ * state.joint.active
236
+ * (state.joint.a_index < static_env_params.num_polygons)
237
+ )
238
+ rect_connections = rect_connections.at[state.joint.b_index].add(
239
+ jnp.ones(static_env_params.num_joints)
240
+ * state.joint.active
241
+ * (state.joint.b_index < static_env_params.num_polygons)
242
+ )
243
+
244
+ circle_connections = circle_connections.at[state.joint.a_index - static_env_params.num_polygons].add(
245
+ jnp.ones(static_env_params.num_joints)
246
+ * state.joint.active
247
+ * (state.joint.a_index >= static_env_params.num_polygons)
248
+ )
249
+ circle_connections = circle_connections.at[state.joint.b_index - static_env_params.num_polygons].add(
250
+ jnp.ones(static_env_params.num_joints)
251
+ * state.joint.active
252
+ * (state.joint.b_index >= static_env_params.num_polygons)
253
+ )
254
+
255
+ # Rectangles can have up to 2 connections
256
+ p_rect *= (-rect_connections + 2.0) / 2.0
257
+ p_rect = jnp.maximum(p_rect, 0.0)
258
+ # Circles can have 1 connection
259
+ p_circle *= circle_connections == 0
260
+
261
+ # To sample a target rect/circle, we have to have at least one.
262
+ target_rect_p = jnp.array(
263
+ [
264
+ (state.polygon.active.astype(int).sum() > static_env_params.num_static_fixated_polys) * 1.0,
265
+ (state.circle.active.astype(int).sum() > 0) * 1.0,
266
+ ]
267
+ )
268
+
269
+ # Don't connect to a circle if no connection-free ones exist
270
+ target_rect_p = target_rect_p.at[1].mul(p_circle.sum() > 0)
271
+
272
+ space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons
273
+ space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles
274
+
275
+ is_target_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=target_rect_p) | (
276
+ ~space_for_new_rect
277
+ )
278
+
279
+ is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0])
280
+ is_rect = jax.random.choice(_rngs[1], jnp.array([True, False], dtype=bool), p=is_rect_p) | (
281
+ ~is_target_rect & space_for_new_rect
282
+ )
283
+
284
+ shape_index = jax.lax.select(
285
+ is_rect,
286
+ jnp.argmin(state.polygon.active),
287
+ jnp.argmin(state.circle.active),
288
+ )
289
+ unified_shape_index = shape_index + (~is_rect) * static_env_params.num_polygons
290
+
291
+ vertices, half_dimensions, radius = sample_dimensions(
292
+ _rngs[2], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size
293
+ )
294
+ n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4)
295
+
296
+ rotation = jax.random.uniform(_rngs[4], shape=()) * 2 * math.pi
297
+ velocity = jnp.array([0.0, 0.0])
298
+ angular_velocity = 0.0
299
+
300
+ density = 1.0
301
+ inverse_mass = jax.lax.select(
302
+ is_rect,
303
+ calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0],
304
+ calc_inverse_mass_circle(radius, density),
305
+ )
306
+
307
+ inverse_inertia = jax.lax.select(
308
+ is_rect,
309
+ calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density),
310
+ calc_inverse_inertia_circle(radius, density),
311
+ )
312
+
313
+ # Joint
314
+
315
+ current_num_rjoints = (jnp.logical_not(state.joint.is_fixed_joint) * state.joint.active).sum()
316
+ is_rjoint = jnp.logical_or(
317
+ jnp.logical_or(jax.random.uniform(_rngs[5]) < 0.5, force_rjoint),
318
+ current_num_rjoints < ued_params.min_rjoints_bias,
319
+ )
320
+
321
+ joint_index = jnp.argmin(state.joint.active)
322
+
323
+ local_joint_position_rect = random_position_on_polygon(_rngs[6], vertices, n_vertices, static_env_params)
324
+ local_joint_position_circle = random_position_on_circle(_rngs[7], radius, on_centre_chance=1.0)
325
+
326
+ local_joint_position = jax.lax.select(is_rect, local_joint_position_rect, local_joint_position_circle)
327
+
328
+ p_rect = jax.lax.select(p_rect.sum() == 0, state.polygon.active.astype(jnp.float32), p_rect)
329
+ p_circle = jax.lax.select(p_circle.sum() == 0, state.circle.active.astype(jnp.float32), p_circle)
330
+
331
+ target_index = jax.lax.select(
332
+ is_target_rect,
333
+ jax.random.choice(
334
+ _rngs[8],
335
+ jnp.arange(static_env_params.num_polygons),
336
+ p=p_rect,
337
+ ),
338
+ jax.random.choice(
339
+ _rngs[9],
340
+ jnp.arange(static_env_params.num_circles),
341
+ p=p_circle,
342
+ ),
343
+ )
344
+
345
+ unified_target_index = target_index + jnp.logical_not(is_target_rect) * static_env_params.num_polygons
346
+ target_shape = select_shape(state, unified_target_index, static_env_params)
347
+
348
+ target_joint_position_rect = random_position_on_polygon(
349
+ _rngs[10], state.polygon.vertices[target_index], state.polygon.n_vertices[target_index], static_env_params
350
+ )
351
+ target_joint_position_circle = random_position_on_circle(
352
+ _rngs[11], state.circle.radius[target_index], on_centre_chance=1.0
353
+ )
354
+
355
+ target_joint_position = jax.lax.select(is_target_rect, target_joint_position_rect, target_joint_position_circle)
356
+
357
+ # Calculate the world position of the new shape
358
+ # We know the rotation of the new shape. We also know the position of the current shape, which we want to remain fixed.
359
+ # Set `position` such that local_joint_position is the same as `target_joint_position`
360
+ global_joint_pos = target_shape.position + jnp.matmul(rmat(target_shape.rotation), target_joint_position)
361
+ position = global_joint_pos - jnp.matmul(rmat(rotation), local_joint_position)
362
+
363
+ _, pos_diff = calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)
364
+ position = jax.lax.select(is_rect, position + pos_diff, position)
365
+ local_joint_position = jax.lax.select(is_rect, local_joint_position - pos_diff, local_joint_position)
366
+ vertices = jax.lax.select(is_rect, vertices - pos_diff[None], vertices)
367
+
368
+ target_role = jax.lax.select(
369
+ is_target_rect, state.polygon_shape_roles[target_index], state.circle_shape_roles[target_index]
370
+ )
371
+
372
+ # We cannot have role 1 and role 2 being connected.
373
+ p = jnp.array([1.0, 1.0, 1.0, 1.0])
374
+ # If role is 0, keep all probs at 1, otherwise set the target role's complement to 0 prob
375
+ # 3 - role turns 1 to 2 and 2 to 1
376
+ # If the target role is three, we set everything to zero except for the default
377
+ p = jax.lax.select(
378
+ target_role == 0,
379
+ p,
380
+ jax.lax.select(
381
+ target_role <= 2,
382
+ p.at[3 - target_role].set(False).at[3].set(False),
383
+ (p.at[2].set(False).at[1].set(False)),
384
+ ),
385
+ )
386
+
387
+ shape_role = get_role(_rngs[12], state, static_env_params, initial_p=p)
388
+
389
+ # This could be either a rect or a circle
390
+ new_rigid_body = RigidBody(
391
+ position=position,
392
+ velocity=velocity,
393
+ inverse_mass=inverse_mass,
394
+ inverse_inertia=inverse_inertia,
395
+ rotation=rotation,
396
+ angular_velocity=angular_velocity,
397
+ radius=radius,
398
+ active=True,
399
+ friction=1.0,
400
+ vertices=vertices,
401
+ n_vertices=n_vertices,
402
+ collision_mode=1,
403
+ restitution=0.0,
404
+ )
405
+
406
+ # Change the shape indices such that a_index is less than b_index
407
+ a_index = shape_index + (1 - is_rect) * static_env_params.num_polygons
408
+ b_index = target_index + (1 - is_target_rect) * static_env_params.num_polygons
409
+
410
+ should_swap = a_index > b_index
411
+ a_index, b_index, local_joint_position, target_joint_position, shape_a, shape_b = jax.lax.cond(
412
+ should_swap,
413
+ lambda x: (x[1], x[0], x[3], x[2], x[5], x[4]), # pairwise swap
414
+ lambda x: x,
415
+ (a_index, b_index, local_joint_position, target_joint_position, new_rigid_body, target_shape),
416
+ )
417
+
418
+ motor_on = jax.random.uniform(_rngs[13], shape=()) < ued_params.motor_on_chance
419
+ joint_colour = jax.random.randint(_rngs[14], shape=(), minval=0, maxval=static_env_params.num_motor_bindings)
420
+ joint_rotation = shape_b.rotation - shape_a.rotation
421
+
422
+ motor_speed = jax.random.uniform(
423
+ _rngs[15], shape=(), minval=ued_params.motor_min_speed, maxval=ued_params.motor_max_speed
424
+ )
425
+
426
+ motor_power = jax.random.uniform(
427
+ _rngs[16], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.motor_max_power
428
+ )
429
+ wheel_power = jax.random.uniform(
430
+ _rngs[20], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.wheel_max_power
431
+ )
432
+
433
+ # High-powered wheels break the physics engine - this is a temporary fix
434
+ motor_power = jax.lax.select(is_rect & is_target_rect, motor_power, wheel_power)
435
+
436
+ motor_has_joint_limits = jax.random.uniform(_rngs[17], shape=()) < ued_params.joint_limit_chance
437
+ motor_has_joint_limits &= is_rect & is_target_rect
438
+ joint_limit_min = (
439
+ jax.random.uniform(_rngs[18], shape=(), minval=-ued_params.joint_limit_max, maxval=0.0)
440
+ * motor_has_joint_limits
441
+ )
442
+ joint_limit_max = (
443
+ jax.random.uniform(_rngs[19], shape=(), minval=0.0, maxval=ued_params.joint_limit_max)
444
+ * motor_has_joint_limits
445
+ )
446
+
447
+ rjoint = Joint(
448
+ a_index=a_index,
449
+ b_index=b_index,
450
+ a_relative_pos=local_joint_position,
451
+ b_relative_pos=target_joint_position,
452
+ global_position=global_joint_pos,
453
+ active=True,
454
+ motor_speed=motor_speed,
455
+ motor_power=motor_power,
456
+ motor_on=motor_on,
457
+ # colour=joint_colour,
458
+ motor_has_joint_limits=motor_has_joint_limits,
459
+ min_rotation=joint_limit_min,
460
+ max_rotation=joint_limit_max,
461
+ is_fixed_joint=False,
462
+ rotation=0.0,
463
+ acc_impulse=jnp.zeros((2,), dtype=jnp.float32),
464
+ acc_r_impulse=jnp.zeros((), dtype=jnp.float32),
465
+ )
466
+
467
+ fjoint = Joint(
468
+ a_index=a_index,
469
+ b_index=b_index,
470
+ a_relative_pos=local_joint_position,
471
+ b_relative_pos=target_joint_position,
472
+ global_position=global_joint_pos,
473
+ active=True,
474
+ rotation=joint_rotation,
475
+ acc_impulse=jnp.zeros((2,), dtype=jnp.float32),
476
+ acc_r_impulse=jnp.zeros((), dtype=jnp.float32),
477
+ is_fixed_joint=True,
478
+ motor_has_joint_limits=False,
479
+ min_rotation=0.0,
480
+ max_rotation=0.0,
481
+ motor_on=False,
482
+ motor_power=0.0,
483
+ motor_speed=0.0,
484
+ )
485
+
486
+ state = state.replace(
487
+ polygon=jax.tree.map(
488
+ lambda x, y: jax.lax.select(is_rect, y.at[shape_index].set(x), y), new_rigid_body, state.polygon
489
+ ),
490
+ circle=jax.tree.map(
491
+ lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[shape_index].set(x), y),
492
+ new_rigid_body,
493
+ state.circle,
494
+ ),
495
+ joint=jax.tree.map(
496
+ lambda rj, fj, y: jax.lax.select(is_rjoint, y.at[joint_index].set(rj), y.at[joint_index].set(fj)),
497
+ rjoint,
498
+ fjoint,
499
+ state.joint,
500
+ ),
501
+ polygon_shape_roles=jax.lax.select(
502
+ is_rect,
503
+ state.polygon_shape_roles.at[shape_index].set(shape_role),
504
+ state.polygon_shape_roles,
505
+ ),
506
+ circle_shape_roles=jax.lax.select(
507
+ jnp.logical_not(is_rect),
508
+ state.circle_shape_roles.at[shape_index].set(shape_role),
509
+ state.circle_shape_roles,
510
+ ),
511
+ motor_bindings=state.motor_bindings.at[joint_index].set(joint_colour),
512
+ )
513
+
514
+ # We need the new collision matrix.
515
+ state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
516
+ state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
517
+
518
+ # Was this a valid addition?
519
+ # We calculate whether (assuming the possiblity of 360 degree rotation around the joint)
520
+ # both shapes can be visible
521
+ # This is to remove the common degenerate pattern of connected shapes being fully inside each other
522
+ def _get_min_rect_dist(r_id, local_pos):
523
+ rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon)
524
+
525
+ half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0
526
+ half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0
527
+
528
+ dist_x = half_width - jnp.abs(local_pos[0])
529
+ dist_y = half_height - jnp.abs(local_pos[1])
530
+
531
+ return jnp.minimum(dist_x, dist_y)
532
+
533
+ def _get_max_rect_dist(r_id, local_pos):
534
+ rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon)
535
+
536
+ half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0
537
+ half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0
538
+
539
+ dist_x = jnp.maximum(
540
+ jnp.abs(half_width - local_pos[0]),
541
+ jnp.abs(-half_width - local_pos[0]),
542
+ )
543
+
544
+ dist_y = jnp.maximum(
545
+ jnp.abs(half_height - local_pos[1]),
546
+ jnp.abs(-half_height - local_pos[1]),
547
+ )
548
+
549
+ return jnp.sqrt(dist_x * dist_x + dist_y * dist_y)
550
+
551
+ def are_both_shapes_showing(idx1, idx2, local_pos1, local_pos2):
552
+ def _is_small_shape_showing(small_idx, big_idx, small_local_pos, big_local_pos):
553
+ small_is_poly = small_idx < static_env_params.num_polygons
554
+ big_is_poly = big_idx < static_env_params.num_polygons
555
+
556
+ # CC
557
+ cc_result = False
558
+
559
+ # CR
560
+ cr_r_dist = _get_min_rect_dist(big_idx, big_local_pos)
561
+ cr_result = (
562
+ cr_r_dist + ued_params.connect_visibility_min
563
+ < state.circle.radius[small_idx - static_env_params.num_polygons]
564
+ )
565
+
566
+ # RC
567
+ rc_r_dist = _get_max_rect_dist(small_idx, small_local_pos)
568
+ rc_result = (
569
+ rc_r_dist
570
+ > state.circle.radius[big_idx - static_env_params.num_polygons] + ued_params.connect_visibility_min
571
+ )
572
+
573
+ # RR
574
+ rr_small_dist = _get_max_rect_dist(small_idx, small_local_pos)
575
+ rr_big_dist = _get_min_rect_dist(big_idx, big_local_pos)
576
+ rr_result = rr_small_dist > rr_big_dist + ued_params.connect_visibility_min
577
+
578
+ # Select
579
+ return jax.lax.select(
580
+ small_is_poly,
581
+ jax.lax.select(big_is_poly, rr_result, rc_result),
582
+ jax.lax.select(big_is_poly, cr_result, cc_result),
583
+ )
584
+
585
+ # Are both shapes showing?
586
+ return _is_small_shape_showing(idx1, idx2, local_pos1, local_pos2) & _is_small_shape_showing(
587
+ idx2, idx1, local_pos2, local_pos1
588
+ )
589
+
590
+ valid = are_both_shapes_showing(
591
+ unified_shape_index, unified_target_index, local_joint_position, target_joint_position
592
+ )
593
+ return state, valid
594
+
595
+ # To add a connected shape, we must have both at least one existing shape and space
596
+ return jax.lax.cond(
597
+ is_space_for_shape(state) & are_there_shapes_present(state, static_env_params) & is_space_for_joint(state),
598
+ do_add,
599
+ do_dummy,
600
+ rng,
601
+ state,
602
+ )
603
+
604
+
605
+ @partial(jax.jit, static_argnums=(3, 4))
606
+ def mutate_add_connected_shape_proper(
607
+ rng,
608
+ state: EnvState,
609
+ params: EnvParams,
610
+ static_env_params: StaticEnvParams,
611
+ ued_params: UEDParams,
612
+ force_rjoint: bool = False,
613
+ ):
614
+ return mutate_add_connected_shape(rng, state, params, static_env_params, ued_params, force_rjoint=force_rjoint)[0]
615
+
616
+
617
+ @partial(jax.jit, static_argnums=(3, 4))
618
+ def mutate_remove_shape(
619
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
620
+ ):
621
+
622
+ can_remove_mask = (
623
+ jnp.concatenate([state.polygon.active, state.circle.active])
624
+ .at[: static_env_params.num_static_fixated_polys]
625
+ .set(False)
626
+ )
627
+
628
+ def dummy(rng, state):
629
+ return state
630
+
631
+ def do_remove(rng, state: EnvState):
632
+ rng, _rng = jax.random.split(rng)
633
+ rngs = jax.random.split(_rng, 2)
634
+ p = can_remove_mask.astype(jnp.float32)
635
+ index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p)
636
+ is_rect = index_to_remove < static_env_params.num_polygons
637
+ state = state.replace(
638
+ polygon=state.polygon.replace(
639
+ active=jax.lax.select(
640
+ is_rect, state.polygon.active.at[index_to_remove].set(False), state.polygon.active
641
+ )
642
+ ),
643
+ circle=state.circle.replace(
644
+ active=jax.lax.select(
645
+ jnp.logical_not(is_rect),
646
+ state.circle.active.at[index_to_remove - static_env_params.num_polygons].set(False),
647
+ state.circle.active,
648
+ )
649
+ ),
650
+ )
651
+ # We need to now remove any joints connected to this shape
652
+ joints_to_remove = (state.joint.a_index == index_to_remove) | (state.joint.b_index == index_to_remove)
653
+
654
+ thrusters_to_remove = state.thruster.object_index == index_to_remove
655
+
656
+ state = state.replace(
657
+ joint=state.joint.replace(active=jnp.where(joints_to_remove, False, state.joint.active)),
658
+ thruster=state.thruster.replace(active=jnp.where(thrusters_to_remove, False, state.thruster.active)),
659
+ )
660
+ # Now recalculate collision matrix
661
+ state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
662
+ return state
663
+
664
+ return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state)
665
+
666
+
667
+ @partial(jax.jit, static_argnums=(3, 4))
668
+ def mutate_remove_joint(
669
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
670
+ ):
671
+ can_remove_mask = state.joint.active
672
+
673
+ def dummy(rng, state):
674
+ return state
675
+
676
+ def do_remove(rng, state):
677
+ rng, _rng = jax.random.split(rng)
678
+ rngs = jax.random.split(_rng, 2)
679
+ p = can_remove_mask.astype(jnp.float32)
680
+ index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p)
681
+ state = state.replace(joint=state.joint.replace(active=state.joint.active.at[index_to_remove].set(False)))
682
+ # Recalculate collision matrix.
683
+ state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
684
+ return state
685
+
686
+ return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state)
687
+
688
+
689
+ @partial(jax.jit, static_argnums=(3, 4))
690
+ def mutate_swap_role(
691
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
692
+ ):
693
+ def _cr(*args):
694
+ return count_roles(*args, include_static_polys=False)
695
+
696
+ role_counts = jax.vmap(_cr, (None, None, 0))(state, static_env_params, jnp.arange(4))
697
+ are_there_multiple_roles = (role_counts > 0).sum() > 1
698
+
699
+ def dummy(rng, state):
700
+ return state
701
+
702
+ def do_swap(rng, state):
703
+ rng, _rng = jax.random.split(rng)
704
+ rngs = jax.random.split(_rng, 2)
705
+ all_roles = jnp.concatenate([state.polygon_shape_roles, state.circle_shape_roles])
706
+
707
+ p = (
708
+ (jnp.concatenate([state.polygon.active, state.circle.active]))
709
+ .astype(jnp.float32)
710
+ .at[: static_env_params.num_static_fixated_polys]
711
+ .set(0.0)
712
+ )
713
+ shape_idx_a = jax.random.choice(
714
+ rngs[0], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p
715
+ )
716
+ role_a = all_roles[shape_idx_a]
717
+ p = jnp.where(all_roles == role_a, 0.0, p)
718
+ shape_idx_b = jax.random.choice(
719
+ rngs[1], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p
720
+ )
721
+ role_b = all_roles[shape_idx_b]
722
+ role_a, role_b = role_b, role_a
723
+
724
+ for idx, role in [(shape_idx_a, role_a), (shape_idx_b, role_b)]:
725
+ is_rect = idx < static_env_params.num_polygons
726
+ state = state.replace(
727
+ polygon_shape_roles=jax.lax.select(
728
+ is_rect, state.polygon_shape_roles.at[idx].set(role), state.polygon_shape_roles
729
+ ),
730
+ circle_shape_roles=jax.lax.select(
731
+ jnp.logical_not(is_rect),
732
+ state.circle_shape_roles.at[idx - static_env_params.num_polygons].set(role),
733
+ state.circle_shape_roles,
734
+ ),
735
+ )
736
+ return state
737
+
738
+ return jax.lax.cond(are_there_multiple_roles, do_swap, dummy, rng, state)
739
+
740
+
741
+ @partial(jax.jit, static_argnums=(3, 4))
742
+ def mutate_toggle_fixture(
743
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
744
+ ):
745
+ can_toggle_mask = (
746
+ jnp.concatenate([state.polygon.active, state.circle.active])
747
+ .at[: static_env_params.num_static_fixated_polys]
748
+ .set(False)
749
+ )
750
+
751
+ def dummy(rng, state):
752
+ return state
753
+
754
+ def do_toggle(rng, state: EnvState):
755
+ rng, _rng = jax.random.split(rng)
756
+ rngs = jax.random.split(_rng, 2)
757
+ p = can_toggle_mask.astype(jnp.float32)
758
+ index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_toggle_mask.shape[0]), p=p)
759
+ is_rect = index_to_remove < static_env_params.num_polygons
760
+ is_current_fixed = (
761
+ jax.lax.select(
762
+ is_rect,
763
+ state.polygon.inverse_inertia[index_to_remove],
764
+ state.circle.inverse_inertia[index_to_remove - static_env_params.num_polygons],
765
+ )
766
+ == 0.0
767
+ )
768
+
769
+ is_current_fixed = is_current_fixed * 1.0 # if it is fixed, we set it to 1.0 and recalc.
770
+ # If it is not fixed, this is 0.0, and it makes it fixed.
771
+
772
+ state = state.replace(
773
+ polygon=state.polygon.replace(
774
+ inverse_inertia=jax.lax.select(
775
+ is_rect,
776
+ state.polygon.inverse_inertia.at[index_to_remove].set(is_current_fixed),
777
+ state.polygon.inverse_inertia,
778
+ ),
779
+ inverse_mass=jax.lax.select(
780
+ is_rect,
781
+ state.polygon.inverse_mass.at[index_to_remove].set(is_current_fixed),
782
+ state.polygon.inverse_mass,
783
+ ),
784
+ ),
785
+ circle=state.circle.replace(
786
+ inverse_inertia=jax.lax.select(
787
+ jnp.logical_not(is_rect),
788
+ state.circle.inverse_inertia.at[index_to_remove - static_env_params.num_polygons].set(
789
+ is_current_fixed
790
+ ),
791
+ state.circle.inverse_inertia,
792
+ ),
793
+ inverse_mass=jax.lax.select(
794
+ jnp.logical_not(is_rect),
795
+ state.circle.inverse_mass.at[index_to_remove - static_env_params.num_polygons].set(
796
+ is_current_fixed
797
+ ),
798
+ state.circle.inverse_mass,
799
+ ),
800
+ ),
801
+ )
802
+
803
+ state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
804
+ return state
805
+
806
+ return jax.lax.cond(can_toggle_mask.sum() > 0, do_toggle, dummy, rng, state)
807
+
808
+
809
+ @partial(jax.jit, static_argnums=(3, 4))
810
+ def mutate_add_thruster(
811
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
812
+ ):
813
+ is_fixated = jnp.concatenate([state.polygon.inverse_mass == 0, state.circle.inverse_mass == 0])
814
+ # is_fixated = jnp.zeros_like(is_fixated, dtype=bool)
815
+ is_active = jnp.concatenate([state.polygon.active, state.circle.active])
816
+ can_add_mask = is_active & (~is_fixated)
817
+ can_add_mask = jnp.logical_and(is_active, jnp.logical_not(is_fixated))
818
+
819
+ def dummy(rng, state):
820
+ return state
821
+
822
+ def do_add(rng, state: EnvState):
823
+ rng, _rng = jax.random.split(rng)
824
+ _rngs = jax.random.split(_rng, 10)
825
+ p = can_add_mask.astype(jnp.float32)
826
+ shape_index = jax.random.choice(_rngs[0], jnp.arange(can_add_mask.shape[0]), p=p)
827
+ thruster_idx = jnp.argmin(state.thruster.active)
828
+
829
+ shape = select_shape(state, shape_index, static_env_params)
830
+
831
+ position_to_add_thruster = jax.lax.select(
832
+ shape_index < static_env_params.num_polygons,
833
+ random_position_on_polygon(_rngs[1], shape.vertices, shape.n_vertices, static_env_params),
834
+ random_position_on_circle(_rngs[2], shape.radius, on_centre_chance=0.0),
835
+ )
836
+
837
+ direction_to_com = ((jax.random.uniform(_rngs[3]) > 0.5) * 2 - 1) * position_to_add_thruster
838
+ direction_to_com = jax.lax.select(
839
+ jnp.linalg.norm(direction_to_com) == 0.0, jnp.array([1.0, 0.0]), direction_to_com
840
+ )
841
+
842
+ thruster_angle = jax.lax.select(
843
+ jax.random.uniform(_rngs[4]) < ued_params.thruster_align_com_prob,
844
+ jnp.atan2(direction_to_com[1], direction_to_com[0]), # test this
845
+ jax.random.uniform(
846
+ _rngs[5],
847
+ (),
848
+ )
849
+ * 2
850
+ * jnp.pi,
851
+ )
852
+
853
+ thruster_power = jax.random.uniform(_rngs[6]) * 1.5 + 0.5
854
+
855
+ thruster = Thruster(
856
+ object_index=shape_index,
857
+ active=True,
858
+ relative_position=position_to_add_thruster, # jnp.array([0.0, 0.0]), # a bit of a hack but reasonable.
859
+ rotation=thruster_angle, # jax.random.choice(rngs[1], jnp.arange(4) * jnp.pi / 2),
860
+ power=1.0
861
+ / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass)
862
+ * ued_params.thruster_power_multiplier
863
+ * thruster_power,
864
+ global_position=shape.position + jnp.matmul(rmat(shape.rotation), position_to_add_thruster),
865
+ )
866
+ thruster_colour = jax.random.randint(
867
+ _rngs[7], shape=(), minval=0, maxval=static_env_params.num_thruster_bindings
868
+ )
869
+
870
+ state = state.replace(
871
+ thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
872
+ thruster_bindings=state.thruster_bindings.at[thruster_idx].set(thruster_colour),
873
+ )
874
+
875
+ return state
876
+
877
+ return jax.lax.cond(
878
+ jnp.logical_and((can_add_mask.sum() > 0), (jnp.logical_not(state.thruster.active).sum() > 0)),
879
+ do_add,
880
+ dummy,
881
+ rng,
882
+ state,
883
+ )
884
+
885
+
886
+ @partial(jax.jit, static_argnums=(3, 4))
887
+ def mutate_change_gravity(
888
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
889
+ ):
890
+ rng, _rng = jax.random.split(rng)
891
+ rngs = jax.random.split(_rng, 2)
892
+ new_gravity = jax.lax.select(
893
+ jax.random.uniform(rngs[0]) < 0.5,
894
+ jnp.array([0.0, -9.8]),
895
+ jnp.array([0.0, jax.random.uniform(rngs[1], minval=-9.8, maxval=0)]),
896
+ )
897
+
898
+ return state.replace(gravity=new_gravity)
899
+
900
+
901
+ @partial(jax.jit, static_argnums=(3, 4))
902
+ def mutate_remove_thruster(
903
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
904
+ ):
905
+ are_there_thrusters = state.thruster.active
906
+
907
+ def dummy(rng, state):
908
+ return state
909
+
910
+ def do_remove(rng, state):
911
+ rng, _rng = jax.random.split(rng)
912
+ rngs = jax.random.split(_rng, 2)
913
+ p = are_there_thrusters.astype(jnp.float32)
914
+ thruster_idx = jax.random.choice(rngs[0], jnp.arange(are_there_thrusters.shape[0]), p=p)
915
+ return state.replace(thruster=state.thruster.replace(active=state.thruster.active.at[thruster_idx].set(False)))
916
+
917
+ return jax.lax.cond(are_there_thrusters.sum() > 0, do_remove, dummy, rng, state)
918
+
919
+
920
+ def make_mutate_change_shape_size(params, static_env_params):
921
+ do_dummy_step = make_do_dummy_step(params, static_env_params)
922
+
923
+ @partial(jax.jit, static_argnums=(3, 4))
924
+ def mutate_change_shape_size(
925
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
926
+ ):
927
+ shape_active = jnp.concatenate(
928
+ [state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
929
+ )
930
+
931
+ def dummy(rng, state):
932
+ return state
933
+
934
+ def do_change(rng, state):
935
+ rng, _rng = jax.random.split(rng)
936
+ rngs = jax.random.split(_rng, 10)
937
+ p = shape_active.astype(jnp.float32)
938
+ shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
939
+ is_rect = shape_idx < static_env_params.num_polygons
940
+ vertices, _, radius = sample_dimensions(
941
+ rngs[1], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size
942
+ )
943
+
944
+ idx_new_top_left = jnp.argmin(vertices[:, 0] * 100 + vertices[:, 1])
945
+ idx_old_top_left = jnp.argmin(
946
+ state.polygon.vertices[shape_idx, :, 0] * 100 + state.polygon.vertices[shape_idx, :, 1]
947
+ )
948
+ scale_rect = (vertices[idx_new_top_left]) / (state.polygon.vertices[shape_idx, idx_old_top_left])
949
+ scale_circle = radius / state.circle.radius[shape_idx - static_env_params.num_polygons]
950
+ vertices = state.polygon.vertices[shape_idx] * scale_rect
951
+
952
+ scale = jax.lax.select(
953
+ is_rect,
954
+ scale_rect,
955
+ jnp.array([scale_circle, scale_circle]),
956
+ )
957
+
958
+ is_a = ((state.joint.a_index == shape_idx) & state.joint.active)[:, None]
959
+ is_b = ((state.joint.b_index == shape_idx) & state.joint.active)[:, None]
960
+ state = state.replace(
961
+ joint=state.joint.replace(
962
+ a_relative_pos=(state.joint.a_relative_pos * scale[None]) * is_a
963
+ + (1 - is_a) * state.joint.a_relative_pos,
964
+ b_relative_pos=(state.joint.b_relative_pos * scale[None]) * is_b
965
+ + (1 - is_b) * state.joint.b_relative_pos,
966
+ ),
967
+ polygon=state.polygon.replace(
968
+ vertices=jax.lax.select(
969
+ is_rect, state.polygon.vertices.at[shape_idx].set(vertices), state.polygon.vertices
970
+ ),
971
+ ),
972
+ circle=state.circle.replace(
973
+ radius=jax.lax.select(
974
+ jnp.logical_not(is_rect),
975
+ state.circle.radius.at[shape_idx - static_env_params.num_polygons].set(radius),
976
+ state.circle.radius,
977
+ )
978
+ ),
979
+ )
980
+
981
+ def _ss(state, _):
982
+ return do_dummy_step(state), None
983
+
984
+ state = jax.lax.scan(_ss, state, jnp.arange(5))[0]
985
+ return recalculate_mass_and_inertia(
986
+ state, static_env_params, state.polygon_densities, state.circle_densities
987
+ )
988
+
989
+ return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
990
+
991
+ return mutate_change_shape_size
992
+
993
+
994
+ @partial(jax.jit, static_argnums=(3, 4))
995
+ def mutate_change_shape_location(
996
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
997
+ ):
998
+ shape_active = jnp.concatenate(
999
+ [state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
1000
+ )
1001
+
1002
+ def dummy(rng, state):
1003
+ return state
1004
+
1005
+ def do_change(rng, state):
1006
+ rng, _rng = jax.random.split(rng)
1007
+ rngs = jax.random.split(_rng, 10)
1008
+ p = shape_active.astype(jnp.float32)
1009
+ shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
1010
+ delta_pos = jax.random.uniform(rngs[1], shape=(2,)) - 0.5 # [-0.5, 0.5]
1011
+
1012
+ positions = jnp.concatenate([state.polygon.position, state.circle.position])
1013
+
1014
+ mask_of_shape_locations_to_change = (
1015
+ (state.collision_matrix[shape_idx] == 0).at[: static_env_params.num_static_fixated_polys].set(False)
1016
+ )
1017
+ # check the new positions, but then maybe revert if any shape becomes out of bounds now.
1018
+ new_positions_tentative = positions * (
1019
+ 1 - mask_of_shape_locations_to_change[:, None]
1020
+ ) + mask_of_shape_locations_to_change[:, None] * (positions + delta_pos[None])
1021
+
1022
+ polys = state.polygon
1023
+ p_pos = new_positions_tentative[: static_env_params.num_polygons]
1024
+ c_pos = new_positions_tentative[static_env_params.num_polygons :] # state.circle.position
1025
+ rad = state.circle.radius
1026
+ rect_vertex_mask = jnp.arange(static_env_params.max_polygon_vertices)[None] < polys.n_vertices[:, None]
1027
+ rect_mask = polys.active.at[: static_env_params.num_static_fixated_polys].set(False)
1028
+ circ_mask = state.circle.active
1029
+ # check if new pos maybe goes out of bounds:
1030
+ min_x, max_x, min_y, max_y = (
1031
+ jnp.minimum(
1032
+ jnp.min(
1033
+ p_pos[:, 0] + jnp.min(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1),
1034
+ where=rect_mask,
1035
+ initial=jnp.inf,
1036
+ ),
1037
+ jnp.min(c_pos[:, 0] - rad, where=circ_mask, initial=jnp.inf),
1038
+ ),
1039
+ jnp.maximum(
1040
+ jnp.max(
1041
+ p_pos[:, 0] + jnp.max(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1),
1042
+ where=rect_mask,
1043
+ initial=-jnp.inf,
1044
+ ),
1045
+ jnp.max(c_pos[:, 0] + rad, where=circ_mask, initial=-jnp.inf),
1046
+ ),
1047
+ jnp.minimum(
1048
+ jnp.min(
1049
+ p_pos[:, 1] + jnp.min(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1),
1050
+ where=rect_mask,
1051
+ initial=jnp.inf,
1052
+ ),
1053
+ jnp.min(c_pos[:, 1] - rad, where=circ_mask, initial=jnp.inf),
1054
+ ),
1055
+ jnp.maximum(
1056
+ jnp.max(
1057
+ p_pos[:, 1] + jnp.max(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1),
1058
+ where=rect_mask,
1059
+ initial=-jnp.inf,
1060
+ ),
1061
+ jnp.max(c_pos[:, 1] + rad, where=circ_mask, initial=-jnp.inf),
1062
+ ),
1063
+ )
1064
+
1065
+ how_much_oob_x_left = jnp.maximum(0, 0 - min_x)
1066
+ how_much_oob_x_right = jnp.maximum(0, max_x - static_env_params.screen_dim[0] / params.pixels_per_unit)
1067
+ how_much_oob_y_down = jnp.maximum(0, 0.4 - min_y) # this is for the floor
1068
+ how_much_oob_y_up = jnp.maximum(0, max_y - static_env_params.screen_dim[1] / params.pixels_per_unit)
1069
+
1070
+ # correct by out of bounds factor
1071
+ positions = (
1072
+ new_positions_tentative
1073
+ + jnp.array(
1074
+ [
1075
+ how_much_oob_x_left - how_much_oob_x_right,
1076
+ how_much_oob_y_down - how_much_oob_y_up,
1077
+ ]
1078
+ )[None]
1079
+ * mask_of_shape_locations_to_change[:, None]
1080
+ )
1081
+
1082
+ state = state.replace(
1083
+ polygon=state.polygon.replace(
1084
+ position=positions[: static_env_params.num_polygons],
1085
+ ),
1086
+ circle=state.circle.replace(
1087
+ position=positions[static_env_params.num_polygons :],
1088
+ ),
1089
+ )
1090
+ return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
1091
+
1092
+ return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
1093
+
1094
+
1095
+ def make_mutate_change_shape_rotation(params, static_env_params):
1096
+ do_dummy_step = make_do_dummy_step(params, static_env_params)
1097
+
1098
+ @partial(jax.jit, static_argnums=(3, 4))
1099
+ def mutate_change_shape_rotation(
1100
+ rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
1101
+ ):
1102
+ shape_active = jnp.concatenate(
1103
+ [state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
1104
+ )
1105
+
1106
+ def dummy(rng, state):
1107
+ return state
1108
+
1109
+ def do_change(rng, state):
1110
+ rng, _rng = jax.random.split(rng)
1111
+ rngs = jax.random.split(_rng, 10)
1112
+ p = shape_active.astype(jnp.float32)
1113
+ shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
1114
+ is_rect = shape_idx < static_env_params.num_polygons
1115
+
1116
+ rotation_delta = jax.random.uniform(rngs[1], shape=()) * math.pi / 2
1117
+
1118
+ has_fixed_joint_a = (state.joint.a_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active
1119
+ has_fixed_joint_b = (state.joint.b_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active
1120
+
1121
+ state = state.replace(
1122
+ joint=state.joint.replace(
1123
+ rotation=jax.lax.select(
1124
+ has_fixed_joint_a,
1125
+ state.joint.rotation - rotation_delta,
1126
+ jax.lax.select(
1127
+ has_fixed_joint_b,
1128
+ state.joint.rotation + rotation_delta,
1129
+ state.joint.rotation,
1130
+ ),
1131
+ )
1132
+ ),
1133
+ polygon=state.polygon.replace(
1134
+ rotation=jax.lax.select(
1135
+ is_rect, state.polygon.rotation.at[shape_idx].add(rotation_delta), state.polygon.rotation
1136
+ ),
1137
+ ),
1138
+ circle=state.circle.replace(
1139
+ rotation=jax.lax.select(
1140
+ jnp.logical_not(is_rect),
1141
+ state.circle.rotation.at[shape_idx - static_env_params.num_polygons].add(rotation_delta),
1142
+ state.circle.rotation,
1143
+ )
1144
+ ),
1145
+ )
1146
+
1147
+ def _ss(state, _):
1148
+ return do_dummy_step(state), None
1149
+
1150
+ state = jax.lax.scan(_ss, state, jnp.arange(5))[0]
1151
+ return recalculate_mass_and_inertia(
1152
+ state, static_env_params, state.polygon_densities, state.circle_densities
1153
+ )
1154
+
1155
+ return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
1156
+
1157
+ return mutate_change_shape_rotation
kinetix/environment/ued/ued.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import math
3
+ import os
4
+
5
+ import chex
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from flax.serialization import to_state_dict
9
+ from jax2d.engine import (
10
+ calculate_collision_matrix,
11
+ calc_inverse_mass_polygon,
12
+ calc_inverse_mass_circle,
13
+ calc_inverse_inertia_circle,
14
+ calc_inverse_inertia_polygon,
15
+ recalculate_mass_and_inertia,
16
+ select_shape,
17
+ PhysicsEngine,
18
+ )
19
+ from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
20
+ from jax2d.maths import rmat
21
+ from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
22
+ from kinetix.environment.ued.distributions import (
23
+ create_vmapped_filtered_distribution,
24
+ sample_kinetix_level,
25
+ )
26
+ from kinetix.environment.ued.mutators import (
27
+ make_mutate_change_shape_rotation,
28
+ make_mutate_change_shape_size,
29
+ mutate_add_connected_shape_proper,
30
+ mutate_add_shape,
31
+ mutate_add_connected_shape,
32
+ mutate_change_shape_location,
33
+ mutate_remove_joint,
34
+ mutate_remove_shape,
35
+ mutate_swap_role,
36
+ mutate_toggle_fixture,
37
+ mutate_add_thruster,
38
+ mutate_remove_thruster,
39
+ mutate_change_gravity,
40
+ )
41
+ from kinetix.environment.ued.ued_state import UEDParams
42
+ from kinetix.environment.utils import permute_pcg_state
43
+ from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state
44
+ from kinetix.util.config import generate_ued_params_from_config, generate_params_from_config
45
+ from kinetix.util.saving import get_pcg_state_from_json, load_pcg_state_pickle, load_world_state_pickle, stack_list_of_pytrees, expand_env_state
46
+ from flax import struct
47
+ from kinetix.environment.env import create_empty_env
48
+ from kinetix.util.learning import BASE_DIR, general_eval, get_eval_levels
49
+
50
+
51
+ def make_mutate_env(static_env_params: StaticEnvParams, params: EnvParams, ued_params: UEDParams):
52
+ mutate_size = make_mutate_change_shape_size(params, static_env_params)
53
+ mutate_rot = make_mutate_change_shape_rotation(params, static_env_params)
54
+
55
+ def mutate_level(rng, level: EnvState, n=1):
56
+ def inner(carry: tuple[chex.PRNGKey, EnvState], _):
57
+ rng, level = carry
58
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
59
+
60
+ any_rects_left = jnp.logical_not(level.polygon.active).sum() > 0
61
+ any_circles_left = jnp.logical_not(level.circle.active).sum() > 0
62
+ any_joints_left = jnp.logical_not(level.joint.active).sum() > 0
63
+ any_thrust_left = jnp.logical_not(level.thruster.active).sum() > 0
64
+ has_any_thursters = level.thruster.active.sum() > 0
65
+
66
+ can_do_add_shape = any_rects_left | any_circles_left
67
+ can_do_add_joint = can_do_add_shape & any_joints_left
68
+
69
+ all_mutations = [
70
+ mutate_add_shape,
71
+ mutate_add_connected_shape_proper,
72
+ mutate_remove_joint,
73
+ mutate_remove_shape,
74
+ mutate_swap_role,
75
+ mutate_add_thruster,
76
+ mutate_remove_thruster,
77
+ mutate_toggle_fixture,
78
+ mutate_size,
79
+ mutate_change_shape_location,
80
+ mutate_rot,
81
+ ]
82
+
83
+ def mypartial(f):
84
+ def inner(rng, level):
85
+ return f(rng, level, params, static_env_params, ued_params)
86
+
87
+ return inner
88
+
89
+ probs = jnp.array(
90
+ [
91
+ can_do_add_shape * 1.0,
92
+ can_do_add_joint * 1.0,
93
+ 0.0,
94
+ 0.0,
95
+ 1.0,
96
+ any_thrust_left * 1.0,
97
+ has_any_thursters * 1.0,
98
+ 0.1,
99
+ 1.0,
100
+ 1.0,
101
+ 1.0,
102
+ ]
103
+ )
104
+
105
+ all_mutations = [mypartial(i) for i in all_mutations]
106
+ index = jax.random.choice(_rng, jnp.arange(len(all_mutations)), (), p=probs)
107
+ level = jax.lax.switch(index, all_mutations, _rng2, level)
108
+
109
+ return (rng, level), None
110
+
111
+ (_, level), _ = jax.lax.scan(inner, (rng, level), None, length=n)
112
+ return level
113
+
114
+ return mutate_level
115
+
116
+
117
+ def make_create_eval_env():
118
+ eval_level1 = load_world_state_pickle("worlds/eval/eval_0610_car1")
119
+ eval_level2 = load_world_state_pickle("worlds/eval/eval_0610_car2")
120
+ eval_level3 = load_world_state_pickle("worlds/eval/eval_0628_ball_left")
121
+ eval_level4 = load_world_state_pickle("worlds/eval/eval_0628_ball_right")
122
+ eval_level5 = load_world_state_pickle("worlds/eval/eval_0628_hard_car_obstacle")
123
+ eval_level6 = load_world_state_pickle("worlds/eval/eval_0628_swingup")
124
+
125
+ def _create_eval_env(rng, env_params, static_env_params, index):
126
+ return jax.lax.switch(
127
+ index,
128
+ [
129
+ lambda: eval_level1,
130
+ lambda: eval_level2,
131
+ lambda: eval_level3,
132
+ lambda: eval_level4,
133
+ lambda: eval_level5,
134
+ lambda: eval_level6,
135
+ ],
136
+ )
137
+ return jax.tree.map(lambda x, y: jax.lax.select(index == 0, x, y), eval_level1, eval_level2)
138
+
139
+ return _create_eval_env
140
+
141
+
142
+ def make_reset_train_function_with_mutations(
143
+ engine: PhysicsEngine, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state=True
144
+ ):
145
+ ued_params = generate_ued_params_from_config(config)
146
+
147
+ def reset(rng):
148
+ inner = sample_kinetix_level(
149
+ rng, engine, env_params, static_env_params, ued_params, env_size_name=config["env_size_name"]
150
+ )
151
+
152
+ if make_pcg_state:
153
+ return env_state_to_pcg_state(inner)
154
+ else:
155
+ return inner
156
+
157
+ return reset
158
+
159
+
160
+ def make_vmapped_filtered_level_sampler(
161
+ level_sampler, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state, env
162
+ ):
163
+ ued_params = generate_ued_params_from_config(config)
164
+
165
+ def reset(rng, n_samples):
166
+ inner = create_vmapped_filtered_distribution(
167
+ rng,
168
+ level_sampler,
169
+ env_params,
170
+ static_env_params,
171
+ ued_params,
172
+ n_samples,
173
+ env,
174
+ config["filter_levels"],
175
+ config["level_filter_sample_ratio"],
176
+ config["env_size_name"],
177
+ config["level_filter_n_steps"],
178
+ )
179
+ if make_pcg_state:
180
+ return env_state_to_pcg_state(inner)
181
+ else:
182
+ return inner
183
+
184
+ return reset
185
+
186
+
187
+ def make_reset_train_function_with_list_of_levels(config, levels, static_env_params, make_pcg_state=True,
188
+ is_loading_train_levels=False):
189
+ assert len(levels) > 0, "Need to provide at least one level to train on"
190
+ if config["load_train_levels_legacy"]:
191
+ ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in levels]
192
+ v = stack_list_of_pytrees(ls)
193
+ elif is_loading_train_levels:
194
+ v = get_eval_levels(levels, static_env_params)
195
+ else:
196
+ _, static_env_params = generate_params_from_config(
197
+ config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
198
+ )
199
+ v = get_eval_levels(levels, static_env_params)
200
+
201
+ def reset(rng):
202
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
203
+ idx = jax.random.randint(_rng, (), 0, len(levels))
204
+ state_to_return = jax.tree.map(lambda x: x[idx], v)
205
+
206
+ if config["permute_state_during_training"]:
207
+ state_to_return = permute_pcg_state(rng, state_to_return, static_env_params)
208
+ if not make_pcg_state:
209
+ state_to_return = sample_pcg_state(_rng2, state_to_return, params=None, static_params=static_env_params)
210
+
211
+ return state_to_return
212
+
213
+ return reset
214
+
215
+
216
+ ALL_MUTATION_FNS = [
217
+ mutate_add_shape,
218
+ mutate_add_connected_shape,
219
+ mutate_remove_joint,
220
+ mutate_swap_role,
221
+ mutate_toggle_fixture,
222
+ mutate_add_thruster,
223
+ mutate_remove_thruster,
224
+ mutate_remove_shape,
225
+ mutate_change_gravity,
226
+ ]
227
+
228
+
229
+ def test_ued():
230
+ from kinetix.environment.env import create_empty_env
231
+
232
+ env_params = EnvParams()
233
+ static_env_params = StaticEnvParams()
234
+ ued_params = UEDParams()
235
+ rng = jax.random.PRNGKey(0)
236
+ rng, _rng = jax.random.split(rng)
237
+ state = create_empty_env(env_params, static_env_params)
238
+ state = mutate_add_shape(_rng, state, env_params, static_env_params, ued_params)
239
+ state = mutate_add_connected_shape(_rng, state, env_params, static_env_params, ued_params)
240
+ state = mutate_remove_shape(_rng, state, env_params, static_env_params, ued_params)
241
+ state = mutate_remove_joint(_rng, state, env_params, static_env_params, ued_params)
242
+ state = mutate_swap_role(_rng, state, env_params, static_env_params, ued_params)
243
+ state = mutate_toggle_fixture(_rng, state, env_params, static_env_params, ued_params)
244
+
245
+ print("Successfully did this")
246
+
247
+
248
+ if __name__ == "__main__":
249
+ test_ued()
kinetix/environment/ued/ued_state.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from flax import struct
4
+
5
+
6
+ @struct.dataclass
7
+ class UEDParams:
8
+ max_shape_size: float = 1.0
9
+ goal_body_opposide_side_chance: float = 0.5
10
+ goal_body_size_factor: float = 1.0
11
+ min_rjoints_bias: int = 2
12
+
13
+ large_rect_dim_chance: float = 0.3
14
+ large_rect_dim_scale: float = 2.0
15
+
16
+ generate_triangles: bool = False
17
+ thruster_power_multiplier: float = 2.0
18
+
19
+ thruster_align_com_prob: float = 0.8
20
+
21
+ motor_on_chance: float = 0.8
22
+ motor_min_speed: float = 0.4
23
+ motor_max_speed: float = 3.0
24
+ motor_min_power: float = 1.0
25
+ motor_max_power: float = 3.0
26
+ wheel_max_power: float = 1.0
27
+
28
+ joint_limit_chance: float = 0.4
29
+ joint_limit_max: float = math.pi
30
+ joint_fixed_chance: float = 0.1
31
+
32
+ fixate_chance_min: float = 0.02
33
+ fixate_chance_max: float = 1.0
34
+ fixate_chance_scale: float = 4.0 # Fixation probability scales with size
35
+ fixate_shape_bottom_bias: float = 0.0
36
+ fixate_shape_bottom_bias_special_role: float = 0.6
37
+
38
+ circle_max_size_coeff: float = 0.8
39
+
40
+ connect_to_fixated_prob_coeff: float = 0.05
41
+ connect_visibility_min: float = 0.05
42
+ connect_no_visibility_bias: float = 10.0
43
+
44
+ add_shape_chance: float = 0.35
45
+ add_connected_shape_chance: float = 0.35
46
+ add_no_shape_chance: float = 0.3
47
+ add_thruster_chance: float = 0.3
48
+ add_shape_n_proposals: int = 8
49
+
50
+ floor_prob_normal: float = 0.9
51
+ floor_prob_green: float = 0.0
52
+ floor_prob_blue: float = 0.02
53
+ floor_prob_red: float = 0.08
kinetix/environment/ued/util.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from jax2d.engine import PhysicsEngine, calculate_collision_matrix, recalculate_mass_and_inertia, select_shape
8
+ from jax2d.sim_state import RigidBody, Thruster
9
+ from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
10
+
11
+
12
+ def sample_dimensions(rng, static_env_params: StaticEnvParams, is_rect: bool, ued_params, max_shape_size=None):
13
+ if max_shape_size is None:
14
+ max_shape_size = static_env_params.max_shape_size
15
+ # Returns (half_dimensions, radius)
16
+
17
+ rng, _rng = jax.random.split(rng)
18
+ # Don't want overly small shapes
19
+ min_rect_size = 0.05
20
+ min_circle_size = 0.1
21
+ cap_rect = max_shape_size / 2.0 / jnp.sqrt(2.0)
22
+ cap_circ = max_shape_size / 2.0 * ued_params.circle_max_size_coeff
23
+ half_dimensions = (
24
+ jax.lax.select(is_rect, jax.random.uniform(_rng, shape=(2,)), jnp.zeros(2, dtype=jnp.float32))
25
+ * (cap_rect - min_rect_size)
26
+ + min_rect_size
27
+ )
28
+
29
+ rng, _rng, __rng = jax.random.split(rng, 3)
30
+ dim_scale = (
31
+ jnp.ones(2)
32
+ .at[jax.random.randint(_rng, shape=(), minval=0, maxval=2)]
33
+ .set(
34
+ jax.lax.select(
35
+ jax.random.uniform(__rng) < ued_params.large_rect_dim_chance, ued_params.large_rect_dim_scale, 1.0
36
+ )
37
+ )
38
+ )
39
+ half_dimensions *= dim_scale
40
+
41
+ vertices = jnp.array(
42
+ [
43
+ half_dimensions * jnp.array([1, 1]),
44
+ half_dimensions * jnp.array([1, -1]),
45
+ half_dimensions * jnp.array([-1, -1]),
46
+ half_dimensions * jnp.array([-1, 1]),
47
+ ]
48
+ )
49
+
50
+ rng, _rng = jax.random.split(rng)
51
+ radius = (
52
+ jax.lax.select(is_rect, jnp.zeros((), dtype=jnp.float32), jax.random.uniform(_rng, shape=()))
53
+ * (cap_circ - min_circle_size)
54
+ + min_circle_size
55
+ )
56
+ return vertices, half_dimensions, radius
57
+
58
+
59
+ def count_roles(state: EnvState, static_env_params: StaticEnvParams, role: int, include_static_polys=True) -> int:
60
+ active_to_use = state.polygon.active
61
+ if not include_static_polys:
62
+ active_to_use = active_to_use.at[: static_env_params.num_static_fixated_polys].set(False)
63
+ return ((state.polygon_shape_roles == role) * active_to_use).sum() + (
64
+ (state.circle_shape_roles == role) * state.circle.active
65
+ ).sum()
66
+
67
+
68
+ def random_position_on_triangle(rng, vertices):
69
+ verts = vertices[:3]
70
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
71
+ f1 = jax.random.uniform(_rng)
72
+ f2 = jax.random.uniform(_rng2)
73
+ # https://www.reddit.com/r/godot/comments/mqp29g/how_do_i_get_a_random_position_inside_a_collision/
74
+ return verts[0] + jnp.sqrt(f1) * (-verts[0] + verts[1] + f2 * (verts[2] - verts[1]))
75
+
76
+
77
+ def random_position_on_rectangle(rng, vertices):
78
+ verts = vertices[:4]
79
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
80
+ f1 = jax.random.uniform(_rng)
81
+ f2 = jax.random.uniform(_rng2)
82
+
83
+ min_x, max_x = jnp.min(verts[:, 0]), jnp.max(verts[:, 0])
84
+ min_y, max_y = jnp.min(verts[:, 1]), jnp.max(verts[:, 1])
85
+ random_x_pos = min_x + f1 * (max_x - min_x)
86
+ random_y_pos = min_y + f2 * (max_y - min_y)
87
+
88
+ return jnp.array([random_x_pos, random_y_pos])
89
+
90
+
91
+ def random_position_on_polygon(rng, vertices, n_vertices, static_env_params: StaticEnvParams):
92
+ assert static_env_params.max_polygon_vertices <= 4, "Only supports up to 4 vertices"
93
+ return jax.lax.select(
94
+ n_vertices <= 3, random_position_on_triangle(rng, vertices), random_position_on_rectangle(rng, vertices)
95
+ )
96
+
97
+
98
+ def random_position_on_circle(rng, radius, on_centre_chance):
99
+ rngs = jax.random.split(rng, 3)
100
+
101
+ on_centre = jax.random.uniform(rngs[0]) < on_centre_chance
102
+
103
+ local_joint_position_circle_theta = jax.random.uniform(rngs[1], shape=()) * 2 * math.pi
104
+ local_joint_position_circle_r = jax.random.uniform(rngs[2], shape=()) * radius
105
+ local_joint_position_circle = jnp.array(
106
+ [
107
+ local_joint_position_circle_r * jnp.cos(local_joint_position_circle_theta),
108
+ local_joint_position_circle_r * jnp.sin(local_joint_position_circle_theta),
109
+ ]
110
+ )
111
+
112
+ return jax.lax.select(on_centre, jnp.array([0.0, 0.0]), local_joint_position_circle)
113
+
114
+
115
+ def get_role(rng, state: EnvState, static_env_params: StaticEnvParams, initial_p=None) -> int:
116
+
117
+ if initial_p is None:
118
+ initial_p = jnp.array([1.0, 1.0, 1.0, 1.0])
119
+
120
+ needs_ball = count_roles(state, static_env_params, 1) == 0
121
+ needs_goal = count_roles(state, static_env_params, 2) == 0
122
+ needs_lava = count_roles(state, static_env_params, 3) == 0
123
+
124
+ # always put goal/ball first.
125
+ prob_of_something_else = (needs_ball == 0) & (needs_goal == 0)
126
+ p = initial_p * jnp.array(
127
+ [prob_of_something_else, needs_ball, needs_goal, prob_of_something_else * needs_lava / 3]
128
+ ) # This ensures we cannot more than one ball or goal.
129
+ return jax.random.choice(rng, jnp.array([0, 1, 2, 3]), p=p)
130
+
131
+
132
+ def is_space_for_shape(state: EnvState):
133
+ return jnp.logical_not(jnp.concatenate([state.polygon.active, state.circle.active])).sum() > 0
134
+
135
+
136
+ def is_space_for_joint(state: EnvState):
137
+ return jnp.logical_not(state.joint.active).sum() > 0
138
+
139
+
140
+ def are_there_shapes_present(state: EnvState, static_env_params: StaticEnvParams):
141
+ m = (
142
+ jnp.concatenate([state.polygon.active, state.circle.active])
143
+ .at[: static_env_params.num_static_fixated_polys]
144
+ .set(False)
145
+ )
146
+ return m.sum() > 0
147
+
148
+
149
+ @partial(jax.jit, static_argnums=(2, 9))
150
+ def add_rigidbody_to_state(
151
+ state: EnvState,
152
+ env_params: EnvParams,
153
+ static_env_params: StaticEnvParams,
154
+ position: jnp.ndarray,
155
+ vertices: jnp.ndarray,
156
+ n_vertices: int,
157
+ radius: float,
158
+ shape_role: int,
159
+ density: float = 1,
160
+ is_circle: bool = False,
161
+ ):
162
+
163
+ new_rigid_body = RigidBody(
164
+ position=position,
165
+ velocity=jnp.array([0.0, 0.0]),
166
+ inverse_mass=1.0,
167
+ inverse_inertia=1.0,
168
+ rotation=0.0,
169
+ angular_velocity=0.0,
170
+ radius=radius,
171
+ active=True,
172
+ friction=1.0,
173
+ vertices=vertices,
174
+ n_vertices=n_vertices,
175
+ collision_mode=1,
176
+ restitution=0.0,
177
+ )
178
+
179
+ if is_circle:
180
+ actives = state.circle.active
181
+ else:
182
+ actives = state.polygon.active
183
+
184
+ idx = jnp.argmin(actives)
185
+
186
+ def noop(state):
187
+ return state
188
+
189
+ def replace(state):
190
+ add_func = lambda all, new: all.at[idx].set(new)
191
+ if is_circle:
192
+ state = state.replace(
193
+ circle=jax.tree.map(add_func, state.circle, new_rigid_body),
194
+ circle_densities=state.circle_densities.at[idx].set(density),
195
+ circle_shape_roles=state.circle_shape_roles.at[idx].set(shape_role),
196
+ )
197
+ else:
198
+ state = state.replace(
199
+ polygon=jax.tree.map(add_func, state.polygon, new_rigid_body),
200
+ polygon_densities=state.polygon_densities.at[idx].set(density),
201
+ polygon_shape_roles=state.polygon_shape_roles.at[idx].set(shape_role),
202
+ )
203
+
204
+ state = state.replace(
205
+ collision_matrix=calculate_collision_matrix(static_env_params, state.joint),
206
+ )
207
+
208
+ state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
209
+ return state
210
+
211
+ return jax.lax.cond(jnp.logical_not(actives).sum() > 0, replace, noop, state)
212
+
213
+
214
+ def rectangle_vertices(half_dim):
215
+ return jnp.array(
216
+ [
217
+ half_dim * jnp.array([1, 1]),
218
+ half_dim * jnp.array([1, -1]),
219
+ half_dim * jnp.array([-1, -1]),
220
+ half_dim * jnp.array([-1, 1]),
221
+ ]
222
+ )
223
+
224
+
225
+ # More Manual Control
226
+ @partial(jax.jit, static_argnums=(2,))
227
+ def add_rectangle_to_state(
228
+ state: EnvState,
229
+ env_params: EnvParams,
230
+ static_env_params: StaticEnvParams,
231
+ position: jnp.ndarray,
232
+ width: float,
233
+ height: float,
234
+ shape_role: int,
235
+ density: float = 1,
236
+ ):
237
+
238
+ return add_rigidbody_to_state(
239
+ state,
240
+ env_params,
241
+ static_env_params,
242
+ position,
243
+ rectangle_vertices(jnp.array([width, height]) / 2),
244
+ 4,
245
+ 0.0,
246
+ shape_role,
247
+ density,
248
+ is_circle=False,
249
+ )
250
+
251
+
252
+ @partial(jax.jit, static_argnums=(2,))
253
+ def add_circle_to_state(
254
+ state: EnvState,
255
+ env_params: EnvParams,
256
+ static_env_params: StaticEnvParams,
257
+ position: jnp.ndarray,
258
+ radius: float,
259
+ shape_role: int,
260
+ density: float = 1,
261
+ ):
262
+ return add_rigidbody_to_state(
263
+ state,
264
+ env_params,
265
+ static_env_params,
266
+ position,
267
+ jnp.array([0.0, 0.0]),
268
+ 0,
269
+ radius,
270
+ shape_role,
271
+ density,
272
+ is_circle=True,
273
+ )
274
+
275
+
276
+ @partial(jax.jit, static_argnums=(2,))
277
+ def add_thruster_to_object(
278
+ state: EnvState,
279
+ env_params: EnvParams,
280
+ static_env_params: StaticEnvParams,
281
+ shape_index: int,
282
+ rotation: float,
283
+ colour: int,
284
+ thruster_power_multiplier: float,
285
+ ):
286
+ def dummy(state):
287
+ return state
288
+
289
+ def do_add(state: EnvState):
290
+ thruster_idx = jnp.argmin(state.thruster.active)
291
+
292
+ shape = select_shape(state, shape_index, static_env_params)
293
+
294
+ thruster = Thruster(
295
+ object_index=shape_index,
296
+ active=True,
297
+ relative_position=jnp.array([0.0, 0.0]), # a bit of a hack but reasonable.
298
+ rotation=rotation,
299
+ power=1.0 / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) * thruster_power_multiplier,
300
+ global_position=select_shape(state, shape_index, static_env_params).position,
301
+ )
302
+
303
+ state = state.replace(
304
+ thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
305
+ thruster_bindings=state.thruster_bindings.at[thruster_idx].set(colour),
306
+ )
307
+
308
+ return state
309
+
310
+ return jax.lax.cond(
311
+ (select_shape(state, shape_index, static_env_params).active)
312
+ & (jnp.logical_not(state.thruster.active).sum() > 0),
313
+ do_add,
314
+ dummy,
315
+ state,
316
+ )
317
+
318
+
319
+ def make_velocities_zero(state: EnvState):
320
+ def inner(state):
321
+ return state.replace(
322
+ polygon=state.polygon.replace(
323
+ angular_velocity=state.polygon.angular_velocity * 0,
324
+ velocity=state.polygon.velocity * 0,
325
+ ),
326
+ circle=state.circle.replace(
327
+ angular_velocity=state.circle.angular_velocity * 0,
328
+ velocity=state.circle.velocity * 0,
329
+ ),
330
+ )
331
+
332
+ return inner(state)
333
+
334
+
335
+ def make_do_dummy_step(
336
+ params: EnvParams, static_sim_params: StaticEnvParams, zero_collisions=True, zero_velocities=True
337
+ ):
338
+ env = PhysicsEngine(static_sim_params)
339
+
340
+ @jax.jit
341
+ def _step_fn(state):
342
+ state, _ = env.step(state, params, jnp.zeros((static_sim_params.num_joints + static_sim_params.num_thrusters,)))
343
+ return state
344
+
345
+ def do_dummy_step(state: EnvState) -> EnvState:
346
+ rng = jax.random.PRNGKey(0)
347
+ og_col = state.collision_matrix
348
+ g = state.gravity
349
+ state = state.replace(
350
+ collision_matrix=state.collision_matrix & (not zero_collisions), gravity=state.gravity * 0
351
+ )
352
+ state = _step_fn(state)
353
+ state = state.replace(gravity=g, collision_matrix=og_col)
354
+ if zero_velocities:
355
+ state = make_velocities_zero(state)
356
+ return state
357
+
358
+ return do_dummy_step
kinetix/environment/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chex
2
+ import jax
3
+ from jax2d.engine import calculate_collision_matrix
4
+ from kinetix.environment.env_state import EnvState, StaticEnvParams
5
+ import jax.numpy as jnp
6
+
7
+ from kinetix.pcg.pcg_state import PCGState
8
+
9
+
10
+ def permute_state(rng: chex.PRNGKey, env_state: EnvState, static_env_params: StaticEnvParams):
11
+ idxs_circles = jnp.arange(static_env_params.num_circles)
12
+ idxs_polygons = jnp.arange(static_env_params.num_polygons)
13
+ idxs_joints = jnp.arange(static_env_params.num_joints)
14
+ idxs_thrusters = jnp.arange(static_env_params.num_thrusters)
15
+
16
+ rng, *_rngs = jax.random.split(rng, 5)
17
+ idxs_circles_permuted = jax.random.permutation(_rngs[0], idxs_circles, independent=True)
18
+ idxs_polygons_permuted = idxs_polygons.at[static_env_params.num_static_fixated_polys :].set(
19
+ jax.random.permutation(_rngs[1], idxs_polygons[static_env_params.num_static_fixated_polys :], independent=True)
20
+ )
21
+
22
+ idxs_joints_permuted = jax.random.permutation(_rngs[2], idxs_joints, independent=True)
23
+ idxs_thrusters_permuted = jax.random.permutation(_rngs[3], idxs_thrusters, independent=True)
24
+
25
+ combined = jnp.concatenate([idxs_polygons_permuted, idxs_circles_permuted + static_env_params.num_polygons])
26
+ # Change the ordering of the shapes, and also remember to change the indices associated with the joints
27
+
28
+ inverse_permutation = jnp.argsort(combined)
29
+
30
+ env_state = env_state.replace(
31
+ polygon_shape_roles=env_state.polygon_shape_roles[idxs_polygons_permuted],
32
+ circle_shape_roles=env_state.circle_shape_roles[idxs_circles_permuted],
33
+ polygon_highlighted=env_state.polygon_highlighted[idxs_polygons_permuted],
34
+ circle_highlighted=env_state.circle_highlighted[idxs_circles_permuted],
35
+ polygon_densities=env_state.polygon_densities[idxs_polygons_permuted],
36
+ circle_densities=env_state.circle_densities[idxs_circles_permuted],
37
+ polygon=jax.tree.map(lambda x: x[idxs_polygons_permuted], env_state.polygon),
38
+ circle=jax.tree.map(lambda x: x[idxs_circles_permuted], env_state.circle),
39
+ joint=env_state.joint.replace(
40
+ a_index=inverse_permutation[env_state.joint.a_index],
41
+ b_index=inverse_permutation[env_state.joint.b_index],
42
+ ),
43
+ thruster=env_state.thruster.replace(
44
+ object_index=inverse_permutation[env_state.thruster.object_index],
45
+ ),
46
+ )
47
+
48
+ # And now permute the thrusters and joints
49
+ env_state = env_state.replace(
50
+ thruster_bindings=env_state.thruster_bindings[idxs_thrusters_permuted],
51
+ motor_bindings=env_state.motor_bindings[idxs_joints_permuted],
52
+ motor_auto=env_state.motor_auto[idxs_joints_permuted],
53
+ joint=jax.tree.map(lambda x: x[idxs_joints_permuted], env_state.joint),
54
+ thruster=jax.tree.map(lambda x: x[idxs_thrusters_permuted], env_state.thruster),
55
+ )
56
+ # and collision matrix
57
+ env_state = env_state.replace(collision_matrix=calculate_collision_matrix(static_env_params, env_state.joint))
58
+ return env_state
59
+
60
+
61
+ def permute_pcg_state(rng: chex.PRNGKey, pcg_state: PCGState, static_env_params: StaticEnvParams):
62
+ return pcg_state.replace(
63
+ env_state=permute_state(rng, pcg_state.env_state, static_env_params),
64
+ env_state_max=permute_state(rng, pcg_state.env_state_max, static_env_params),
65
+ env_state_pcg_mask=jax.tree.map(lambda x: jnp.zeros_like(x, dtype=bool), pcg_state.env_state_pcg_mask),
66
+ )
kinetix/environment/wrappers.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from chex._src.pytypes import PRNGKey
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import chex
6
+ from jax.numpy import ndarray
7
+ import numpy as np
8
+ from flax import struct
9
+ from functools import partial
10
+ from typing import Callable, Dict, Optional, Tuple, Union, Any
11
+
12
+ from gymnax.environments import spaces, environment
13
+
14
+ from kinetix.environment.env_state import EnvParams, EnvState
15
+ from jaxued.environments import UnderspecifiedEnv
16
+
17
+
18
+ class UnderspecifiedEnvWrapper(UnderspecifiedEnv):
19
+ """Base class for Gymnax wrappers."""
20
+
21
+ def __init__(self, env):
22
+ self._env = env
23
+
24
+ # provide proxy access to regular attributes of wrapped object
25
+ def __getattr__(self, name):
26
+ return getattr(self._env, name)
27
+
28
+
29
+ class GymnaxWrapper(object):
30
+ """Base class for Gymnax wrappers."""
31
+
32
+ def __init__(self, env):
33
+ self._env = env
34
+
35
+ # provide proxy access to regular attributes of wrapped object
36
+ def __getattr__(self, name):
37
+ return getattr(self._env, name)
38
+
39
+
40
+ # From Here: https://github.com/DramaCow/jaxued/blob/main/src/jaxued/wrappers/autoreset.py
41
+ class AutoResetWrapper(UnderspecifiedEnvWrapper):
42
+ """
43
+ This is a wrapper around an `UnderspecifiedEnv`, allowing for the environment to be automatically reset upon completion of an episode. This behaviour is similar to the default Gymnax interface. The user can specify a callable `sample_level` that takes in a PRNGKey and returns a level.
44
+
45
+ Warning:
46
+ To maintain compliance with UnderspecifiedEnv interface, user can reset to an
47
+ arbitrary level. This includes levels outside the support of sample_level(). Consequently,
48
+ the tagged rng is defaulted to jax.random.PRNGKey(0). If your code relies on this, careful
49
+ attention may be required.
50
+ """
51
+
52
+ def __init__(self, env: UnderspecifiedEnv, sample_level: Callable[[chex.PRNGKey], EnvState]):
53
+ self._env = env
54
+ self.sample_level = sample_level
55
+
56
+ @property
57
+ def default_params(self) -> EnvParams:
58
+ return self._env.default_params
59
+
60
+ def reset_env(self, rng, params):
61
+ rng, rng_sample, rng_reset = jax.random.split(rng, 3)
62
+ state_to_reset_to = self.sample_level(rng_sample)
63
+ return self._env.reset_env_to_pcg_level(rng_reset, state_to_reset_to, params)
64
+
65
+ def step_env(
66
+ self,
67
+ rng: chex.PRNGKey,
68
+ state: EnvState,
69
+ action: Union[int, float],
70
+ params: EnvParams,
71
+ ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
72
+
73
+ rng_reset, rng_step = jax.random.split(rng, 2)
74
+ obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state, action, params)
75
+ obs_re, env_state_re = self.reset_env(rng_reset, params)
76
+
77
+ env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
78
+ obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
79
+
80
+ return obs, env_state, reward, done, info
81
+
82
+ def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
83
+ # raise NotImplementedError("This method should not be called directly. Use reset instead.")
84
+ obs, env_state = self._env.reset_to_level(rng, level, params)
85
+ return obs, env_state
86
+
87
+ def action_space(self, params: EnvParams) -> Any:
88
+ return self._env.action_space(params)
89
+
90
+
91
+ class AutoReplayWrapper(UnderspecifiedEnv):
92
+ """
93
+ This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
94
+ This is useful for training/rolling out multiple times on the same level.
95
+ """
96
+
97
+ def __init__(self, env: UnderspecifiedEnv):
98
+ self._env = env
99
+
100
+ @property
101
+ def default_params(self) -> EnvParams:
102
+ return self._env.default_params
103
+
104
+ def step_env(
105
+ self,
106
+ rng: chex.PRNGKey,
107
+ state: EnvState,
108
+ action: Union[int, float],
109
+ params: EnvParams,
110
+ ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
111
+ rng_reset, rng_step = jax.random.split(rng)
112
+ obs_re, env_state_re = self._env.reset_to_level(rng_reset, state.level, params)
113
+ obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state.env_state, action, params)
114
+ env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
115
+ obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
116
+ return obs, state.replace(env_state=env_state), reward, done, info
117
+
118
+ def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
119
+ obs, env_state = self._env.reset_to_level(rng, level, params)
120
+ return obs, AutoReplayState(env_state=env_state, level=level)
121
+
122
+ def action_space(self, params: EnvParams) -> Any:
123
+ return self._env.action_space(params)
124
+
125
+
126
+ @struct.dataclass
127
+ class AutoReplayState:
128
+ env_state: EnvState
129
+ level: EnvState
130
+
131
+
132
+ class AutoReplayWrapper(UnderspecifiedEnvWrapper):
133
+ """
134
+ This wrapper replay the **same** level over and over again by resetting to the same level after each episode.
135
+ This is useful for training/rolling out multiple times on the same level.
136
+ """
137
+
138
+ def __init__(self, env: UnderspecifiedEnv):
139
+ self._env = env
140
+
141
+ @property
142
+ def default_params(self) -> EnvParams:
143
+ return self._env.default_params
144
+
145
+ def step_env(
146
+ self,
147
+ rng: chex.PRNGKey,
148
+ state: EnvState,
149
+ action: Union[int, float],
150
+ params: EnvParams,
151
+ ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]:
152
+ rng_reset, rng_step = jax.random.split(rng)
153
+ obs_re, env_state_re = self._env.reset_to_level(rng_reset, state.level, params)
154
+ obs_st, env_state_st, reward, done, info = self._env.step_env(rng_step, state.env_state, action, params)
155
+ env_state = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), env_state_re, env_state_st)
156
+ obs = jax.tree_map(lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st)
157
+ return obs, state.replace(env_state=env_state), reward, done, info
158
+
159
+ def reset_env_to_level(self, rng: chex.PRNGKey, level: EnvState, params: EnvParams) -> Tuple[Any, EnvState]:
160
+ obs, env_state = self._env.reset_to_level(rng, level, params)
161
+ return obs, AutoReplayState(env_state=env_state, level=level)
162
+
163
+ def action_space(self, params: EnvParams) -> Any:
164
+ return self._env.action_space(params)
165
+
166
+
167
+ class UnderspecifiedToGymnaxWrapper(environment.Environment):
168
+ def __init__(self, env):
169
+ self._env = env
170
+
171
+ # provide proxy access to regular attributes of wrapped object
172
+ def __getattr__(self, name):
173
+ return getattr(self._env, name)
174
+
175
+ @property
176
+ def default_params(self) -> Any:
177
+ return self._env.default_params
178
+
179
+ def step_env(
180
+ self, key: jax.Array, state: Any, action: int | float | jax.Array | ndarray | np.bool_ | np.number, params: Any
181
+ ) -> Tuple[jax.Array | ndarray | np.bool_ | np.number | Any | Dict[Any, Any]]:
182
+ return self._env.step_env(key, state, action, params)
183
+
184
+ def reset_env(self, key: PRNGKey, params: Any) -> Tuple[PRNGKey | np.ndarray | np.bool_ | np.number | Any]:
185
+ return self._env.reset_env(key, params)
186
+
187
+ def action_space(self, params: Any):
188
+ return self._env.action_space(params)
189
+
190
+
191
+ class BatchEnvWrapper(GymnaxWrapper):
192
+ """Batches reset and step functions"""
193
+
194
+ def __init__(self, env, num_envs: int):
195
+ super().__init__(env)
196
+
197
+ self.num_envs = num_envs
198
+
199
+ self.reset_fn = jax.vmap(self._env.reset, in_axes=(0, None))
200
+ self.reset_to_level_fn = jax.vmap(self._env.reset_to_level, in_axes=(0, 0, None))
201
+ self.step_fn = jax.vmap(self._env.step, in_axes=(0, 0, 0, None))
202
+
203
+ @partial(jax.jit, static_argnums=(0, 2))
204
+ def reset(self, rng, params=None):
205
+ rng, _rng = jax.random.split(rng)
206
+ rngs = jax.random.split(_rng, self.num_envs)
207
+ obs, env_state = self.reset_fn(rngs, params)
208
+ return obs, env_state
209
+
210
+ @partial(jax.jit, static_argnums=(0, 3))
211
+ def reset_to_level(self, rng, level, params=None):
212
+ rng, _rng = jax.random.split(rng)
213
+ rngs = jax.random.split(_rng, self.num_envs)
214
+ obs, env_state = self.reset_to_level_fn(rngs, level, params)
215
+ return obs, env_state
216
+
217
+ @partial(jax.jit, static_argnums=(0, 4))
218
+ def step(self, rng, state, action, params=None):
219
+ rng, _rng = jax.random.split(rng)
220
+ rngs = jax.random.split(_rng, self.num_envs)
221
+ obs, state, reward, done, info = self.step_fn(rngs, state, action, params)
222
+
223
+ return obs, state, reward, done, info
224
+
225
+
226
+ @struct.dataclass
227
+ class DenseRewardState:
228
+ env_state: EnvState
229
+ last_distance: float = -1.0
230
+
231
+
232
+ class DenseRewardWrapper(GymnaxWrapper):
233
+ def __init__(self, env, dense_reward_scale: float = 1.0) -> None:
234
+ super().__init__(env)
235
+ self.dense_reward_scale = dense_reward_scale
236
+
237
+ def step(self, key, state, action: int, params=None):
238
+ obs, env_state, reward, done, info = self._env.step_env(key, state.env_state, action, params)
239
+ delta_dist = (
240
+ -(info["distance"] - state.last_distance) * params.dense_reward_scale
241
+ ) # if distance got less, then reward is positive
242
+
243
+ delta_dist = jnp.nan_to_num(delta_dist, nan=0.0, posinf=0.0, neginf=0.0)
244
+ reward = reward + jax.lax.select(
245
+ (state.last_distance == -1) | (self.dense_reward_scale == 0.0), 0.0, delta_dist * self.dense_reward_scale
246
+ )
247
+ return obs, DenseRewardState(env_state, info["distance"]), reward, done, info
248
+
249
+ def reset(self, rng, params=None):
250
+ obs, env_state = self._env.reset(rng, params)
251
+ return obs, DenseRewardState(env_state, -1.0)
252
+
253
+ def reset_to_level(self, rng, level, params=None):
254
+ obs, env_state = self._env.reset_to_level(rng, level, params)
255
+ return obs, DenseRewardState(env_state, -1.0)
256
+
257
+
258
+ @struct.dataclass
259
+ class LogEnvState:
260
+ env_state: Any
261
+ episode_returns: float
262
+ episode_lengths: int
263
+ returned_episode_returns: float
264
+ returned_episode_lengths: int
265
+ timestep: int
266
+
267
+
268
+ class LogWrapper(GymnaxWrapper):
269
+ """Log the episode returns and lengths."""
270
+
271
+ def __init__(self, env):
272
+ super().__init__(env)
273
+
274
+ @partial(jax.jit, static_argnums=(0, 2))
275
+ def reset(self, key: chex.PRNGKey, params=None):
276
+ obs, env_state = self._env.reset(key, params)
277
+ state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
278
+ return obs, state
279
+
280
+ def reset_to_level(self, key: chex.PRNGKey, level: EnvState, params=None):
281
+ obs, env_state = self._env.reset_to_level(key, level, params)
282
+ state = LogEnvState(env_state, 0.0, 0, 0.0, 0, 0)
283
+ return obs, state
284
+
285
+ @partial(jax.jit, static_argnums=(0, 4))
286
+ def step(
287
+ self,
288
+ key: chex.PRNGKey,
289
+ state,
290
+ action: Union[int, float],
291
+ params=None,
292
+ ):
293
+ obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
294
+ new_episode_return = state.episode_returns + reward
295
+ new_episode_length = state.episode_lengths + 1
296
+ state = LogEnvState(
297
+ env_state=env_state,
298
+ episode_returns=new_episode_return * (1 - done),
299
+ episode_lengths=new_episode_length * (1 - done),
300
+ returned_episode_returns=state.returned_episode_returns * (1 - done) + new_episode_return * done,
301
+ returned_episode_lengths=state.returned_episode_lengths * (1 - done) + new_episode_length * done,
302
+ timestep=state.timestep + 1,
303
+ )
304
+ info["returned_episode_returns"] = state.returned_episode_returns
305
+ info["returned_episode_lengths"] = state.returned_episode_lengths
306
+ info["returned_episode_solved"] = info["GoalR"]
307
+ info["timestep"] = state.timestep
308
+ info["returned_episode"] = done
309
+ return obs, state, reward, done, info
kinetix/models/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ actor_critic_old.py
2
+ gactor_gritic_old.py
kinetix/models/__init__.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from kinetix.models.actor_critic import (
2
+ ActorCriticPixelsRNN,
3
+ ActorCriticSymbolicRNN,
4
+ )
5
+ from kinetix.models.transformer_model import ActorCriticTransformer
6
+
7
+
8
+ def make_network_from_config(env, env_params, config, network_kws={}):
9
+
10
+ env_name = config["env_name"]
11
+ if "MultiDiscrete" in env_name:
12
+ action_mode = "multi_discrete"
13
+ elif "Discrete" in env_name:
14
+ action_mode = "discrete"
15
+ elif "Continuous" in env_name:
16
+ action_mode = "continuous"
17
+ elif "Hybrid" in env_name:
18
+ action_mode = "hybrid"
19
+ else:
20
+ raise ValueError(f"Unknown action mode for {env_name}")
21
+ action_dim = (
22
+ env.action_space(env_params).shape[0] if action_mode == "continuous" else env.action_space(env_params).n
23
+ )
24
+ if "hybrid_action_continuous_dim" not in network_kws:
25
+ network_kws["hybrid_action_continuous_dim"] = action_dim
26
+
27
+ if "multi_discrete_number_of_dims_per_distribution" not in network_kws:
28
+ num_joint_bindings = config["static_env_params"]["num_motor_bindings"]
29
+ num_thruster_bindings = config["static_env_params"]["num_thruster_bindings"]
30
+ network_kws["multi_discrete_number_of_dims_per_distribution"] = [3 for _ in range(num_joint_bindings)] + [
31
+ 2 for _ in range(num_thruster_bindings)
32
+ ]
33
+ network_kws["recurrent"] = config.get("recurrent_model", True)
34
+
35
+ if "Pixels" in env_name:
36
+ cls_to_use = ActorCriticPixelsRNN
37
+ elif "Symbolic" in env_name or "Blind" in env_name:
38
+ cls_to_use = ActorCriticSymbolicRNN
39
+
40
+ if "Entity" in env_name:
41
+ network = ActorCriticTransformer(
42
+ action_dim=action_dim,
43
+ fc_layer_width=config["fc_layer_width"],
44
+ fc_layer_depth=config["fc_layer_depth"],
45
+ action_mode=action_mode,
46
+ num_heads=config["num_heads"],
47
+ transformer_depth=config["transformer_depth"],
48
+ transformer_size=config["transformer_size"],
49
+ transformer_encoder_size=config["transformer_encoder_size"],
50
+ aggregate_mode=config["aggregate_mode"],
51
+ full_attention_mask=config["full_attention_mask"],
52
+ activation=config["activation"],
53
+ **network_kws,
54
+ )
55
+ else:
56
+ network = cls_to_use(
57
+ action_dim,
58
+ fc_layer_width=config["fc_layer_width"],
59
+ fc_layer_depth=config["fc_layer_depth"],
60
+ activation=config["activation"],
61
+ action_mode=action_mode,
62
+ **network_kws,
63
+ )
64
+
65
+ return network
kinetix/models/action_spaces.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Sequence
2
+ from chex import PRNGKey
3
+ import distrax
4
+ from flax import struct
5
+ import jax
6
+ import jax.numpy as jnp
7
+
8
+
9
+ @struct.dataclass
10
+ class HybridAction:
11
+ discrete: int
12
+ continuous: jnp.ndarray
13
+
14
+
15
+ class HybridActionDistribution(distrax.Distribution):
16
+ def __init__(self, discrete_logits, continuous_mu, continuous_sigma) -> None:
17
+ self.discrete = distrax.Categorical(logits=discrete_logits)
18
+ self.continuous = distrax.MultivariateNormalDiag(continuous_mu, continuous_sigma)
19
+
20
+ def _sample_n(self, rng: PRNGKey, n: int) -> Any:
21
+ rng, _rng, _rng2 = jax.random.split(rng, 3)
22
+ a = self.discrete._sample_n(_rng, n)
23
+ b = self.continuous._sample_n(_rng2, n)
24
+ return HybridAction(a, b)
25
+
26
+ def log_prob(self, value: Any):
27
+ a = self.discrete.log_prob(value.discrete)
28
+ b = self.continuous.log_prob(value.continuous)
29
+ return a + b # log probs, we add.
30
+
31
+ def entropy(self):
32
+ return self.discrete.entropy() + self.continuous.entropy()
33
+
34
+ def event_shape(self) -> Sequence[int]:
35
+ return ()
36
+
37
+
38
+ class MultiDiscreteActionDistribution(distrax.Distribution):
39
+ def __init__(self, flat_logits, number_of_dims_per_distribution) -> None:
40
+ self.distributions = []
41
+ total_dims = 0
42
+ for dims in number_of_dims_per_distribution:
43
+ self.distributions.append(distrax.Categorical(logits=flat_logits[..., total_dims : total_dims + dims]))
44
+ total_dims += dims
45
+
46
+ def _sample_n(self, key: PRNGKey, n: int) -> Any:
47
+ rngs = jax.random.split(key, len(self.distributions))
48
+ samples = [jnp.expand_dims(d._sample_n(rng, n), axis=-1) for rng, d in zip(rngs, self.distributions)]
49
+ return jnp.concatenate(samples, axis=-1)
50
+
51
+ def log_prob(self, value: Any):
52
+ return sum(d.log_prob(value[..., i]) for i, d in enumerate(self.distributions))
53
+
54
+ def entropy(self):
55
+ return sum(d.entropy() for d in self.distributions)
56
+
57
+ def event_shape(self) -> Sequence[int]:
58
+ return ()
kinetix/models/actor_critic.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import flax.linen as nn
5
+ import numpy as np
6
+ from flax.linen.initializers import constant, orthogonal
7
+ from typing import List, Sequence
8
+
9
+ import distrax
10
+
11
+ from kinetix.models.action_spaces import HybridActionDistribution, MultiDiscreteActionDistribution
12
+
13
+
14
+ class ScannedRNN(nn.Module):
15
+ @functools.partial(
16
+ nn.scan,
17
+ variable_broadcast="params",
18
+ in_axes=0,
19
+ out_axes=0,
20
+ split_rngs={"params": False},
21
+ )
22
+ @nn.compact
23
+ def __call__(self, carry, x):
24
+ """Applies the module."""
25
+ rnn_state = carry
26
+ ins, resets = x
27
+ rnn_state = jnp.where(
28
+ resets[:, np.newaxis],
29
+ self.initialize_carry(ins.shape[0], 256),
30
+ rnn_state,
31
+ )
32
+ new_rnn_state, y = nn.GRUCell(features=256)(rnn_state, ins)
33
+ return new_rnn_state, y
34
+
35
+ @staticmethod
36
+ def initialize_carry(batch_size, hidden_size=256):
37
+ # Use a dummy key since the default state init fn is just zeros.
38
+ cell = nn.GRUCell(features=256)
39
+ return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
40
+
41
+
42
+ class GeneralActorCriticRNN(nn.Module):
43
+ action_dim: Sequence[int]
44
+ fc_layer_depth: int
45
+ fc_layer_width: int
46
+ action_mode: str # "continuous" or "discrete" or "hybrid"
47
+ hybrid_action_continuous_dim: int
48
+ multi_discrete_number_of_dims_per_distribution: List[int]
49
+ add_generator_embedding: bool = False
50
+ generator_embedding_number_of_timesteps: int = 10
51
+ recurrent: bool = False
52
+
53
+ # Given an embedding, return the action/values, since this is shared across all models.
54
+ @nn.compact
55
+ def __call__(self, hidden, obs, embedding, dones, activation):
56
+
57
+ if self.add_generator_embedding:
58
+ raise NotImplementedError()
59
+
60
+ if self.recurrent:
61
+ rnn_in = (embedding, dones)
62
+ hidden, embedding = ScannedRNN()(hidden, rnn_in)
63
+
64
+ actor_mean = embedding
65
+ critic = embedding
66
+ actor_mean_last = embedding
67
+ for _ in range(self.fc_layer_depth):
68
+ actor_mean = nn.Dense(
69
+ self.fc_layer_width,
70
+ kernel_init=orthogonal(np.sqrt(2)),
71
+ bias_init=constant(0.0),
72
+ )(actor_mean)
73
+ actor_mean = activation(actor_mean)
74
+
75
+ critic = nn.Dense(
76
+ self.fc_layer_width,
77
+ kernel_init=orthogonal(np.sqrt(2)),
78
+ bias_init=constant(0.0),
79
+ )(critic)
80
+ critic = activation(critic)
81
+
82
+ actor_mean_last = actor_mean
83
+ actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
84
+ if self.action_mode == "discrete":
85
+ pi = distrax.Categorical(logits=actor_mean)
86
+ elif self.action_mode == "continuous":
87
+ actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
88
+ pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
89
+ elif self.action_mode == "multi_discrete":
90
+ pi = MultiDiscreteActionDistribution(actor_mean, self.multi_discrete_number_of_dims_per_distribution)
91
+ else:
92
+ actor_mean_continuous = nn.Dense(
93
+ self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
94
+ )(actor_mean_last)
95
+ actor_mean_sigma = jnp.exp(
96
+ nn.Dense(self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(
97
+ actor_mean_last
98
+ )
99
+ )
100
+ pi = HybridActionDistribution(actor_mean, actor_mean_continuous, actor_mean_sigma)
101
+
102
+ critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
103
+ return hidden, pi, jnp.squeeze(critic, axis=-1)
104
+
105
+
106
+ class ActorCriticPixelsRNN(nn.Module):
107
+
108
+ action_dim: Sequence[int]
109
+ fc_layer_depth: int
110
+ fc_layer_width: int
111
+ action_mode: str
112
+ hybrid_action_continuous_dim: int
113
+ multi_discrete_number_of_dims_per_distribution: List[int]
114
+ activation: str
115
+ add_generator_embedding: bool = False
116
+ generator_embedding_number_of_timesteps: int = 10
117
+ recurrent: bool = True
118
+
119
+ @nn.compact
120
+ def __call__(self, hidden, x, **kwargs):
121
+ if self.activation == "relu":
122
+ activation = nn.relu
123
+ else:
124
+ activation = nn.tanh
125
+ og_obs, dones = x
126
+
127
+ if self.add_generator_embedding:
128
+ obs = og_obs.obs
129
+ else:
130
+ obs = og_obs
131
+
132
+ image = obs.image
133
+ global_info = obs.global_info
134
+
135
+ x = nn.Conv(features=16, kernel_size=(8, 8), strides=(4, 4))(image)
136
+ x = nn.relu(x)
137
+ x = nn.Conv(features=32, kernel_size=(4, 4), strides=(2, 2))(x)
138
+ x = nn.relu(x)
139
+ embedding = x.reshape(x.shape[0], x.shape[1], -1)
140
+
141
+ embedding = jnp.concatenate([embedding, global_info], axis=-1)
142
+
143
+ return GeneralActorCriticRNN(
144
+ action_dim=self.action_dim,
145
+ fc_layer_depth=self.fc_layer_depth,
146
+ fc_layer_width=self.fc_layer_width,
147
+ action_mode=self.action_mode,
148
+ hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
149
+ multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
150
+ add_generator_embedding=self.add_generator_embedding,
151
+ generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
152
+ recurrent=self.recurrent,
153
+ )(hidden, og_obs, embedding, dones, activation)
154
+
155
+ @staticmethod
156
+ def initialize_carry(batch_size, hidden_size=256):
157
+ return ScannedRNN.initialize_carry(batch_size, hidden_size)
158
+
159
+
160
+ class ActorCriticSymbolicRNN(nn.Module):
161
+ action_dim: Sequence[int]
162
+ fc_layer_width: int
163
+ action_mode: str
164
+ hybrid_action_continuous_dim: int
165
+ multi_discrete_number_of_dims_per_distribution: List[int]
166
+ fc_layer_depth: int
167
+ activation: str
168
+ add_generator_embedding: bool = False
169
+ generator_embedding_number_of_timesteps: int = 10
170
+ recurrent: bool = True
171
+
172
+ @nn.compact
173
+ def __call__(self, hidden, x):
174
+ if self.activation == "relu":
175
+ activation = nn.relu
176
+ else:
177
+ activation = nn.tanh
178
+
179
+ og_obs, dones = x
180
+ if self.add_generator_embedding:
181
+ obs = og_obs.obs
182
+ else:
183
+ obs = og_obs
184
+
185
+ embedding = nn.Dense(
186
+ self.fc_layer_width,
187
+ kernel_init=orthogonal(np.sqrt(2)),
188
+ bias_init=constant(0.0),
189
+ )(obs)
190
+ embedding = nn.relu(embedding)
191
+
192
+ return GeneralActorCriticRNN(
193
+ action_dim=self.action_dim,
194
+ fc_layer_depth=self.fc_layer_depth,
195
+ fc_layer_width=self.fc_layer_width,
196
+ action_mode=self.action_mode,
197
+ hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
198
+ multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
199
+ add_generator_embedding=self.add_generator_embedding,
200
+ generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
201
+ recurrent=self.recurrent,
202
+ )(hidden, og_obs, embedding, dones, activation)
203
+
204
+ @staticmethod
205
+ def initialize_carry(batch_size, hidden_size=256):
206
+ return ScannedRNN.initialize_carry(batch_size, hidden_size)
kinetix/models/rel_multi_head.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The Flax Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ # CODE IS HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
17
+ # MOST OF THE TIME JUST A CONVERSION IN JAX
18
+
19
+
20
+ """Relative Attention HEAVILY INSPIRED FROM https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
21
+ , flax attention, https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L143, most of the time just a flax/jax conversion """
22
+
23
+ import functools
24
+ from typing import Any, Callable, Optional, Tuple
25
+ from flax.linen.dtypes import promote_dtype
26
+
27
+ from flax.linen import initializers
28
+ from flax.linen.linear import default_kernel_init
29
+ from flax.linen.linear import DenseGeneral
30
+ from flax.linen.linear import DotGeneralT
31
+ from flax.linen.linear import PrecisionLike
32
+ from flax.linen.module import compact
33
+ from flax.linen.module import merge_param
34
+ from flax.linen.module import Module
35
+
36
+ import jax
37
+ from jax import lax
38
+ from jax import random
39
+ import jax.numpy as jnp
40
+
41
+ PRNGKey = Any
42
+ Shape = Tuple[int, ...]
43
+ Dtype = Any
44
+ Array = Any
45
+
46
+ roll_vmap = jax.vmap(jnp.roll, in_axes=(-2, 0, None), out_axes=-2)
47
+
48
+
49
+ def _rel_shift(x):
50
+ zero_pad_shape = x.shape[:-2] + (x.shape[-2], 1)
51
+ zero_pad = jnp.zeros(zero_pad_shape, dtype=x.dtype)
52
+ x_padded = jnp.concatenate([zero_pad, x], axis=-1)
53
+
54
+ x_padded_shape = x.shape[:-2] + (x.shape[-1] + 1, x.shape[-2])
55
+ x_padded = x_padded.reshape(x_padded_shape)
56
+ # x_padded=jnp.swapaxes(x_padded,0,1)
57
+
58
+ x = jnp.take(x_padded, jnp.arange(1, x_padded.shape[-2]), axis=-2).reshape(x.shape)
59
+
60
+ return x
61
+
62
+
63
+ def dot_product_attention_weights(
64
+ query: Array,
65
+ key: Array,
66
+ r_pos_embed,
67
+ r_r_bias,
68
+ r_w_bias,
69
+ bias: Optional[Array] = None,
70
+ mask: Optional[Array] = None,
71
+ broadcast_dropout: bool = True,
72
+ dropout_rng: Optional[PRNGKey] = None,
73
+ dropout_rate: float = 0.0,
74
+ deterministic: bool = False,
75
+ dtype: Optional[Dtype] = None,
76
+ precision: PrecisionLike = None,
77
+ ):
78
+ """Computes dot-product attention weights given query and key.
79
+
80
+ Used by :func:`dot_product_attention`, which is what you'll most likely use.
81
+ But if you want access to the attention weights for introspection, then
82
+ you can directly call this function and call einsum yourself.
83
+
84
+ Args:
85
+ query: queries for calculating attention with shape of
86
+ `[batch..., q_length, num_heads, qk_depth_per_head]`.
87
+ key: keys for calculating attention with shape of
88
+ `[batch..., kv_length, num_heads, qk_depth_per_head]`.
89
+ bias: bias for the attention weights. This should be broadcastable to the
90
+ shape `[batch..., num_heads, q_length, kv_length]`.
91
+ This can be used for incorporating causal masks, padding masks,
92
+ proximity bias, etc.
93
+ mask: mask for the attention weights. This should be broadcastable to the
94
+ shape `[batch..., num_heads, q_length, kv_length]`.
95
+ This can be used for incorporating causal masks.
96
+ Attention weights are masked out if their corresponding mask value
97
+ is `False`.
98
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
99
+ dropout_rng: JAX PRNGKey: to be used for dropout
100
+ dropout_rate: dropout rate
101
+ deterministic: bool, deterministic or not (to apply dropout)
102
+ dtype: the dtype of the computation (default: infer from inputs and params)
103
+ precision: numerical precision of the computation see `jax.lax.Precision`
104
+ for details.
105
+
106
+ Returns:
107
+ Output of shape `[batch..., num_heads, q_length, kv_length]`.
108
+ """
109
+ query, key = promote_dtype(query, key, dtype=dtype)
110
+ dtype = query.dtype
111
+
112
+ assert query.ndim == key.ndim, "q, k must have same rank."
113
+ assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
114
+ assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
115
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
116
+
117
+ # calculate attention matrix
118
+ depth = query.shape[-1]
119
+ # query = query
120
+
121
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
122
+ attn_weights = jnp.einsum("...qhd,...khd->...hqk", query + r_w_bias, key, precision=precision)
123
+
124
+ attn_weights_r = jnp.einsum("...qhd,khd->...hqk", query + r_r_bias, r_pos_embed, precision=precision)
125
+
126
+ attn_weights_r = roll_vmap(attn_weights_r, jnp.arange(0, query.shape[-3]) - (query.shape[-3] - 1), -1)
127
+ # attn_weights_r=_rel_shift(attn_weights_r)
128
+ attn_weights = attn_weights + attn_weights_r
129
+
130
+ attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
131
+
132
+ # apply attention bias: masking, dropout, proximity bias, etc.
133
+ if bias is not None:
134
+ attn_weights = attn_weights + bias
135
+ # apply attention mask
136
+ if mask is not None:
137
+ big_neg = jnp.finfo(dtype).min
138
+ attn_weights = jnp.where(mask, attn_weights, big_neg)
139
+
140
+ # normalize the attention weights
141
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
142
+
143
+ # apply attention dropout
144
+ if not deterministic and dropout_rate > 0.0:
145
+ keep_prob = 1.0 - dropout_rate
146
+ if broadcast_dropout:
147
+ # dropout is broadcast across the batch + head dimensions
148
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
149
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
150
+ else:
151
+ keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
152
+ multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
153
+ attn_weights = attn_weights * multiplier
154
+
155
+ return attn_weights
156
+
157
+
158
+ def dot_product_attention(
159
+ query: Array,
160
+ key: Array,
161
+ value: Array,
162
+ r_pos_embed,
163
+ r_r_bias,
164
+ r_w_bias,
165
+ bias: Optional[Array] = None,
166
+ mask: Optional[Array] = None,
167
+ broadcast_dropout: bool = True,
168
+ dropout_rng: Optional[PRNGKey] = None,
169
+ dropout_rate: float = 0.0,
170
+ deterministic: bool = False,
171
+ dtype: Optional[Dtype] = None,
172
+ precision: PrecisionLike = None,
173
+ ):
174
+ """Computes dot-product attention given query, key, and value.
175
+
176
+ This is the core function for applying attention based on
177
+ https://arxiv.org/abs/1706.03762. It calculates the attention weights given
178
+ query and key and combines the values using the attention weights.
179
+
180
+ Note: query, key, value needn't have any batch dimensions.
181
+
182
+ Args:
183
+ query: queries for calculating attention with shape of
184
+ `[batch..., q_length, num_heads, qk_depth_per_head]`.
185
+ key: keys for calculating attention with shape of
186
+ `[batch..., kv_length, num_heads, qk_depth_per_head]`.
187
+ value: values to be used in attention with shape of
188
+ `[batch..., kv_length, num_heads, v_depth_per_head]`.
189
+ bias: bias for the attention weights. This should be broadcastable to the
190
+ shape `[batch..., num_heads, q_length, kv_length]`.
191
+ This can be used for incorporating causal masks, padding masks,
192
+ proximity bias, etc.
193
+ mask: mask for the attention weights. This should be broadcastable to the
194
+ shape `[batch..., num_heads, q_length, kv_length]`.
195
+ This can be used for incorporating causal masks.
196
+ Attention weights are masked out if their corresponding mask value
197
+ is `False`.
198
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
199
+ dropout_rng: JAX PRNGKey: to be used for dropout
200
+ dropout_rate: dropout rate
201
+ deterministic: bool, deterministic or not (to apply dropout)
202
+ dtype: the dtype of the computation (default: infer from inputs)
203
+ precision: numerical precision of the computation see `jax.lax.Precision`
204
+ for details.
205
+
206
+ Returns:
207
+ Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
208
+ """
209
+ query, key, value = promote_dtype(query, key, value, dtype=dtype)
210
+ dtype = query.dtype
211
+ assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
212
+ assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
213
+ assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
214
+ assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
215
+
216
+ # compute attention weights
217
+ attn_weights = dot_product_attention_weights(
218
+ query,
219
+ key,
220
+ r_pos_embed,
221
+ r_r_bias,
222
+ r_w_bias,
223
+ bias,
224
+ mask,
225
+ broadcast_dropout,
226
+ dropout_rng,
227
+ dropout_rate,
228
+ deterministic,
229
+ dtype,
230
+ precision,
231
+ )
232
+
233
+ # return weighted sum over values for each query position
234
+ return jnp.einsum("...hqk,...khd->...qhd", attn_weights, value, precision=precision)
235
+
236
+
237
+ class RelMultiHeadDotProductAttention(Module):
238
+ """Multi-head dot-product attention.
239
+
240
+ Attributes:
241
+ num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
242
+ should be divisible by the number of heads.
243
+ dtype: the dtype of the computation (default: infer from inputs and params)
244
+ param_dtype: the dtype passed to parameter initializers (default: float32)
245
+ qkv_features: dimension of the key, query, and value.
246
+ out_features: dimension of the last projection
247
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
248
+ dropout_rate: dropout rate
249
+ deterministic: if false, the attention weight is masked randomly using
250
+ dropout, whereas if true, the attention weights are deterministic.
251
+ precision: numerical precision of the computation see `jax.lax.Precision`
252
+ for details.
253
+ kernel_init: initializer for the kernel of the Dense layers.
254
+ bias_init: initializer for the bias of the Dense layers.
255
+ use_bias: bool: whether pointwise QKVO dense transforms use bias.
256
+ attention_fn: dot_product_attention or compatible function. Accepts query,
257
+ key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
258
+ num_heads, value_channels]``
259
+ decode: whether to prepare and use an autoregressive cache.
260
+ """
261
+
262
+ num_heads: int
263
+ dtype: Optional[Dtype] = None
264
+ param_dtype: Dtype = jnp.float32
265
+ qkv_features: Optional[int] = None
266
+ out_features: Optional[int] = None
267
+ broadcast_dropout: bool = True
268
+ dropout_rate: float = 0.0
269
+ deterministic: Optional[bool] = None
270
+ precision: PrecisionLike = None
271
+ kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
272
+ bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
273
+ use_bias: bool = True
274
+ attention_fn: Callable[..., Array] = dot_product_attention
275
+ decode: bool = False
276
+ qkv_dot_general: DotGeneralT = lax.dot_general
277
+ out_dot_general: DotGeneralT = lax.dot_general
278
+
279
+ @compact
280
+ def __call__(
281
+ self,
282
+ inputs_q: Array,
283
+ inputs_kv: Array,
284
+ pos_embed: Array,
285
+ mask: Optional[Array] = None,
286
+ deterministic: Optional[bool] = None,
287
+ ):
288
+ """Applies multi-head dot product attention on the input data.
289
+
290
+ Projects the inputs into multi-headed query, key, and value vectors,
291
+ applies dot-product attention and project the results to an output vector.
292
+
293
+ Args:
294
+ inputs_q: input queries of shape
295
+ `[batch_sizes..., length, features]`.
296
+ inputs_kv: key/values of shape
297
+ `[batch_sizes..., length, features]`.
298
+ mask: attention mask of shape
299
+ `[batch_sizes..., num_heads, query_length, key/value_length]`.
300
+ Attention weights are masked out if their corresponding mask value
301
+ is `False`.
302
+ deterministic: if false, the attention weight is masked randomly
303
+ using dropout, whereas if true, the attention weights
304
+ are deterministic.
305
+
306
+ Returns:
307
+ output of shape `[batch_sizes..., length, features]`.
308
+ """
309
+ features = self.out_features or inputs_q.shape[-1]
310
+ qkv_features = self.qkv_features or inputs_q.shape[-1]
311
+ assert qkv_features % self.num_heads == 0, (
312
+ f"Memory dimension ({qkv_features}) must be divisible by number of" f" heads ({self.num_heads})."
313
+ )
314
+ head_dim = qkv_features // self.num_heads
315
+
316
+ dense = functools.partial(
317
+ DenseGeneral,
318
+ axis=-1,
319
+ dtype=self.dtype,
320
+ param_dtype=self.param_dtype,
321
+ features=(self.num_heads, head_dim),
322
+ kernel_init=self.kernel_init,
323
+ bias_init=self.bias_init,
324
+ use_bias=self.use_bias,
325
+ precision=self.precision,
326
+ dot_general=self.qkv_dot_general,
327
+ )
328
+ # project inputs_q to multi-headed q/k/v
329
+ # dimensions are then [batch..., length, n_heads, n_features_per_head]
330
+ query, key, value = (
331
+ dense(name="query")(inputs_q),
332
+ dense(name="key")(inputs_kv),
333
+ dense(name="value")(inputs_kv),
334
+ )
335
+
336
+ # different bc no bias
337
+ dense_relpos = functools.partial(
338
+ DenseGeneral,
339
+ axis=-1,
340
+ dtype=self.dtype,
341
+ param_dtype=self.param_dtype,
342
+ features=(self.num_heads, head_dim),
343
+ kernel_init=self.kernel_init,
344
+ use_bias=False,
345
+ precision=self.precision,
346
+ dot_general=self.qkv_dot_general,
347
+ )
348
+
349
+ r_pos_embed = dense_relpos(name="pos_embed_mat")(pos_embed)
350
+
351
+ r_r_bias = self.param("r_r_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function
352
+ r_w_bias = self.param("r_w_bias", self.bias_init, (self.num_heads, head_dim)) # Initialization function
353
+
354
+ # During fast autoregressive decoding, we feed one position at a time,
355
+ # and cache the keys and values step by step.
356
+ if self.decode:
357
+ # detect if we're initializing by absence of existing cache data.
358
+ is_initialized = self.has_variable("cache", "cached_key")
359
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
360
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
361
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
362
+ if is_initialized:
363
+ (
364
+ *batch_dims,
365
+ max_length,
366
+ num_heads,
367
+ depth_per_head,
368
+ ) = cached_key.value.shape
369
+ # shape check of cached keys against query input
370
+ expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
371
+ if expected_shape != query.shape:
372
+ raise ValueError(
373
+ "Autoregressive cache shape error, "
374
+ "expected query shape %s instead got %s." % (expected_shape, query.shape)
375
+ )
376
+ # update key, value caches with our new 1d spatial slices
377
+ cur_index = cache_index.value
378
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
379
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
380
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
381
+ cached_key.value = key
382
+ cached_value.value = value
383
+ cache_index.value = cache_index.value + 1
384
+ # causal mask for cached decoder self-attention:
385
+ # our single query position should only attend to those key
386
+ # positions that have already been generated and cached,
387
+ # not the remaining zero elements.
388
+ mask = combine_masks(
389
+ mask,
390
+ jnp.broadcast_to(
391
+ jnp.arange(max_length) <= cur_index,
392
+ tuple(batch_dims) + (1, 1, max_length),
393
+ ),
394
+ )
395
+
396
+ dropout_rng = None
397
+ if self.dropout_rate > 0.0: # Require `deterministic` only if using dropout.
398
+ m_deterministic = merge_param("deterministic", self.deterministic, deterministic)
399
+ if not m_deterministic:
400
+ dropout_rng = self.make_rng("dropout")
401
+ else:
402
+ m_deterministic = True
403
+
404
+ # apply attention
405
+ x = self.attention_fn(
406
+ query,
407
+ key,
408
+ value,
409
+ r_pos_embed,
410
+ r_r_bias,
411
+ r_w_bias,
412
+ mask=mask,
413
+ dropout_rng=dropout_rng,
414
+ dropout_rate=self.dropout_rate,
415
+ broadcast_dropout=self.broadcast_dropout,
416
+ deterministic=m_deterministic,
417
+ dtype=self.dtype,
418
+ precision=self.precision,
419
+ ) # pytype: disable=wrong-keyword-args
420
+ # back to the original inputs dimensions
421
+ out = DenseGeneral(
422
+ features=features,
423
+ axis=(-2, -1),
424
+ kernel_init=self.kernel_init,
425
+ bias_init=self.bias_init,
426
+ use_bias=self.use_bias,
427
+ dtype=self.dtype,
428
+ param_dtype=self.param_dtype,
429
+ precision=self.precision,
430
+ dot_general=self.out_dot_general,
431
+ name="out", # type: ignore[call-arg]
432
+ )(x)
433
+ return out
434
+
435
+
436
+ class SelfAttention(RelMultiHeadDotProductAttention):
437
+ """Self-attention special case of multi-head dot-product attention."""
438
+
439
+ @compact
440
+ def __call__( # type: ignore
441
+ self,
442
+ inputs_q: Array,
443
+ mask: Optional[Array] = None,
444
+ deterministic: Optional[bool] = None,
445
+ ):
446
+ """Applies multi-head dot product self-attention on the input data.
447
+
448
+ Projects the inputs into multi-headed query, key, and value vectors,
449
+ applies dot-product attention and project the results to an output vector.
450
+
451
+ Args:
452
+ inputs_q: input queries of shape
453
+ `[batch_sizes..., length, features]`.
454
+ mask: attention mask of shape
455
+ `[batch_sizes..., num_heads, query_length, key/value_length]`.
456
+ Attention weights are masked out if their corresponding mask value
457
+ is `False`.
458
+ deterministic: if false, the attention weight is masked randomly
459
+ using dropout, whereas if true, the attention weights
460
+ are deterministic.
461
+
462
+ Returns:
463
+ output of shape `[batch_sizes..., length, features]`.
464
+ """
465
+ return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)
466
+
467
+
468
+ # mask-making utility functions
469
+
470
+
471
+ def make_attention_mask(
472
+ query_input: Array,
473
+ key_input: Array,
474
+ pairwise_fn: Callable[..., Any] = jnp.multiply,
475
+ extra_batch_dims: int = 0,
476
+ dtype: Dtype = jnp.float32,
477
+ ):
478
+ """Mask-making helper for attention weights.
479
+
480
+ In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the
481
+ attention weights will be `[batch..., heads, len_q, len_kv]` and this
482
+ function will produce `[batch..., 1, len_q, len_kv]`.
483
+
484
+ Args:
485
+ query_input: a batched, flat input of query_length size
486
+ key_input: a batched, flat input of key_length size
487
+ pairwise_fn: broadcasting elementwise comparison function
488
+ extra_batch_dims: number of extra batch dims to add singleton
489
+ axes for, none by default
490
+ dtype: mask return dtype
491
+
492
+ Returns:
493
+ A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.
494
+ """
495
+ mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2))
496
+ mask = jnp.expand_dims(mask, axis=-3)
497
+ mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
498
+ return mask.astype(dtype)
499
+
500
+
501
+ def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32) -> Array:
502
+ """Make a causal mask for self-attention.
503
+
504
+ In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights
505
+ will be `[batch..., heads, len, len]` and this function will produce a
506
+ causal mask of shape `[batch..., 1, len, len]`.
507
+
508
+ Args:
509
+ x: input array of shape `[batch..., len]`
510
+ extra_batch_dims: number of batch dims to add singleton axes for,
511
+ none by default
512
+ dtype: mask return dtype
513
+
514
+ Returns:
515
+ A `[batch..., 1, len, len]` shaped causal mask for 1d attention.
516
+ """
517
+ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
518
+ return make_attention_mask(
519
+ idxs,
520
+ idxs,
521
+ jnp.greater_equal,
522
+ extra_batch_dims=extra_batch_dims,
523
+ dtype=dtype,
524
+ )
525
+
526
+
527
+ def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array:
528
+ """Combine attention masks.
529
+
530
+ Args:
531
+ *masks: set of attention mask arguments to combine, some can be None.
532
+ dtype: dtype for the returned mask.
533
+
534
+ Returns:
535
+ Combined mask, reduced by logical and, returns None if no masks given.
536
+ """
537
+ masks_list = [m for m in masks if m is not None]
538
+ if not masks_list:
539
+ return None
540
+ assert all(
541
+ map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
542
+ ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}"
543
+ mask, *other_masks = masks_list
544
+ for other_mask in other_masks:
545
+ mask = jnp.logical_and(mask, other_mask)
546
+ return mask.astype(dtype)
kinetix/models/transformer_model.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import jax.numpy as jnp
3
+ import flax.linen as nn
4
+ import numpy as np
5
+ from flax.linen.initializers import constant, orthogonal
6
+ from typing import List, Sequence
7
+
8
+ import distrax
9
+ import jax
10
+
11
+ from kinetix.models.actor_critic import GeneralActorCriticRNN, ScannedRNN
12
+
13
+
14
+ from kinetix.render.renderer_symbolic_entity import EntityObservation
15
+
16
+ from flax.linen.attention import MultiHeadDotProductAttention
17
+
18
+
19
+ class Gating(nn.Module):
20
+ # code taken from https://github.com/dhruvramani/Transformers-RL/blob/master/layers.py
21
+ d_input: int
22
+ bg: float = 0.0
23
+
24
+ @nn.compact
25
+ def __call__(self, x, y):
26
+ r = jax.nn.sigmoid(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(x))
27
+ z = jax.nn.sigmoid(
28
+ nn.Dense(self.d_input, use_bias=False)(y)
29
+ + nn.Dense(self.d_input, use_bias=False)(x)
30
+ - self.param("gating_bias", constant(self.bg), (self.d_input,))
31
+ )
32
+ h = jnp.tanh(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(r * x))
33
+ g = (1 - z) * x + (z * h)
34
+ return g
35
+
36
+
37
+ class transformer_layer(nn.Module):
38
+ num_heads: int
39
+ out_features: int
40
+ qkv_features: int
41
+ gating: bool = False
42
+ gating_bias: float = 0.0
43
+
44
+ def setup(self):
45
+ self.attention1 = MultiHeadDotProductAttention(
46
+ num_heads=self.num_heads, qkv_features=self.qkv_features, out_features=self.out_features
47
+ )
48
+
49
+ self.ln1 = nn.LayerNorm()
50
+
51
+ self.dense1 = nn.Dense(self.out_features)
52
+
53
+ self.dense2 = nn.Dense(self.out_features)
54
+
55
+ self.ln2 = nn.LayerNorm()
56
+ if self.gating:
57
+ self.gate1 = Gating(self.out_features, self.gating_bias)
58
+ self.gate2 = Gating(self.out_features, self.gating_bias)
59
+
60
+ def __call__(self, queries: jnp.ndarray, mask: jnp.ndarray):
61
+ # After reading the paper, this is what I think we should do:
62
+ # First layernorm, then do attention
63
+ queries_n = self.ln1(queries)
64
+ y = self.attention1(queries_n, mask=mask)
65
+ if self.gating: # and gate
66
+ y = self.gate1(queries, jax.nn.relu(y))
67
+ else:
68
+ y = queries + y
69
+ # Dense after norming, crucially no relu.
70
+ e = self.dense1(self.ln2(y))
71
+ if self.gating: # and gate again
72
+ # This may be the wrong way around
73
+ e = self.gate2(y, jax.nn.relu(e))
74
+ else:
75
+ e = y + e
76
+
77
+ return e
78
+
79
+
80
+ class Transformer(nn.Module):
81
+ encoder_size: int
82
+ num_heads: int
83
+ qkv_features: int
84
+ num_layers: int
85
+ gating: bool = False
86
+ gating_bias: float = 0.0
87
+
88
+ def setup(self):
89
+ # self.encoder = nn.Dense(self.encoder_size)
90
+
91
+ # self.positional_encoding = PositionalEncoding(self.encoder_size, max_len=self.max_len)
92
+
93
+ self.tf_layers = [
94
+ transformer_layer(
95
+ num_heads=self.num_heads,
96
+ qkv_features=self.qkv_features,
97
+ out_features=self.encoder_size,
98
+ gating=self.gating,
99
+ gating_bias=self.gating_bias,
100
+ )
101
+ for _ in range(self.num_layers)
102
+ ]
103
+
104
+ self.joint_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)]
105
+ self.thruster_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)]
106
+
107
+ # self.pos_emb=PositionalEmbedding(self.encoder_size)
108
+
109
+ def __call__(
110
+ self,
111
+ shape_embeddings: jnp.ndarray,
112
+ shape_attention_mask,
113
+ joint_embeddings,
114
+ joint_mask,
115
+ joint_indexes,
116
+ thruster_embeddings,
117
+ thruster_mask,
118
+ thruster_indexes,
119
+ ):
120
+ # forward eval so obs is only one timestep
121
+ # encoded = self.encoder(shape_embeddings)
122
+ # pos_embed=self.pos_emb(jnp.arange(1+memories.shape[-3],-1,-1))[:1+memories.shape[-3]]
123
+
124
+ for tf_layer, joint_layer, thruster_layer in zip(self.tf_layers, self.joint_layers, self.thruster_layers):
125
+ # Do attention
126
+ shape_embeddings = tf_layer(shape_embeddings, shape_attention_mask)
127
+
128
+ # Joints
129
+ # T, B, 2J, (2SE + JE)
130
+
131
+ @jax.vmap
132
+ @jax.vmap
133
+ def do_index2(to_ind, ind):
134
+ return to_ind[ind]
135
+
136
+ joint_shape_embeddings = jnp.concatenate(
137
+ [
138
+ do_index2(shape_embeddings, joint_indexes[..., 0]),
139
+ do_index2(shape_embeddings, joint_indexes[..., 1]),
140
+ joint_embeddings,
141
+ ],
142
+ axis=-1,
143
+ )
144
+
145
+ shape_joint_entity_delta = joint_layer(joint_shape_embeddings) * joint_mask[..., None]
146
+
147
+ @jax.vmap
148
+ @jax.vmap
149
+ def add2(addee, index, adder):
150
+ return addee.at[index].add(adder)
151
+
152
+ # Thrusters
153
+ thruster_shape_embeddings = jnp.concatenate(
154
+ [
155
+ do_index2(shape_embeddings, thruster_indexes),
156
+ thruster_embeddings,
157
+ ],
158
+ axis=-1,
159
+ )
160
+
161
+ shape_thruster_entity_delta = thruster_layer(thruster_shape_embeddings) * thruster_mask[..., None]
162
+
163
+ shape_embeddings = add2(shape_embeddings, joint_indexes[..., 0], shape_joint_entity_delta)
164
+ shape_embeddings = add2(shape_embeddings, thruster_indexes, shape_thruster_entity_delta)
165
+
166
+ return shape_embeddings
167
+
168
+
169
+ class ActorCriticTransformer(nn.Module):
170
+ action_dim: Sequence[int]
171
+ fc_layer_width: int
172
+ action_mode: str
173
+ hybrid_action_continuous_dim: int
174
+ multi_discrete_number_of_dims_per_distribution: List[int]
175
+ transformer_size: int
176
+ transformer_encoder_size: int
177
+ transformer_depth: int
178
+ fc_layer_depth: int
179
+ num_heads: int
180
+ activation: str
181
+ aggregate_mode: str # "dummy" or "mean" or "dummy_and_mean"
182
+ full_attention_mask: bool # if true, only mask out inactives, and have everything attend to everything else
183
+ add_generator_embedding: bool = False
184
+ generator_embedding_number_of_timesteps: int = 10
185
+ recurrent: bool = True
186
+
187
+ @nn.compact
188
+ def __call__(self, hidden, x):
189
+ if self.activation == "relu":
190
+ activation = nn.relu
191
+ else:
192
+ activation = nn.tanh
193
+
194
+ og_obs, dones = x
195
+ if self.add_generator_embedding:
196
+ obs = og_obs.obs
197
+ else:
198
+ obs = og_obs
199
+
200
+ # obs._ is [T, B, N, L]
201
+ # B - batch size
202
+ # T - time
203
+ # N - number of things
204
+ # L - unembedded entity size
205
+ obs: EntityObservation
206
+
207
+ def _single_encoder(features, entity_id, concat=True):
208
+ # assume two entity types
209
+ num_to_remove = 1 if concat else 0
210
+ embedding = activation(
211
+ nn.Dense(
212
+ self.transformer_encoder_size - num_to_remove,
213
+ kernel_init=orthogonal(np.sqrt(2)),
214
+ bias_init=constant(0.0),
215
+ )(features)
216
+ )
217
+ if concat:
218
+ id_1h = jnp.zeros((*embedding.shape[:3], 1)).at[:, :, :, entity_id].set(entity_id)
219
+ return jnp.concatenate([embedding, id_1h], axis=-1)
220
+ else:
221
+ return embedding
222
+
223
+ circle_encodings = _single_encoder(obs.circles, 0)
224
+ polygon_encodings = _single_encoder(obs.polygons, 1)
225
+ joint_encodings = _single_encoder(obs.joints, -1, False)
226
+ thruster_encodings = _single_encoder(obs.thrusters, -1, False)
227
+ # Size of this is something like (T, B, N, K) (time, batch, num_entities, embedding_size)
228
+
229
+ # T, B, M, K
230
+ shape_encodings = jnp.concatenate([polygon_encodings, circle_encodings], axis=2)
231
+ # T, B, M
232
+ shape_mask = jnp.concatenate([obs.polygon_mask, obs.circle_mask], axis=2)
233
+
234
+ def mask_out_inactives(flat_active_mask, matrix_attention_mask):
235
+ matrix_attention_mask = matrix_attention_mask & (flat_active_mask[:, None]) & (flat_active_mask[None, :])
236
+ return matrix_attention_mask
237
+
238
+ joint_indexes = obs.joint_indexes
239
+ thruster_indexes = obs.thruster_indexes
240
+
241
+ if self.aggregate_mode == "dummy" or self.aggregate_mode == "dummy_and_mean":
242
+ T, B, _, K = circle_encodings.shape
243
+ dummy = jnp.ones((T, B, 1, K))
244
+ shape_encodings = jnp.concatenate([dummy, shape_encodings], axis=2)
245
+ shape_mask = jnp.concatenate(
246
+ [jnp.ones((T, B, 1), dtype=bool), shape_mask],
247
+ axis=2,
248
+ )
249
+ N = obs.attention_mask.shape[-1]
250
+ overall_mask = (
251
+ jnp.ones((T, B, obs.attention_mask.shape[2], N + 1, N + 1), dtype=bool)
252
+ .at[:, :, :, 1:, 1:]
253
+ .set(obs.attention_mask)
254
+ )
255
+ overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask)
256
+
257
+ # To account for the dummy entity
258
+ joint_indexes = joint_indexes + 1
259
+ thruster_indexes = thruster_indexes + 1
260
+
261
+ else:
262
+ overall_mask = obs.attention_mask
263
+
264
+ if self.full_attention_mask:
265
+ overall_mask = jnp.ones(overall_mask.shape, dtype=bool)
266
+ overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask)
267
+
268
+ # Now do attention on these
269
+ embedding = Transformer(
270
+ num_layers=self.transformer_depth,
271
+ num_heads=self.num_heads,
272
+ qkv_features=self.transformer_size,
273
+ encoder_size=self.transformer_encoder_size,
274
+ gating=True,
275
+ gating_bias=0.0,
276
+ )(
277
+ shape_encodings,
278
+ jnp.repeat(overall_mask, repeats=self.num_heads // overall_mask.shape[2], axis=2),
279
+ joint_encodings,
280
+ obs.joint_mask,
281
+ joint_indexes,
282
+ thruster_encodings,
283
+ obs.thruster_mask,
284
+ thruster_indexes,
285
+ ) # add the extra dimension for the heads
286
+
287
+ if self.aggregate_mode == "mean" or self.aggregate_mode == "dummy_and_mean":
288
+ embedding = jnp.mean(embedding, axis=2, where=shape_mask[..., None])
289
+ else:
290
+ embedding = embedding[:, :, 0] # Take the dummy entity as the embedding of the entire scene.
291
+
292
+ return GeneralActorCriticRNN(
293
+ action_dim=self.action_dim,
294
+ fc_layer_depth=self.fc_layer_depth,
295
+ fc_layer_width=self.fc_layer_width,
296
+ action_mode=self.action_mode,
297
+ hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
298
+ multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
299
+ add_generator_embedding=self.add_generator_embedding,
300
+ generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
301
+ recurrent=self.recurrent,
302
+ )(hidden, og_obs, embedding, dones, activation)
kinetix/pcg/__init__.py ADDED
File without changes
kinetix/pcg/pcg.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from jax2d.engine import recalculate_mass_and_inertia, recompute_global_joint_positions, select_shape
4
+ from kinetix.environment.env_state import EnvState, StaticEnvParams
5
+ from kinetix.pcg.pcg_state import PCGState
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+
10
+ def _process_tied_together_shapes(pcg_state: PCGState, sampled_state: EnvState, static_params: StaticEnvParams):
11
+
12
+ # Get the matrix of tied together positions. Since we vmap, we only want one entry active for any (i, j, k). Thus, we mask out some of the duplicate ones.
13
+ tied = jnp.triu(pcg_state.tied_together & jnp.logical_not(jnp.eye(pcg_state.tied_together.shape[0], dtype=bool)))
14
+ has_anything_in_column = tied.any(axis=0)
15
+ tied = (
16
+ tied * jnp.logical_not(has_anything_in_column)[:, None]
17
+ ) # if there is something in a column, it means a previous one with a lower index has already been processed
18
+
19
+ should_use_delta_positions = tied.any(axis=0)
20
+
21
+ # This is the delta we have moved after sampling
22
+ delta_positions = jnp.concatenate(
23
+ [
24
+ sampled_state.polygon.position - pcg_state.env_state.polygon.position,
25
+ sampled_state.circle.position - pcg_state.env_state.circle.position,
26
+ ]
27
+ )
28
+
29
+ def _get_effect_of_shape_i_on_all_others(item_index, item_row_of_what_is_tied):
30
+ delta_pos = delta_positions[item_index]
31
+ return jnp.arange(len(item_row_of_what_is_tied)), delta_pos[None] * item_row_of_what_is_tied[:, None]
32
+
33
+ indices, positions = jax.vmap(_get_effect_of_shape_i_on_all_others, (0, 0))(jnp.arange(tied.shape[0]), tied)
34
+ indices = indices.flatten()
35
+ positions = positions.reshape(indices.shape[0], -1)
36
+
37
+ default_positions = jnp.concatenate(
38
+ [pcg_state.env_state.polygon.position, pcg_state.env_state.circle.position], axis=0
39
+ )
40
+ sampled_positions = jnp.concatenate([sampled_state.polygon.position, sampled_state.circle.position], axis=0)
41
+
42
+ updated_positions = default_positions.at[indices].add(positions)
43
+ # Use the deltas or the sampled positions
44
+ positions = jnp.where(should_use_delta_positions[:, None], updated_positions, sampled_positions)
45
+
46
+ sampled_state = sampled_state.replace(
47
+ polygon=sampled_state.polygon.replace(position=positions[: static_params.num_polygons]),
48
+ circle=sampled_state.circle.replace(position=positions[static_params.num_polygons :]),
49
+ )
50
+ return sampled_state
51
+
52
+
53
+ @partial(jax.jit, static_argnums=(3,))
54
+ def sample_pcg_state(rng, pcg_state: PCGState, params, static_params):
55
+ def _pcg_fn(rng, main_val, max_val, mask):
56
+ pcg_val = jax.random.uniform(rng, shape=main_val.shape) * (
57
+ max_val.astype(float) - main_val.astype(float)
58
+ ) + main_val.astype(float)
59
+ if jnp.issubdtype(main_val.dtype, jnp.integer) or jnp.issubdtype(main_val.dtype, jnp.bool_):
60
+ pcg_val = jnp.round(pcg_val)
61
+ pcg_val = pcg_val.astype(main_val.dtype)
62
+ new_val = jax.lax.select(mask.astype(bool), pcg_val, main_val)
63
+ return new_val
64
+
65
+ def _random_split_like_tree(rng, target):
66
+ tree_def = jax.tree_structure(target)
67
+ rngs = jax.random.split(rng, tree_def.num_leaves)
68
+ return jax.tree_unflatten(tree_def, rngs)
69
+
70
+ rng, _rng = jax.random.split(rng)
71
+ rng_tree = _random_split_like_tree(_rng, pcg_state.env_state)
72
+
73
+ sampled_state = jax.tree_util.tree_map(
74
+ _pcg_fn, rng_tree, pcg_state.env_state, pcg_state.env_state_max, pcg_state.env_state_pcg_mask
75
+ )
76
+
77
+ sampled_state = _process_tied_together_shapes(pcg_state, sampled_state, static_params)
78
+
79
+ sampled_state = recompute_global_joint_positions(sampled_state, static_params)
80
+
81
+ env_state = recalculate_mass_and_inertia(
82
+ sampled_state, static_params, sampled_state.polygon_densities, sampled_state.circle_densities
83
+ )
84
+
85
+ return env_state
86
+
87
+
88
+ def env_state_to_pcg_state(env_state: EnvState):
89
+ N = env_state.polygon.active.shape[0] + env_state.circle.active.shape[0]
90
+ pcg_state = PCGState(
91
+ env_state=env_state,
92
+ env_state_max=env_state,
93
+ env_state_pcg_mask=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x, dtype=bool), env_state),
94
+ tied_together=jnp.zeros((N, N), dtype=bool),
95
+ )
96
+
97
+ return pcg_state
kinetix/pcg/pcg_state.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import field
2
+ import jax.numpy as jnp
3
+ from flax import struct
4
+
5
+ from jax2d.sim_state import SimState, SimParams, StaticSimParams, RigidBody, Joint, Thruster, CollisionManifold
6
+ from kinetix.environment.env_state import EnvState
7
+
8
+
9
+ @struct.dataclass
10
+ class PCGState:
11
+ # Primary env state
12
+ env_state: EnvState
13
+ # The PCG mask. If a value is truthy in this, then it is PCG not static
14
+ env_state_pcg_mask: EnvState
15
+ # In the case that a value is PCG, the env_state value is the min and this state represents the max
16
+ env_state_max: EnvState
17
+
18
+ tied_together: jnp.ndarray # NxN matrix of booleans, where N is the number of shapes
19
+
20
+ def __setstate__(self, state):
21
+ if "tied_together" not in state:
22
+ num_shapes = state["env_state"].polygon.active.shape[0] + state["env_state"].circle.active.shape[0]
23
+ state["tied_together"] = jnp.zeros((num_shapes, num_shapes), dtype=bool)
24
+ object.__setattr__(self, "__dict__", state)
kinetix/render/__init__.py ADDED
File without changes
kinetix/render/renderer_pixels.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+
7
+ from jax2d import joint
8
+ from jax2d.engine import select_shape
9
+ from jax2d.maths import rmat
10
+ from jax2d.sim_state import RigidBody
11
+ from jaxgl.maths import dist_from_line
12
+ from jaxgl.renderer import clear_screen, make_renderer
13
+ from jaxgl.shaders import (
14
+ fragment_shader_quad,
15
+ fragment_shader_edged_quad,
16
+ make_fragment_shader_texture,
17
+ nearest_neighbour,
18
+ make_fragment_shader_quad_textured,
19
+ )
20
+
21
+ from kinetix.render.textures import (
22
+ THRUSTER_TEXTURE_16_RGBA,
23
+ RJOINT_TEXTURE_6_RGBA,
24
+ FJOINT_TEXTURE_6_RGBA,
25
+ )
26
+ from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
27
+ from flax import struct
28
+
29
+
30
+ def make_render_pixels(
31
+ params,
32
+ static_params: StaticEnvParams,
33
+ ):
34
+ screen_dim = static_params.screen_dim
35
+ downscale = static_params.downscale
36
+
37
+ joint_tex_size = 6
38
+ thruster_tex_size = 16
39
+
40
+ FIXATED_COLOUR = jnp.array([80, 80, 80])
41
+ JOINT_COLOURS = jnp.array(
42
+ [
43
+ # [0, 0, 255],
44
+ [255, 255, 255], # yellow
45
+ [255, 255, 0], # yellow
46
+ [255, 0, 255], # purple/magenta
47
+ [0, 255, 255], # cyan
48
+ [255, 153, 51], # white
49
+ ]
50
+ )
51
+
52
+ def colour_thruster_texture(colour):
53
+ return THRUSTER_TEXTURE_16_RGBA.at[:9, :, :3].mul(colour[None, None, :] / 255.0)
54
+
55
+ coloured_thruster_textures = jax.vmap(colour_thruster_texture)(JOINT_COLOURS)
56
+
57
+ ROLE_COLOURS = jnp.array(
58
+ [
59
+ [160.0, 160.0, 160.0], # None
60
+ [0.0, 204.0, 0.0], # Green: The ball
61
+ [0.0, 102.0, 204.0], # Blue: The goal
62
+ [255.0, 102.0, 102.0], # Red: Death Objects
63
+ ]
64
+ )
65
+
66
+ BACKGROUND_COLOUR = jnp.array([255.0, 255.0, 255.0])
67
+
68
+ def _get_colour(shape_role, inverse_inertia):
69
+ base_colour = ROLE_COLOURS[shape_role]
70
+ f = (inverse_inertia == 0) * 1
71
+ is_not_normal = (shape_role != 0) * 1
72
+
73
+ return jnp.array(
74
+ [
75
+ base_colour,
76
+ base_colour,
77
+ FIXATED_COLOUR,
78
+ base_colour * 0.5,
79
+ ]
80
+ )[2 * f + is_not_normal]
81
+
82
+ # Pixels per unit distance
83
+ ppud = params.pixels_per_unit // downscale
84
+
85
+ downscaled_screen_dim = (screen_dim[0] // downscale, screen_dim[1] // downscale)
86
+
87
+ full_screen_size = (
88
+ downscaled_screen_dim[0] + (static_params.max_shape_size * 2 * ppud),
89
+ downscaled_screen_dim[1] + (static_params.max_shape_size * 2 * ppud),
90
+ )
91
+ cleared_screen = clear_screen(full_screen_size, BACKGROUND_COLOUR)
92
+
93
+ def _world_space_to_pixel_space(x):
94
+ return (x + static_params.max_shape_size) * ppud
95
+
96
+ def fragment_shader_kinetix_circle(position, current_frag, unit_position, uniform):
97
+ centre, radius, rotation, colour, mask = uniform
98
+
99
+ dist = jnp.sqrt(jnp.square(position - centre).sum())
100
+ inside = dist <= radius
101
+ on_edge = dist > radius - 2
102
+
103
+ # TODO - precompute?
104
+ normal = jnp.array([jnp.sin(rotation), -jnp.cos(rotation)])
105
+
106
+ dist = dist_from_line(position, centre, centre + normal)
107
+
108
+ on_edge |= (dist < 1) & (jnp.dot(normal, position - centre) <= 0)
109
+
110
+ fragment = jax.lax.select(on_edge, jnp.zeros(3), colour)
111
+
112
+ return jax.lax.select(inside & mask, fragment, current_frag)
113
+
114
+ def fragment_shader_kinetix_joint(position, current_frag, unit_position, uniform):
115
+ texture, colour, mask = uniform
116
+
117
+ tex_coord = (
118
+ jnp.array(
119
+ [
120
+ joint_tex_size * unit_position[0],
121
+ joint_tex_size * unit_position[1],
122
+ ]
123
+ )
124
+ - 0.5
125
+ )
126
+
127
+ tex_frag = nearest_neighbour(texture, tex_coord)
128
+ tex_frag = tex_frag.at[3].mul(mask)
129
+ tex_frag = tex_frag.at[:3].mul(colour / 255.0)
130
+
131
+ tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
132
+
133
+ return tex_frag
134
+
135
+ thruster_pixel_size = thruster_tex_size // downscale
136
+ thruster_pixel_size_diagonal = (thruster_pixel_size * np.sqrt(2)).astype(jnp.int32) + 1
137
+
138
+ def fragment_shader_kinetix_thruster(fragment_position, current_frag, unit_position, uniform):
139
+ thruster_position, rotation, texture, mask = uniform
140
+
141
+ tex_position = jnp.matmul(rmat(-rotation), (fragment_position - thruster_position)) / thruster_pixel_size + 0.5
142
+
143
+ mask &= (tex_position[0] >= 0) & (tex_position[0] <= 1) & (tex_position[1] >= 0) & (tex_position[1] <= 1)
144
+
145
+ eps = 0.001
146
+ tex_coord = (
147
+ jnp.array(
148
+ [
149
+ thruster_tex_size * tex_position[0],
150
+ thruster_tex_size * tex_position[1],
151
+ ]
152
+ )
153
+ - 0.5
154
+ + eps
155
+ )
156
+
157
+ tex_frag = nearest_neighbour(texture, tex_coord)
158
+ tex_frag = tex_frag.at[3].mul(mask)
159
+
160
+ tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
161
+
162
+ return tex_frag
163
+
164
+ patch_size_1d = static_params.max_shape_size * ppud
165
+ patch_size = (patch_size_1d, patch_size_1d)
166
+
167
+ circle_renderer = make_renderer(full_screen_size, fragment_shader_kinetix_circle, patch_size, batched=True)
168
+ quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, patch_size, batched=True)
169
+ big_quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, downscaled_screen_dim)
170
+
171
+ joint_pixel_size = joint_tex_size // downscale
172
+ joint_renderer = make_renderer(
173
+ full_screen_size, fragment_shader_kinetix_joint, (joint_pixel_size, joint_pixel_size), batched=True
174
+ )
175
+
176
+ thruster_renderer = make_renderer(
177
+ full_screen_size,
178
+ fragment_shader_kinetix_thruster,
179
+ (thruster_pixel_size_diagonal, thruster_pixel_size_diagonal),
180
+ batched=True,
181
+ )
182
+
183
+ @jax.jit
184
+ def render_pixels(state: EnvState):
185
+ pixels = cleared_screen
186
+
187
+ # Floor
188
+ floor_uniform = (
189
+ _world_space_to_pixel_space(state.polygon.position[0, None, :] + state.polygon.vertices[0]),
190
+ _get_colour(state.polygon_shape_roles[0], 0),
191
+ jnp.zeros(3),
192
+ True,
193
+ )
194
+
195
+ pixels = big_quad_renderer(pixels, _world_space_to_pixel_space(jnp.zeros(2, dtype=jnp.int32)), floor_uniform)
196
+
197
+ # Rectangles
198
+ rectangle_patch_positions = _world_space_to_pixel_space(
199
+ state.polygon.position - (static_params.max_shape_size / 2.0)
200
+ ).astype(jnp.int32)
201
+
202
+ rectangle_rmats = jax.vmap(rmat)(state.polygon.rotation)
203
+ rectangle_rmats = jnp.repeat(rectangle_rmats[:, None, :, :], repeats=static_params.max_polygon_vertices, axis=1)
204
+ rectangle_vertices_pixel_space = _world_space_to_pixel_space(
205
+ state.polygon.position[:, None, :] + jax.vmap(jax.vmap(jnp.matmul))(rectangle_rmats, state.polygon.vertices)
206
+ )
207
+ rectangle_colours = jax.vmap(_get_colour)(state.polygon_shape_roles, state.polygon.inverse_mass)
208
+ rectangle_edge_colours = jnp.zeros((static_params.num_polygons, 3))
209
+
210
+ rectangle_uniforms = (
211
+ rectangle_vertices_pixel_space,
212
+ rectangle_colours,
213
+ rectangle_edge_colours,
214
+ state.polygon.active,
215
+ )
216
+
217
+ pixels = quad_renderer(pixels, rectangle_patch_positions, rectangle_uniforms)
218
+
219
+ # Circles
220
+ circle_positions_pixel_space = _world_space_to_pixel_space(state.circle.position)
221
+ circle_radii_pixel_space = state.circle.radius * ppud
222
+ circle_patch_positions = _world_space_to_pixel_space(
223
+ state.circle.position - (static_params.max_shape_size / 2.0)
224
+ ).astype(jnp.int32)
225
+
226
+ circle_colours = jax.vmap(_get_colour)(state.circle_shape_roles, state.circle.inverse_mass)
227
+
228
+ circle_uniforms = (
229
+ circle_positions_pixel_space,
230
+ circle_radii_pixel_space,
231
+ state.circle.rotation,
232
+ circle_colours,
233
+ state.circle.active,
234
+ )
235
+
236
+ pixels = circle_renderer(pixels, circle_patch_positions, circle_uniforms)
237
+
238
+ # Joints
239
+ joint_patch_positions = jnp.round(
240
+ _world_space_to_pixel_space(state.joint.global_position) - (joint_pixel_size // 2)
241
+ ).astype(jnp.int32)
242
+ joint_textures = jax.vmap(jax.lax.select, in_axes=(0, None, None))(
243
+ state.joint.is_fixed_joint, FJOINT_TEXTURE_6_RGBA, RJOINT_TEXTURE_6_RGBA
244
+ )
245
+ joint_colours = JOINT_COLOURS[
246
+ (state.motor_bindings + 1) * (state.joint.motor_on & (~state.joint.is_fixed_joint))
247
+ ]
248
+
249
+ joint_uniforms = (joint_textures, joint_colours, state.joint.active)
250
+
251
+ pixels = joint_renderer(pixels, joint_patch_positions, joint_uniforms)
252
+
253
+ # Thrusters
254
+ thruster_positions = jnp.round(_world_space_to_pixel_space(state.thruster.global_position)).astype(jnp.int32)
255
+ thruster_patch_positions = thruster_positions - (thruster_pixel_size_diagonal // 2)
256
+ thruster_textures = coloured_thruster_textures[state.thruster_bindings + 1]
257
+ thruster_rotations = (
258
+ state.thruster.rotation
259
+ + jax.vmap(select_shape, in_axes=(None, 0, None))(
260
+ state, state.thruster.object_index, static_params
261
+ ).rotation
262
+ )
263
+ thruster_uniforms = (thruster_positions, thruster_rotations, thruster_textures, state.thruster.active)
264
+
265
+ pixels = thruster_renderer(pixels, thruster_patch_positions, thruster_uniforms)
266
+
267
+ # Crop out the sides
268
+ crop_amount = static_params.max_shape_size * ppud
269
+ return pixels[crop_amount:-crop_amount, crop_amount:-crop_amount]
270
+
271
+ return render_pixels
272
+
273
+
274
+ @struct.dataclass
275
+ class PixelsObservation:
276
+ image: jnp.ndarray
277
+ global_info: jnp.ndarray
278
+
279
+
280
+ def make_render_pixels_rl(params, static_params: StaticEnvParams):
281
+ render_fn = make_render_pixels(params, static_params)
282
+
283
+ def inner(state):
284
+ pixels = render_fn(state) / 255.0
285
+ return PixelsObservation(
286
+ image=pixels,
287
+ global_info=jnp.array([state.gravity[1] / 10.0]),
288
+ )
289
+
290
+ return inner
kinetix/render/renderer_symbolic_common.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ from jax2d.sim_state import RigidBody
3
+ import jax.numpy as jnp
4
+ from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
5
+
6
+
7
+ def _get_base_shape_features(
8
+ density: jnp.ndarray, roles: jnp.ndarray, shapes: RigidBody, env_params: EnvParams
9
+ ) -> jnp.ndarray:
10
+ cos = jnp.cos(shapes.rotation)
11
+ sin = jnp.sin(shapes.rotation)
12
+ return jnp.concatenate(
13
+ [
14
+ shapes.position,
15
+ shapes.velocity,
16
+ jnp.expand_dims(shapes.inverse_mass, axis=1),
17
+ jnp.expand_dims(shapes.inverse_inertia, axis=1),
18
+ jnp.expand_dims(density, axis=1),
19
+ jnp.expand_dims(jnp.tanh(shapes.angular_velocity / 10), axis=1),
20
+ jax.nn.one_hot(roles, env_params.num_shape_roles),
21
+ jnp.expand_dims(sin, axis=1),
22
+ jnp.expand_dims(cos, axis=1),
23
+ jnp.expand_dims(shapes.friction, axis=1),
24
+ jnp.expand_dims(shapes.restitution, axis=1),
25
+ ],
26
+ axis=1,
27
+ )
28
+
29
+
30
+ def add_circle_features(
31
+ base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
32
+ ):
33
+ return jnp.concatenate(
34
+ [
35
+ base_features,
36
+ shapes.radius[:, None],
37
+ jnp.ones_like(base_features[:, :1]), # one for circle
38
+ ],
39
+ axis=1,
40
+ )
41
+
42
+
43
+ def make_circle_features(
44
+ state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
45
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
46
+ base_features = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
47
+ node_features = add_circle_features(base_features, state.circle, env_params, static_env_params)
48
+ return node_features, state.circle.active
49
+
50
+
51
+ def add_polygon_features(
52
+ base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
53
+ ):
54
+ vertices = jnp.where(
55
+ jnp.arange(static_env_params.max_polygon_vertices)[None, :, None] < shapes.n_vertices[:, None, None],
56
+ shapes.vertices,
57
+ jnp.zeros_like(shapes.vertices) - 1,
58
+ )
59
+
60
+ return jnp.concatenate(
61
+ [
62
+ base_features,
63
+ jnp.zeros_like(base_features[:, :1]), # zero for polygon
64
+ vertices.reshape((vertices.shape[0], -1)),
65
+ jnp.expand_dims((shapes.n_vertices <= 3), axis=1),
66
+ ],
67
+ axis=1,
68
+ )
69
+
70
+
71
+ def make_polygon_features(
72
+ state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
73
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
74
+ base_features = _get_base_shape_features(
75
+ state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params
76
+ )
77
+ node_features = add_polygon_features(base_features, state.polygon, env_params, static_env_params)
78
+ return node_features, state.polygon.active
79
+
80
+
81
+ def make_unified_shape_features(
82
+ state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
83
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
84
+ base_p = _get_base_shape_features(state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params)
85
+ base_c = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
86
+ base_p = add_polygon_features(base_p, state.polygon, env_params, static_env_params)
87
+ base_p = add_circle_features(base_p, state.polygon, env_params, static_env_params)
88
+
89
+ base_c = add_polygon_features(base_c, state.circle, env_params, static_env_params)
90
+ base_c = add_circle_features(base_c, state.circle, env_params, static_env_params)
91
+
92
+ return jnp.concatenate([base_p, base_c], axis=0), jnp.concatenate(
93
+ [state.polygon.active, state.circle.active], axis=0
94
+ )
95
+
96
+
97
+ def make_joint_features(
98
+ state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
99
+ ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
100
+ # Returns joint_features, indexes, mask, of shape:
101
+ # (2 * J, K), (2 * J, 2), (2 * J,)
102
+ def _create_joint_features(joints):
103
+ # 2, J, A
104
+ J = joints.active.shape[0]
105
+
106
+ def _create_1way_joint_features(direction):
107
+ from_pos = jax.lax.select(direction, joints.a_relative_pos, joints.b_relative_pos)
108
+ to_pos = jax.lax.select(direction, joints.b_relative_pos, joints.a_relative_pos)
109
+
110
+ rotation_sin, rotation_cos = jnp.sin(joints.rotation), jnp.cos(joints.rotation)
111
+ rotation_max_sin = jnp.sin(joints.max_rotation) * joints.motor_has_joint_limits
112
+ rotation_max_cos = jnp.cos(joints.max_rotation) * joints.motor_has_joint_limits
113
+ rotation_min_sin = jnp.sin(joints.min_rotation) * joints.motor_has_joint_limits
114
+ rotation_min_cos = jnp.cos(joints.min_rotation) * joints.motor_has_joint_limits
115
+
116
+ rotation_diff_max = (joints.max_rotation - joints.rotation) * joints.motor_has_joint_limits
117
+ rotation_diff_min = (joints.min_rotation - joints.rotation) * joints.motor_has_joint_limits
118
+
119
+ base_features = jnp.concatenate(
120
+ [
121
+ (joints.active * 1.0)[:, None],
122
+ (joints.is_fixed_joint * 1.0)[:, None], # J, 1
123
+ from_pos,
124
+ to_pos,
125
+ rotation_sin[:, None],
126
+ rotation_cos[:, None],
127
+ ],
128
+ axis=1,
129
+ )
130
+ rjoint_features = (
131
+ jnp.concatenate(
132
+ [
133
+ joints.motor_speed[:, None],
134
+ joints.motor_power[:, None],
135
+ (joints.motor_on * 1.0)[:, None],
136
+ (joints.motor_has_joint_limits * 1.0)[:, None],
137
+ jax.nn.one_hot(state.motor_bindings, num_classes=static_env_params.num_motor_bindings),
138
+ rotation_min_sin[:, None],
139
+ rotation_min_cos[:, None],
140
+ rotation_max_sin[:, None],
141
+ rotation_max_cos[:, None],
142
+ rotation_diff_min[:, None],
143
+ rotation_diff_max[:, None],
144
+ ],
145
+ axis=1,
146
+ )
147
+ * (1.0 - (joints.is_fixed_joint * 1.0))[:, None]
148
+ )
149
+
150
+ return jnp.concatenate([base_features, rjoint_features], axis=1)
151
+
152
+ # 2, J, A
153
+ joint_features = jax.vmap(_create_1way_joint_features)(jnp.array([False, True]))
154
+
155
+ # J, 2
156
+ indexes_from = jnp.concatenate([joints.b_index[:, None], joints.a_index[:, None]], axis=1)
157
+ indexes_to = jnp.concatenate([joints.a_index[:, None], joints.b_index[:, None]], axis=1)
158
+
159
+ indexes_from = jnp.where(joints.active[:, None], indexes_from, jnp.zeros_like(indexes_from))
160
+ indexes_to = jnp.where(joints.active[:, None], indexes_to, jnp.zeros_like(indexes_to))
161
+
162
+ indexes = jnp.concatenate([indexes_from, indexes_to], axis=0)
163
+ mask = jnp.concatenate([joints.active, joints.active], axis=0)
164
+
165
+ return joint_features.reshape((2 * J, -1)), indexes, mask
166
+
167
+ return _create_joint_features(state.joint)
168
+
169
+
170
+ def make_thruster_features(
171
+ state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
172
+ ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
173
+ # Returns thruster_features, indexes, mask, of shape:
174
+ # (T, K), (T,), (T,)
175
+ def _create_thruster_features(thrusters):
176
+ cos = jnp.cos(thrusters.rotation)
177
+ sin = jnp.sin(thrusters.rotation)
178
+ return jnp.concatenate(
179
+ [
180
+ (thrusters.active * 1.0)[:, None],
181
+ (thrusters.relative_position),
182
+ jax.nn.one_hot(state.thruster_bindings, num_classes=static_env_params.num_thruster_bindings),
183
+ sin[:, None],
184
+ cos[:, None],
185
+ thrusters.power[:, None],
186
+ ],
187
+ axis=1,
188
+ )
189
+
190
+ return _create_thruster_features(state.thruster), state.thruster.object_index, state.thruster.active
kinetix/render/renderer_symbolic_entity.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cmath import rect
2
+ from functools import partial
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import struct
7
+ from jax2d.engine import get_pairwise_interaction_indices
8
+ from kinetix.environment.env_state import EnvState
9
+ from kinetix.render.renderer_symbolic_common import (
10
+ make_circle_features,
11
+ make_joint_features,
12
+ make_polygon_features,
13
+ make_thruster_features,
14
+ make_unified_shape_features,
15
+ )
16
+
17
+
18
+ @struct.dataclass
19
+ class EntityObservation:
20
+ circles: jnp.ndarray
21
+ polygons: jnp.ndarray
22
+ joints: jnp.ndarray
23
+ thrusters: jnp.ndarray
24
+
25
+ circle_mask: jnp.ndarray
26
+ polygon_mask: jnp.ndarray
27
+ joint_mask: jnp.ndarray
28
+ thruster_mask: jnp.ndarray
29
+ attention_mask: jnp.ndarray
30
+ # collision_mask: jnp.ndarray
31
+
32
+ joint_indexes: jnp.ndarray
33
+ thruster_indexes: jnp.ndarray
34
+
35
+
36
+ def make_render_entities(params, static_params):
37
+ _, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params)
38
+ circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons)
39
+ circle_circle_pairs = circle_circle_pairs + static_params.num_polygons
40
+
41
+ def render_entities(state: EnvState):
42
+ state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state)
43
+
44
+ joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params)
45
+ thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params)
46
+
47
+ poly_nodes, poly_mask = make_polygon_features(state, params, static_params)
48
+ circle_nodes, circle_mask = make_circle_features(state, params, static_params)
49
+
50
+ def _add_grav(nodes):
51
+ return jnp.concatenate(
52
+ [nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1
53
+ ) # add gravity to each shape's embedding
54
+
55
+ poly_nodes = _add_grav(poly_nodes)
56
+ circle_nodes = _add_grav(circle_nodes)
57
+
58
+ # Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster )
59
+ mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0)
60
+ num_shapes = static_params.num_polygons + static_params.num_circles
61
+
62
+ def make_n_squared_mask(val):
63
+ # val has shape N of bools.
64
+ N = val.shape[0]
65
+ A = jnp.eye(N, N, dtype=bool) # also have things attend to themselves
66
+ # Make the shapes fully connected
67
+ full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool))
68
+
69
+ one_hop_connected = jnp.zeros((N, N), dtype=bool)
70
+ one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True)
71
+ one_hop_connected = one_hop_connected.at[0, 0].set(False) # invalid joints have indices of (0, 0)
72
+
73
+ multi_hop_connected = jnp.logical_not(state.collision_matrix)
74
+
75
+ collision_mask = state.collision_matrix
76
+
77
+ # where val is false, we want to mask out the row and column.
78
+ full_mask = full_mask & (val[:, None]) & (val[None, :])
79
+ collision_mask = collision_mask & (val[:, None]) & (val[None, :])
80
+ multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :])
81
+ one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :])
82
+ collision_manifold_mask = jnp.zeros_like(collision_mask)
83
+
84
+ def _set(collision_manifold_mask, pairs, active):
85
+ return collision_manifold_mask.at[
86
+ pairs[:, 0],
87
+ pairs[:, 1],
88
+ ].set(active)
89
+
90
+ collision_manifold_mask = _set(
91
+ collision_manifold_mask,
92
+ rect_rect_pairs,
93
+ jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]),
94
+ )
95
+
96
+ collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active)
97
+ collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active)
98
+ collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :])
99
+
100
+ return jnp.concatenate(
101
+ [full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]],
102
+ axis=0,
103
+ )
104
+
105
+ mask_n_squared = make_n_squared_mask(mask_flat_shapes)
106
+
107
+ return EntityObservation(
108
+ circles=circle_nodes,
109
+ polygons=poly_nodes,
110
+ joints=joint_features,
111
+ thrusters=thruster_features,
112
+ circle_mask=circle_mask,
113
+ polygon_mask=poly_mask,
114
+ joint_mask=joint_mask,
115
+ thruster_mask=thruster_mask,
116
+ attention_mask=mask_n_squared,
117
+ joint_indexes=joint_indexes,
118
+ thruster_indexes=thruster_indexes,
119
+ )
120
+
121
+ return render_entities
kinetix/render/renderer_symbolic_flat.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+
7
+ from jax2d import joint
8
+ from jax2d.engine import select_shape
9
+ from jax2d.maths import rmat
10
+ from jax2d.sim_state import RigidBody
11
+ from jaxgl.maths import dist_from_line
12
+ from jaxgl.renderer import clear_screen, make_renderer
13
+ from jaxgl.shaders import (
14
+ fragment_shader_quad,
15
+ fragment_shader_edged_quad,
16
+ make_fragment_shader_texture,
17
+ nearest_neighbour,
18
+ make_fragment_shader_quad_textured,
19
+ )
20
+ from kinetix.render.renderer_symbolic_common import (
21
+ make_circle_features,
22
+ make_joint_features,
23
+ make_polygon_features,
24
+ make_thruster_features,
25
+ )
26
+ from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
27
+ from flax import struct
28
+
29
+
30
+ def make_render_symbolic(params, static_params: StaticEnvParams):
31
+ def render_symbolic(state):
32
+
33
+ n_polys = static_params.num_polygons
34
+ nshapes = n_polys + static_params.num_circles
35
+
36
+ polygon_features, polygon_mask = make_polygon_features(state, params, static_params)
37
+ mask_to_ignore_walls_ceiling = np.ones(static_params.num_polygons, dtype=bool)
38
+ mask_to_ignore_walls_ceiling[np.array([1, 2, 3])] = False
39
+
40
+ polygon_features = polygon_features[mask_to_ignore_walls_ceiling]
41
+ polygon_mask = polygon_mask[mask_to_ignore_walls_ceiling]
42
+
43
+ circle_features, circle_mask = make_circle_features(state, params, static_params)
44
+ joint_features, joint_idxs, joint_mask = make_joint_features(state, params, static_params)
45
+ thruster_features, thruster_idxs, thruster_mask = make_thruster_features(state, params, static_params)
46
+
47
+ two_J = joint_features.shape[0]
48
+ J = two_J // 2 # for symbolic only have the one
49
+ joint_features = jnp.concatenate(
50
+ [
51
+ joint_features[:J], # shape (2 * J, K)
52
+ jax.nn.one_hot(joint_idxs[:J, 0], nshapes), # shape (2 * J, N)
53
+ jax.nn.one_hot(joint_idxs[:J, 1], nshapes), # shape (2 * J, N)
54
+ ],
55
+ axis=1,
56
+ )
57
+ thruster_features = jnp.concatenate(
58
+ [
59
+ thruster_features,
60
+ jax.nn.one_hot(thruster_idxs, nshapes),
61
+ ],
62
+ axis=1,
63
+ )
64
+
65
+ polygon_features = jnp.where(polygon_mask[:, None], polygon_features, 0.0).flatten()
66
+ circle_features = jnp.where(circle_mask[:, None], circle_features, 0.0).flatten()
67
+ joint_features = jnp.where(joint_mask[:J, None], joint_features, 0.0).flatten()
68
+ thruster_features = jnp.where(thruster_mask[:, None], thruster_features, 0.0).flatten()
69
+
70
+ def _get_manifold_features(manifold):
71
+ collision_mask_features = jnp.concatenate(
72
+ [
73
+ manifold.normal,
74
+ jnp.expand_dims(manifold.penetration, axis=-1),
75
+ manifold.collision_point,
76
+ jnp.expand_dims(manifold.acc_impulse_normal, axis=-1),
77
+ jnp.expand_dims(manifold.acc_impulse_tangent, axis=-1),
78
+ ],
79
+ axis=-1,
80
+ )
81
+
82
+ return (collision_mask_features * manifold.active[..., None]).flatten()
83
+
84
+ obs = jnp.concatenate(
85
+ [
86
+ polygon_features,
87
+ circle_features,
88
+ joint_features,
89
+ thruster_features,
90
+ jnp.array([state.gravity[1]]) / 10,
91
+ # _get_manifold_features(state.acc_cc_manifolds),
92
+ # _get_manifold_features(state.acc_cr_manifolds),
93
+ # _get_manifold_features(state.acc_rr_manifolds),
94
+ ],
95
+ axis=0,
96
+ )
97
+
98
+ obs = jnp.clip(obs, a_min=-10.0, a_max=10.0)
99
+ obs = jnp.nan_to_num(obs)
100
+ return obs
101
+
102
+ return render_symbolic
kinetix/render/textures.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ from enum import Enum
4
+ import jax.numpy as jnp
5
+ import imageio.v3 as iio
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ def load_texture(filename, render_size):
11
+ filename = os.path.join(pathlib.Path(__file__).parent.parent.resolve(), "assets", filename)
12
+ img = iio.imread(filename)
13
+ jnp_img = jnp.array(img).astype(jnp.int32)
14
+
15
+ if jnp_img.shape[2] == 4:
16
+ jnp_img = jnp_img.at[:, :, 3].set(jnp_img[:, :, 3] // 255)
17
+
18
+ img = np.array(jnp_img, dtype=np.uint8)
19
+ image = Image.fromarray(img)
20
+ image = image.resize((render_size, render_size), resample=Image.NEAREST)
21
+ jnp_img = jnp.array(image, dtype=jnp.float32)
22
+
23
+ return jnp_img.transpose((1, 0, 2))
24
+
25
+
26
+ EDIT_TEXTURE_RGBA = load_texture("edit.png", 64)
27
+ PLAY_TEXTURE_RGBA = load_texture("play.png", 64)
28
+
29
+ CIRCLE_TEXTURE_RGBA = load_texture("circle.png", 32)
30
+ RECT_TEXTURE_RGBA = load_texture("square.png", 32)
31
+ TRIANGLE_TEXTURE_RGBA = load_texture("triangle.png", 32)
32
+ RJOINT_TEXTURE_6_RGBA = load_texture("rjoint.png", 6)
33
+ RJOINT_TEXTURE_RGBA = load_texture("rjoint2.png", 32)
34
+
35
+ FJOINT_TEXTURE_6_RGBA = load_texture("fjoint.png", 6)
36
+ FJOINT_TEXTURE_RGBA = load_texture("fjoint2.png", 32)
37
+
38
+
39
+ ROTATION_TEXTURE_RGBA = load_texture("rotate.png", 32)
40
+ SELECT_TEXTURE_RGBA = load_texture("hand.png", 32)
41
+
42
+ THRUSTER_TEXTURE_RGBA = jnp.rot90(load_texture("thruster6.png", 32), k=3)
43
+ THRUSTER_TEXTURE_16_RGBA = jnp.rot90(load_texture("thruster.png", 16), k=3)
kinetix/util/__init__.py ADDED
File without changes
kinetix/util/config.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import gzip
4
+ import json
5
+ import os
6
+ from hashlib import md5
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+ from numpy import isin
12
+ from kinetix.environment.ued.ued_state import UEDParams
13
+ from omegaconf import OmegaConf
14
+ from pandas import isna
15
+ from typing import List, Tuple
16
+ import wandb
17
+ from kinetix.environment.env_state import EnvParams, StaticEnvParams
18
+ from collections import defaultdict
19
+
20
+ from kinetix.util.saving import load_from_json_file
21
+
22
+
23
+ def get_hash_without_seed(config):
24
+ old_seed = config["seed"]
25
+ config["seed"] = 0
26
+ ans = md5(OmegaConf.to_yaml(config, sort_keys=True).encode()).hexdigest()
27
+ config["seed"] = old_seed
28
+ return ans
29
+
30
+
31
+ def get_date() -> str:
32
+ return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
33
+
34
+
35
+ def generate_params_from_config(config):
36
+ if config.get("env_size_type", "predefined") == "custom":
37
+ # must load env params from a file
38
+ _, static_env_params, env_params = load_from_json_file(os.path.join("worlds", config["custom_path"]))
39
+ return env_params, static_env_params.replace(
40
+ frame_skip=config["frame_skip"],
41
+ )
42
+ env_params = EnvParams()
43
+
44
+ static_env_params = StaticEnvParams().replace(
45
+ num_polygons=config["num_polygons"],
46
+ num_circles=config["num_circles"],
47
+ num_joints=config["num_joints"],
48
+ num_thrusters=config["num_thrusters"],
49
+ frame_skip=config["frame_skip"],
50
+ num_motor_bindings=config["num_motor_bindings"],
51
+ num_thruster_bindings=config["num_thruster_bindings"],
52
+ )
53
+
54
+ return env_params, static_env_params
55
+
56
+
57
+ def generate_ued_params_from_config(config) -> UEDParams:
58
+ ans = UEDParams()
59
+
60
+ if config["env_size_name"] == "s":
61
+ ans = ans.replace(add_shape_n_proposals=1) # otherwise we get a very weird XLA bug.
62
+ if "fixate_chance_max" in config:
63
+ print("Changing fixate chance max to", config["fixate_chance_max"])
64
+ ans = ans.replace(fixate_chance_max=config["fixate_chance_max"])
65
+ return ans
66
+
67
+
68
+ def get_eval_level_groups(eval_levels: List[str]) -> List[Tuple[str, str]]:
69
+ def get_groups(s):
70
+ # This is the size group
71
+ group_one = s.split("/")[0]
72
+ group_two = s.split("/")[1].split("_")[0]
73
+ group_two = "".join([i for i in group_two if not i.isdigit()])
74
+ if group_two == "h":
75
+ group_two = "handmade"
76
+ if group_two == "r":
77
+ group_two = "random"
78
+ return f"{group_one}_all", f"{group_one}_{group_two}"
79
+
80
+ indices = defaultdict(list)
81
+
82
+ for idx, s in enumerate(eval_levels):
83
+ groups = get_groups(s)
84
+ for group in groups:
85
+ indices[group].append(idx)
86
+
87
+ indices2 = {}
88
+ for g in indices:
89
+ indices2[g] = np.array(indices[g])
90
+
91
+ return indices2
92
+
93
+
94
+ def normalise_config(config, name, editor_config=False):
95
+ old_config = copy.deepcopy(config)
96
+ keys = ["env", "learning", "model", "misc", "eval", "ued", "env_size", "train_levels"]
97
+ for k in keys:
98
+ if k not in config:
99
+ config[k] = {}
100
+ small_d = config[k]
101
+ del config[k]
102
+ for kk, vv in small_d.items():
103
+ assert kk not in config, kk
104
+ config[kk] = vv
105
+
106
+ if not editor_config:
107
+ config["eval_env_size_true"] = config["eval_env_size"]
108
+ if config["num_train_envs"] == 2048 and "Pixels" in config["env_name"]:
109
+ config["num_train_envs"] = 512
110
+ if "SFL" in name and config["env_size_name"] in ["m", "l"]:
111
+ config["eval_num_attempts"] = 6 # to avoid a very weird XLA bug.
112
+ config["hash"] = get_hash_without_seed(config)
113
+
114
+ config["random_hash"] = np.random.randint(2**31)
115
+
116
+ config["log_save_path"] = f"logs/{config['hash']}/{config['seed']}-{get_date()}"
117
+ os.makedirs(config["log_save_path"], exist_ok=True)
118
+ with open(f"{config['log_save_path']}/config.yaml", "w") as f:
119
+ f.write(OmegaConf.to_yaml(old_config))
120
+ if config["group"] == "auto":
121
+ config["group"] = f"{name}-" + config["group_auto_prefix"] + config["env_name"].replace("Kinetix-", "")
122
+ config["group"] += "-" + str(config["env_size_name"])
123
+
124
+ if config["eval_levels"] == ["auto"] or config["eval_levels"] == "auto":
125
+ config["eval_levels"] = config["train_levels_list"]
126
+ print("Using Auto eval levels:", config["eval_levels"])
127
+ config["num_eval_levels"] = len(config["eval_levels"])
128
+
129
+ steps = (
130
+ config["num_steps"]
131
+ * config.get("outer_rollout_steps", 1)
132
+ * config["num_train_envs"]
133
+ * (2 if name == "PAIRED" else 1)
134
+ )
135
+ config["num_updates"] = int(config["total_timesteps"]) // steps
136
+
137
+ nsteps = int(config["total_timesteps"] // 1e6)
138
+ letter = "M"
139
+ if nsteps >= 1000:
140
+ nsteps = nsteps // 1000
141
+ letter = "B"
142
+ config["run_name"] = (
143
+ config["env_name"] + f"-{name}-" + str(nsteps) + letter + "-" + str(config["num_train_envs"])
144
+ )
145
+
146
+ if config["checkpoint_save_freq"] >= config["num_updates"]:
147
+ config["checkpoint_save_freq"] = config["num_updates"]
148
+ return config
149
+
150
+
151
+ def get_tags(config, name):
152
+ return [name]
153
+ tags = [name]
154
+ if name in ["PLR", "ACCEL", "DR"]:
155
+ if config["use_accel"]:
156
+ tags.append("ACCEL")
157
+ else:
158
+ tags.append("PLR")
159
+ return tags
160
+
161
+
162
+ def init_wandb(config, name) -> wandb.run:
163
+ run = wandb.init(
164
+ config=config,
165
+ project=config["wandb_project"],
166
+ group=config["group"],
167
+ name=config["run_name"],
168
+ entity=config["wandb_entity"],
169
+ mode=config["wandb_mode"],
170
+ tags=get_tags(config, name),
171
+ )
172
+ wandb.define_metric("timing/num_updates")
173
+ wandb.define_metric("timing/num_env_steps")
174
+ wandb.define_metric("*", step_metric="timing/num_env_steps")
175
+ wandb.define_metric("timing/sps", step_metric="timing/num_env_steps")
176
+ return run
177
+
178
+
179
+ def save_data_to_local_file(data_to_save, config):
180
+ if not config.get("save_local_data", False):
181
+ return
182
+
183
+ def reverse_in(li, value):
184
+ for i, v in enumerate(li):
185
+ if v in value:
186
+ return True
187
+ return False
188
+
189
+ clean_data = {k: v for k, v in data_to_save.items() if not reverse_in(["media/", "images/"], k)}
190
+
191
+ def _clean(x):
192
+ if isinstance(x, jnp.ndarray):
193
+ return x.tolist()
194
+ elif isinstance(x, jnp.float32):
195
+ if jnp.isnan(x):
196
+ return -float("inf")
197
+ return round(float(x) * 1000) / 1000
198
+ elif isinstance(x, jnp.int32):
199
+ return int(x)
200
+ return x
201
+
202
+ clean_data = jax.tree_map(lambda x: _clean(x), clean_data)
203
+ print("Saving this data:", clean_data)
204
+ with open(f"{config['log_save_path']}/data.jsonl", "a+") as f:
205
+ f.write(json.dumps(clean_data) + "\n")
206
+
207
+
208
+ def compress_log_files_after_run(config):
209
+ fpath = f"{config['log_save_path']}/data.jsonl"
210
+ with open(fpath, "rb") as f_in, gzip.open(fpath + ".gz", "wb") as f_out:
211
+ f_out.writelines(f_in)
212
+
213
+
214
+ def get_video_frequency(config, update_step):
215
+ frac_through_training = update_step / config["num_updates"]
216
+ vid_frequency = (
217
+ config["eval_freq"]
218
+ * config["video_frequency"]
219
+ * jax.lax.select(
220
+ (0.1 <= frac_through_training) & (frac_through_training < 0.3),
221
+ 1,
222
+ jax.lax.select(
223
+ (0.3 <= frac_through_training) & (frac_through_training < 0.6),
224
+ 2,
225
+ 4,
226
+ ),
227
+ )
228
+ )
229
+ return vid_frequency
kinetix/util/learning.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ from enum import IntEnum
7
+ from typing import Tuple
8
+
9
+ import chex
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import numpy as np
13
+ import optax
14
+ import orbax.checkpoint as ocp
15
+ from flax import core, struct
16
+ from flax.training.train_state import TrainState as BaseTrainState
17
+
18
+ import wandb
19
+ from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv
20
+ from jaxued.level_sampler import LevelSampler
21
+ from jaxued.utils import compute_max_returns, max_mc, positive_value_loss
22
+
23
+ from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name
24
+ from kinetix.environment.env_state import StaticEnvParams
25
+ from kinetix.environment.utils import permute_pcg_state
26
+ from kinetix.environment.wrappers import (
27
+ UnderspecifiedToGymnaxWrapper,
28
+ LogWrapper,
29
+ DenseRewardWrapper,
30
+ AutoReplayWrapper,
31
+ )
32
+ from kinetix.models import make_network_from_config
33
+ from kinetix.pcg.pcg import env_state_to_pcg_state
34
+ from kinetix.render.renderer_pixels import make_render_pixels
35
+ from kinetix.models.actor_critic import ScannedRNN
36
+ from kinetix.util.saving import (
37
+ expand_pcg_state,
38
+ get_pcg_state_from_json,
39
+ load_pcg_state_pickle,
40
+ load_world_state_pickle,
41
+ stack_list_of_pytrees,
42
+ import_env_state_from_json,
43
+ load_from_json_file,
44
+ )
45
+ from flax.training.train_state import TrainState
46
+
47
+ BASE_DIR = "worlds"
48
+
49
+ DEFAULT_EVAL_LEVELS = [
50
+ "easy.cartpole",
51
+ "easy.flappy_bird",
52
+ "easy.unicycle",
53
+ "easy.car_left",
54
+ "easy.car_right",
55
+ "easy.pinball",
56
+ "easy.swing_up",
57
+ "easy.thruster",
58
+ ]
59
+
60
+
61
+ def get_eval_levels(eval_levels, static_env_params):
62
+ should_permute = [".permute" in l for l in eval_levels]
63
+ eval_levels = [re.sub(r"\.permute\d+", "", l) for l in eval_levels]
64
+ ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in eval_levels]
65
+ ls = [expand_pcg_state(l, static_env_params) for l in ls]
66
+ new_ls = []
67
+ rng = jax.random.PRNGKey(0)
68
+ for sp, l in zip(should_permute, ls):
69
+ rng, _rng = jax.random.split(rng)
70
+ if sp:
71
+ l = permute_pcg_state(_rng, l, static_env_params)
72
+ new_ls.append(l)
73
+ return stack_list_of_pytrees(new_ls)
74
+
75
+
76
+ def evaluate_rnn( # from jaxued
77
+ rng: chex.PRNGKey,
78
+ env: UnderspecifiedEnv,
79
+ env_params: EnvParams,
80
+ train_state: TrainState,
81
+ init_hstate: chex.ArrayTree,
82
+ init_obs: Observation,
83
+ init_env_state: EnvState,
84
+ max_episode_length: int,
85
+ keep_states=True,
86
+ return_trajectories=False,
87
+ ) -> Tuple[chex.Array, chex.Array, chex.Array]:
88
+ """This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths)
89
+
90
+ Args:
91
+ rng (chex.PRNGKey):
92
+ env (UnderspecifiedEnv):
93
+ env_params (EnvParams):
94
+ train_state (TrainState):
95
+ init_hstate (chex.ArrayTree): Shape (num_levels, )
96
+ init_obs (Observation): Shape (num_levels, )
97
+ init_env_state (EnvState): Shape (num_levels, )
98
+ max_episode_length (int):
99
+
100
+ Returns:
101
+ Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,)
102
+ """
103
+ num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0]
104
+
105
+ def step(carry, _):
106
+ rng, hstate, obs, state, done, mask, episode_length = carry
107
+ rng, rng_action, rng_step = jax.random.split(rng, 3)
108
+
109
+ x = jax.tree.map(lambda x: x[None, ...], (obs, done))
110
+ hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x)
111
+ action = pi.sample(seed=rng_action).squeeze(0)
112
+
113
+ obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
114
+ jax.random.split(rng_step, num_levels), state, action, env_params
115
+ )
116
+
117
+ next_mask = mask & ~done
118
+ episode_length += mask
119
+
120
+ if keep_states:
121
+ return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, done, info)
122
+ else:
123
+ return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, done, info)
124
+
125
+ (_, _, _, _, _, _, episode_lengths), (states, rewards, dones, infos) = jax.lax.scan(
126
+ step,
127
+ (
128
+ rng,
129
+ init_hstate,
130
+ init_obs,
131
+ init_env_state,
132
+ jnp.zeros(num_levels, dtype=bool),
133
+ jnp.ones(num_levels, dtype=bool),
134
+ jnp.zeros(num_levels, dtype=jnp.int32),
135
+ ),
136
+ None,
137
+ length=max_episode_length,
138
+ )
139
+ done_idx = jnp.argmax(dones, axis=0)
140
+
141
+ to_return = (states, rewards, done_idx, episode_lengths, infos)
142
+ if return_trajectories:
143
+ return to_return, (dones, rewards)
144
+ return to_return
145
+
146
+
147
+ def general_eval(
148
+ rng: chex.PRNGKey,
149
+ eval_env: UnderspecifiedEnv,
150
+ env_params: EnvParams,
151
+ train_state: TrainState,
152
+ levels: EnvState,
153
+ num_eval_steps: int,
154
+ num_levels: int,
155
+ keep_states=True,
156
+ return_trajectories=False,
157
+ ):
158
+ """
159
+ This evaluates the current policy on the set of evaluation levels
160
+ It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
161
+ """
162
+ rng, rng_reset = jax.random.split(rng)
163
+ init_obs, init_env_state = jax.vmap(eval_env.reset_to_level, (0, 0, None))(
164
+ jax.random.split(rng_reset, num_levels), levels, env_params
165
+ )
166
+ init_hstate = ScannedRNN.initialize_carry(num_levels)
167
+ (states, rewards, done_idx, episode_lengths, infos), (dones, reward) = evaluate_rnn(
168
+ rng,
169
+ eval_env,
170
+ env_params,
171
+ train_state,
172
+ init_hstate,
173
+ init_obs,
174
+ init_env_state,
175
+ num_eval_steps,
176
+ keep_states=keep_states,
177
+ return_trajectories=True,
178
+ )
179
+ mask = jnp.arange(num_eval_steps)[..., None] < episode_lengths
180
+ cum_rewards = (rewards * mask).sum(axis=0)
181
+ to_return = (
182
+ states,
183
+ cum_rewards,
184
+ done_idx,
185
+ episode_lengths,
186
+ infos,
187
+ ) # (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
188
+
189
+ if return_trajectories:
190
+ return to_return, (dones, reward)
191
+ return to_return
192
+
193
+
194
+ def compute_gae(
195
+ gamma: float,
196
+ lambd: float,
197
+ last_value: chex.Array,
198
+ values: chex.Array,
199
+ rewards: chex.Array,
200
+ dones: chex.Array,
201
+ ) -> Tuple[chex.Array, chex.Array]:
202
+ """This takes in arrays of shape (NUM_STEPS, NUM_ENVS) and returns the advantages and targets.
203
+
204
+ Args:
205
+ gamma (float):
206
+ lambd (float):
207
+ last_value (chex.Array): Shape (NUM_ENVS)
208
+ values (chex.Array): Shape (NUM_STEPS, NUM_ENVS)
209
+ rewards (chex.Array): Shape (NUM_STEPS, NUM_ENVS)
210
+ dones (chex.Array): Shape (NUM_STEPS, NUM_ENVS)
211
+
212
+ Returns:
213
+ Tuple[chex.Array, chex.Array]: advantages, targets; each of shape (NUM_STEPS, NUM_ENVS)
214
+ """
215
+
216
+ def compute_gae_at_timestep(carry, x):
217
+ gae, next_value = carry
218
+ value, reward, done = x
219
+ delta = reward + gamma * next_value * (1 - done) - value
220
+ gae = delta + gamma * lambd * (1 - done) * gae
221
+ return (gae, value), gae
222
+
223
+ _, advantages = jax.lax.scan(
224
+ compute_gae_at_timestep,
225
+ (jnp.zeros_like(last_value), last_value),
226
+ (values, rewards, dones),
227
+ reverse=True,
228
+ unroll=16,
229
+ )
230
+ return advantages, advantages + values
231
+
232
+
233
+ def sample_trajectories_rnn(
234
+ rng: chex.PRNGKey,
235
+ env: UnderspecifiedEnv,
236
+ env_params: EnvParams,
237
+ train_state: TrainState,
238
+ init_hstate: chex.ArrayTree,
239
+ init_obs: Observation,
240
+ init_env_state: EnvState,
241
+ num_envs: int,
242
+ max_episode_length: int,
243
+ return_states: bool = False,
244
+ ) -> Tuple[
245
+ Tuple[chex.PRNGKey, TrainState, chex.ArrayTree, Observation, EnvState, chex.Array],
246
+ Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict],
247
+ ]:
248
+ """This samples trajectories from the environment using the agent specified by the `train_state`.
249
+
250
+ Args:
251
+
252
+ rng (chex.PRNGKey): Singleton
253
+ env (UnderspecifiedEnv):
254
+ env_params (EnvParams):
255
+ train_state (TrainState): Singleton
256
+ init_hstate (chex.ArrayTree): This is the init RNN hidden state, has to have shape (NUM_ENVS, ...)
257
+ init_obs (Observation): The initial observation, shape (NUM_ENVS, ...)
258
+ init_env_state (EnvState): The initial env state (NUM_ENVS, ...)
259
+ num_envs (int): The number of envs that are vmapped over.
260
+ max_episode_length (int): The maximum episode length, i.e., the number of steps to do the rollouts for.
261
+
262
+ Returns:
263
+ Tuple[Tuple[chex.PRNGKey, TrainState, chex.ArrayTree, Observation, EnvState, chex.Array], Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict]]: (rng, train_state, hstate, last_obs, last_env_state, last_value), traj, where traj is (obs, action, reward, done, log_prob, value, info). The first element in the tuple consists of arrays that have shapes (NUM_ENVS, ...) (except `rng` and and `train_state` which are singleton). The second element in the tuple is of shape (NUM_STEPS, NUM_ENVS, ...), and it contains the trajectory.
264
+ """
265
+
266
+ def sample_step(carry, _):
267
+ rng, train_state, hstate, obs, env_state, last_done = carry
268
+ prev_state = env_state
269
+ rng, rng_action, rng_step = jax.random.split(rng, 3)
270
+
271
+ x = jax.tree.map(lambda x: x[None, ...], (obs, last_done))
272
+ hstate, pi, value = train_state.apply_fn(train_state.params, hstate, x)
273
+ action = pi.sample(seed=rng_action)
274
+ log_prob = pi.log_prob(action)
275
+ value, action, log_prob = jax.tree.map(lambda x: x.squeeze(0), (value, action, log_prob))
276
+
277
+ next_obs, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
278
+ jax.random.split(rng_step, num_envs), env_state, action, env_params
279
+ )
280
+
281
+ carry = (rng, train_state, hstate, next_obs, env_state, done)
282
+ step = (obs, action, reward, done, log_prob, value, info)
283
+ if return_states:
284
+ step += (prev_state,)
285
+ return carry, step
286
+
287
+ (rng, train_state, hstate, last_obs, last_env_state, last_done), traj = jax.lax.scan(
288
+ sample_step,
289
+ (
290
+ rng,
291
+ train_state,
292
+ init_hstate,
293
+ init_obs,
294
+ init_env_state,
295
+ jnp.zeros(num_envs, dtype=bool),
296
+ ),
297
+ None,
298
+ length=max_episode_length,
299
+ )
300
+
301
+ x = jax.tree.map(lambda x: x[None, ...], (last_obs, last_done))
302
+ _, _, last_value = train_state.apply_fn(train_state.params, hstate, x)
303
+
304
+ my_obs = traj[0]
305
+ rew = traj[2]
306
+
307
+ return (rng, train_state, hstate, last_obs, last_env_state, last_value.squeeze(0)), traj
308
+
309
+
310
+ def update_actor_critic_rnn(
311
+ rng: chex.PRNGKey,
312
+ train_state: TrainState,
313
+ init_hstate: chex.ArrayTree,
314
+ batch: chex.ArrayTree,
315
+ num_envs: int,
316
+ n_steps: int,
317
+ n_minibatch: int,
318
+ n_epochs: int,
319
+ clip_eps: float,
320
+ entropy_coeff: float,
321
+ critic_coeff: float,
322
+ update_grad: bool = True,
323
+ ) -> Tuple[Tuple[chex.PRNGKey, TrainState], chex.ArrayTree]:
324
+ """This function takes in a rollout, and PPO hyperparameters, and updates the train state.
325
+
326
+ Args:
327
+ rng (chex.PRNGKey):
328
+ train_state (TrainState):
329
+ init_hstate (chex.ArrayTree):
330
+ batch (chex.ArrayTree): obs, actions, dones, log_probs, values, targets, advantages
331
+ num_envs (int):
332
+ n_steps (int):
333
+ n_minibatch (int):
334
+ n_epochs (int):
335
+ clip_eps (float):
336
+ entropy_coeff (float):
337
+ critic_coeff (float):
338
+ update_grad (bool, optional): If False, the train state does not actually get updated. Defaults to True.
339
+
340
+ Returns:
341
+ Tuple[Tuple[chex.PRNGKey, TrainState], chex.ArrayTree]: It returns a new rng, the updated train_state, and the losses. The losses have structure (loss, (l_vf, l_clip, entropy))
342
+ """
343
+ obs, actions, dones, log_probs, values, targets, advantages = batch
344
+ last_dones = jnp.roll(dones, 1, axis=0).at[0].set(False)
345
+ batch = obs, actions, last_dones, log_probs, values, targets, advantages
346
+
347
+ def update_epoch(carry, _):
348
+ def update_minibatch(train_state, minibatch):
349
+ init_hstate, obs, actions, last_dones, log_probs, values, targets, advantages = minibatch
350
+
351
+ def loss_fn(params):
352
+ _, pi, values_pred = train_state.apply_fn(params, init_hstate, (obs, last_dones))
353
+ log_probs_pred = pi.log_prob(actions)
354
+ entropy = pi.entropy().mean()
355
+
356
+ ratio = jnp.exp(log_probs_pred - log_probs)
357
+ A = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
358
+ l_clip = (-jnp.minimum(ratio * A, jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps) * A)).mean()
359
+
360
+ values_pred_clipped = values + (values_pred - values).clip(-clip_eps, clip_eps)
361
+ l_vf = 0.5 * jnp.maximum((values_pred - targets) ** 2, (values_pred_clipped - targets) ** 2).mean()
362
+
363
+ loss = l_clip + critic_coeff * l_vf - entropy_coeff * entropy
364
+
365
+ return loss, (l_vf, l_clip, entropy)
366
+
367
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
368
+ loss, grads = grad_fn(train_state.params)
369
+ if update_grad:
370
+ train_state = train_state.apply_gradients(grads=grads)
371
+ grad_norm = jnp.linalg.norm(
372
+ jnp.concatenate(jax.tree_util.tree_map(lambda x: x.flatten(), jax.tree_util.tree_flatten(grads)[0]))
373
+ )
374
+ return train_state, (loss, grad_norm)
375
+
376
+ rng, train_state = carry
377
+ rng, rng_perm = jax.random.split(rng)
378
+ permutation = jax.random.permutation(rng_perm, num_envs)
379
+ minibatches = (
380
+ jax.tree.map(
381
+ lambda x: jnp.take(x, permutation, axis=0).reshape(n_minibatch, -1, *x.shape[1:]),
382
+ init_hstate,
383
+ ),
384
+ *jax.tree.map(
385
+ lambda x: jnp.take(x, permutation, axis=1)
386
+ .reshape(x.shape[0], n_minibatch, -1, *x.shape[2:])
387
+ .swapaxes(0, 1),
388
+ batch,
389
+ ),
390
+ )
391
+ train_state, (losses, grads) = jax.lax.scan(update_minibatch, train_state, minibatches)
392
+ return (rng, train_state), (losses, grads)
393
+
394
+ return jax.lax.scan(update_epoch, (rng, train_state), None, n_epochs)
395
+
396
+
397
+ @partial(jax.jit, static_argnums=(0, 2, 8, 9))
398
+ def sample_trajectories_and_learn(
399
+ env: UnderspecifiedEnv,
400
+ env_params: EnvParams,
401
+ config: dict,
402
+ rng: chex.PRNGKey,
403
+ train_state: TrainState,
404
+ init_hstate: chex.Array,
405
+ init_obs: Observation,
406
+ init_env_state: EnvState,
407
+ update_grad: bool = True,
408
+ return_states: bool = False,
409
+ ) -> Tuple[
410
+ Tuple[chex.PRNGKey, TrainState, Observation, EnvState],
411
+ Tuple[
412
+ Observation,
413
+ chex.Array,
414
+ chex.Array,
415
+ chex.Array,
416
+ chex.Array,
417
+ chex.Array,
418
+ dict,
419
+ chex.Array,
420
+ chex.Array,
421
+ chex.ArrayTree,
422
+ chex.Array,
423
+ ],
424
+ ]:
425
+ """This function loops the following:
426
+ - rollout for config['num_steps']
427
+ - learn / update policy
428
+
429
+ And it loops it for config['outer_rollout_steps'].
430
+ What is returns is a new carry (rng, train_state, init_obs, init_env_state), and concatenated rollouts. The shape of the rollouts are config['num_steps'] * config['outer_rollout_steps']. In other words, the trajectories returned by this function are the same as if we ran rollouts for config['num_steps'] * config['outer_rollout_steps'] steps, but the agent does perform PPO updates in between.
431
+
432
+ Args:
433
+ env (UnderspecifiedEnv):
434
+ env_params (EnvParams):
435
+ config (dict):
436
+ rng (chex.PRNGKey):
437
+ train_state (TrainState):
438
+ init_obs (Observation):
439
+ init_env_state (EnvState):
440
+ update_grad (bool, optional): Defaults to True.
441
+
442
+ Returns:
443
+ Tuple[Tuple[chex.PRNGKey, TrainState, Observation, EnvState], Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict, chex.Array, chex.Array, chex.ArrayTree, chex.Array]]: This returns a tuple:
444
+ (
445
+ (rng, train_state, init_obs, init_env_state),
446
+ (obs, actions, rewards, dones, log_probs, values, info, advantages, targets, losses, grads)
447
+ )
448
+ """
449
+
450
+ def single_step(carry, _):
451
+ rng, train_state, init_hstate, init_obs, init_env_state = carry
452
+ ((rng, train_state, new_hstate, last_obs, last_env_state, last_value), traj,) = sample_trajectories_rnn(
453
+ rng,
454
+ env,
455
+ env_params,
456
+ train_state,
457
+ init_hstate,
458
+ init_obs,
459
+ init_env_state,
460
+ config["num_train_envs"],
461
+ config["num_steps"],
462
+ return_states=return_states,
463
+ )
464
+ if return_states:
465
+ states = traj[-1]
466
+ traj = traj[:-1]
467
+
468
+ (obs, actions, rewards, dones, log_probs, values, info) = traj
469
+ advantages, targets = compute_gae(config["gamma"], config["gae_lambda"], last_value, values, rewards, dones)
470
+
471
+ # Update the policy using trajectories collected from replay levels
472
+ (rng, train_state), (losses, grads) = update_actor_critic_rnn(
473
+ rng,
474
+ train_state,
475
+ init_hstate,
476
+ (obs, actions, dones, log_probs, values, targets, advantages),
477
+ config["num_train_envs"],
478
+ config["num_steps"],
479
+ config["num_minibatches"],
480
+ config["update_epochs"],
481
+ config["clip_eps"],
482
+ config["ent_coef"],
483
+ config["vf_coef"],
484
+ update_grad=update_grad,
485
+ )
486
+ new_carry = (rng, train_state, new_hstate, last_obs, last_env_state)
487
+ step = (obs, actions, rewards, dones, log_probs, values, info, advantages, targets, losses, grads)
488
+ if return_states:
489
+ step += (states,)
490
+ return new_carry, step
491
+
492
+ carry = (rng, train_state, init_hstate, init_obs, init_env_state)
493
+ new_carry, all_rollouts = jax.lax.scan(single_step, carry, None, length=config["outer_rollout_steps"])
494
+
495
+ all_rollouts = jax.tree_util.tree_map(lambda x: jnp.concatenate(x, axis=0), all_rollouts)
496
+ return new_carry, all_rollouts
497
+
498
+
499
+ def no_op_rollout(
500
+ env: UnderspecifiedEnv,
501
+ env_params: EnvParams,
502
+ rng: chex.PRNGKey,
503
+ init_obs: Observation,
504
+ init_env_state: EnvState,
505
+ num_envs: int,
506
+ max_episode_length: int,
507
+ do_random=False,
508
+ ):
509
+
510
+ noop = jnp.array(env.action_type.noop_action())
511
+ zero_action = jnp.repeat(noop[None, ...], num_envs, axis=0)
512
+ SHAPE = zero_action.shape
513
+
514
+ def sample_step(carry, _):
515
+ rng, obs, env_state, last_done = carry
516
+ rng, rng_step, _rng = jax.random.split(rng, 3)
517
+ if do_random:
518
+ action = jax.vmap(env.action_space(env_params).sample)(jax.random.split(_rng, num_envs))
519
+ else:
520
+ action = zero_action
521
+
522
+ next_obs, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
523
+ jax.random.split(rng_step, num_envs), env_state, action, env_params
524
+ )
525
+
526
+ carry = (rng, next_obs, env_state, done)
527
+ return carry, (obs, action, reward, done, info)
528
+
529
+ (rng, last_obs, last_env_state, last_done), traj = jax.lax.scan(
530
+ sample_step,
531
+ (
532
+ rng,
533
+ init_obs,
534
+ init_env_state,
535
+ jnp.zeros(num_envs, dtype=bool),
536
+ ),
537
+ None,
538
+ length=max_episode_length,
539
+ )
540
+
541
+ info = traj[-1]
542
+ dones = traj[-2]
543
+
544
+ returns_per_env = (info["returned_episode_returns"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0))
545
+ lens_per_env = (info["returned_episode_lengths"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0))
546
+ success_per_env = (info["returned_episode_solved"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0))
547
+ return returns_per_env, lens_per_env, success_per_env
548
+
549
+
550
+ def no_op_and_random_rollout(
551
+ env: UnderspecifiedEnv,
552
+ env_params: EnvParams,
553
+ rng: chex.PRNGKey,
554
+ init_obs: Observation,
555
+ init_env_state: EnvState,
556
+ num_envs: int,
557
+ max_episode_length: int,
558
+ ):
559
+ returns_noop, lens_noop, success_noop = no_op_rollout(
560
+ env, env_params, rng, init_obs, init_env_state, num_envs, max_episode_length, do_random=False
561
+ )
562
+ returns_random, lens_random, success_random = no_op_rollout(
563
+ env, env_params, rng, init_obs, init_env_state, num_envs, max_episode_length, do_random=True
564
+ )
565
+ return returns_noop, lens_noop, success_noop, returns_random, lens_random, success_random
kinetix/util/saving.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ from typing import Any, Dict, Union
5
+
6
+ import flax.serialization
7
+ import flax.serialization
8
+ import flax.serialization
9
+ import flax.serialization
10
+ import flax.serialization
11
+ import flax.serialization
12
+ import flax.serialization
13
+ import jax
14
+ import jax.numpy as jnp
15
+ import flax
16
+ import wandb
17
+ from jax2d.engine import (
18
+ calculate_collision_matrix,
19
+ get_empty_collision_manifolds,
20
+ get_pairwise_interaction_indices,
21
+ recalculate_mass_and_inertia,
22
+ )
23
+ from jax2d.sim_state import RigidBody, SimState
24
+ from kinetix.environment.env_state import EnvState, StaticEnvParams, EnvParams
25
+
26
+ from flax.traverse_util import flatten_dict, unflatten_dict
27
+
28
+ from safetensors.flax import save_file, load_file
29
+
30
+ from kinetix.pcg.pcg import env_state_to_pcg_state
31
+ from kinetix.pcg.pcg_state import PCGState
32
+ import bz2
33
+
34
+
35
+ def check_if_mass_and_inertia_are_correct(state: SimState, env_params: EnvParams, static_params):
36
+ new = recalculate_mass_and_inertia(state, static_params, state.polygon_densities, state.circle_densities)
37
+
38
+ def _check(a, b, shape, name):
39
+ a = jnp.where(shape.active, a, jnp.zeros_like(a))
40
+ b = jnp.where(shape.active, b, jnp.zeros_like(b))
41
+
42
+ if not jnp.allclose(a, b):
43
+ idxs = jnp.arange(len(shape.active))[(a != b) & shape.active]
44
+ new_one = a[idxs]
45
+ old_one = b[idxs]
46
+ raise ValueError(
47
+ f"Error: {name} is not the same after loading. Indexes {idxs} are incorrect. New = {new_one} | Before = {old_one}"
48
+ )
49
+
50
+ _check(new.polygon.inverse_mass, state.polygon.inverse_mass, state.polygon, "Polygon inverse mass")
51
+ _check(new.circle.inverse_mass, state.circle.inverse_mass, state.circle, "Circle inverse mass")
52
+ _check(new.polygon.inverse_inertia, state.polygon.inverse_inertia, state.polygon, "Polygon inverse inertia")
53
+ _check(new.circle.inverse_inertia, state.circle.inverse_inertia, state.circle, "Circle inverse inertia")
54
+ return True
55
+
56
+
57
+ def save_pickle(filename, state):
58
+ with open(filename, "wb") as f:
59
+ pickle.dump(state, f)
60
+
61
+
62
+ def load_pcg_state_pickle(filename):
63
+ with open(filename, "rb") as f:
64
+ return pickle.load(f)
65
+
66
+
67
+ def expand_env_state(env_state: EnvState, static_env_params: StaticEnvParams, ignore_collision_matrix=False):
68
+
69
+ num_rects = len(env_state.polygon.position)
70
+ num_circles = len(env_state.circle.position)
71
+ num_joints = len(env_state.joint.a_index)
72
+ num_thrusters = len(env_state.thruster.object_index)
73
+
74
+ def _add_dummy(num_to_add, obj):
75
+ return jax.tree_map(
76
+ lambda current: jnp.concatenate(
77
+ [current, jnp.zeros((num_to_add, *current.shape[1:]), dtype=current.dtype)], axis=0
78
+ ),
79
+ obj,
80
+ )
81
+
82
+ does_need_to_change = False
83
+ added_rects = 0
84
+
85
+ if (
86
+ num_rects > static_env_params.num_polygons
87
+ or num_circles > static_env_params.num_circles
88
+ or num_joints > static_env_params.num_joints
89
+ ):
90
+ raise Exception(
91
+ f"The current static_env_params is too small to accommodate the loaded env_state (needs num_rects={num_rects}, num_circles={num_circles}, num_joints={num_joints} but current is {static_env_params.num_polygons}, {static_env_params.num_circles}, {static_env_params.num_joints})."
92
+ )
93
+
94
+ if num_rects < static_env_params.num_polygons:
95
+ added_rects = static_env_params.num_polygons - num_rects
96
+ does_need_to_change = True
97
+ env_state = env_state.replace(
98
+ polygon=_add_dummy(added_rects, env_state.polygon),
99
+ polygon_shape_roles=_add_dummy(added_rects, env_state.polygon_shape_roles),
100
+ polygon_highlighted=_add_dummy(added_rects, env_state.polygon_highlighted),
101
+ polygon_densities=_add_dummy(added_rects, env_state.polygon_densities),
102
+ )
103
+
104
+ if num_circles < static_env_params.num_circles:
105
+ does_need_to_change = True
106
+ n_to_add = static_env_params.num_circles - num_circles
107
+ env_state = env_state.replace(
108
+ circle=_add_dummy(n_to_add, env_state.circle),
109
+ circle_shape_roles=_add_dummy(n_to_add, env_state.circle_shape_roles),
110
+ circle_highlighted=_add_dummy(n_to_add, env_state.circle_highlighted),
111
+ circle_densities=_add_dummy(n_to_add, env_state.circle_densities),
112
+ )
113
+
114
+ if num_joints < static_env_params.num_joints:
115
+ does_need_to_change = True
116
+ n_to_add = static_env_params.num_joints - num_joints
117
+ env_state = env_state.replace(
118
+ joint=_add_dummy(n_to_add, env_state.joint),
119
+ motor_bindings=_add_dummy(n_to_add, env_state.motor_bindings),
120
+ motor_auto=_add_dummy(n_to_add, env_state.motor_auto),
121
+ )
122
+
123
+ if num_thrusters < static_env_params.num_thrusters:
124
+ does_need_to_change = True
125
+ n_to_add = static_env_params.num_thrusters - num_thrusters
126
+ env_state = env_state.replace(
127
+ thruster=_add_dummy(n_to_add, env_state.thruster),
128
+ thruster_bindings=_add_dummy(n_to_add, env_state.thruster_bindings),
129
+ )
130
+
131
+ # This fixes the indices
132
+ def _modify_index(old_indices):
133
+ return jnp.where(old_indices >= num_rects, old_indices + added_rects, old_indices)
134
+
135
+ if added_rects > 0:
136
+ env_state = env_state.replace(
137
+ joint=env_state.joint.replace(
138
+ a_index=_modify_index(env_state.joint.a_index),
139
+ b_index=_modify_index(env_state.joint.b_index),
140
+ ),
141
+ thruster=env_state.thruster.replace(
142
+ object_index=_modify_index(env_state.thruster.object_index),
143
+ ),
144
+ )
145
+ # Double check the collision manifolds are fine
146
+ if does_need_to_change or 1:
147
+ # print("Loading but changing the shapes to match the current static params.")
148
+ acc_rr_manifolds, acc_cr_manifolds, acc_cc_manifolds = get_empty_collision_manifolds(static_env_params)
149
+ env_state = env_state.replace(
150
+ collision_matrix=(
151
+ env_state.collision_matrix
152
+ if ignore_collision_matrix
153
+ else calculate_collision_matrix(static_env_params, env_state.joint)
154
+ ),
155
+ acc_rr_manifolds=acc_rr_manifolds,
156
+ acc_cr_manifolds=acc_cr_manifolds,
157
+ acc_cc_manifolds=acc_cc_manifolds,
158
+ )
159
+ return env_state
160
+
161
+
162
+ def expand_pcg_state(pcg_state: PCGState, static_env_params):
163
+ new_pcg_state = pcg_state.replace(
164
+ env_state=expand_env_state(pcg_state.env_state, static_env_params),
165
+ env_state_max=expand_env_state(pcg_state.env_state_max, static_env_params),
166
+ env_state_pcg_mask=expand_env_state(
167
+ pcg_state.env_state_pcg_mask, static_env_params, ignore_collision_matrix=True
168
+ ),
169
+ )
170
+ new_pcg_state = new_pcg_state.replace(
171
+ env_state_pcg_mask=new_pcg_state.env_state_pcg_mask.replace(
172
+ collision_matrix=jnp.zeros_like(new_pcg_state.env_state.collision_matrix, dtype=bool),
173
+ )
174
+ )
175
+ num_shapes = new_pcg_state.env_state.polygon.active.shape[0] + new_pcg_state.env_state.circle.active.shape[0]
176
+
177
+ return new_pcg_state.replace(
178
+ tied_together=jnp.zeros((num_shapes, num_shapes), dtype=bool)
179
+ .at[
180
+ : pcg_state.tied_together.shape[0],
181
+ : pcg_state.tied_together.shape[1],
182
+ ]
183
+ .set(pcg_state.tied_together)
184
+ )
185
+
186
+
187
+ def load_world_state_pickle(filename, params=None, static_env_params=None):
188
+ static_params = static_env_params or StaticEnvParams()
189
+ with open(filename, "rb") as f:
190
+ state: SimState = pickle.load(f)
191
+ state = jax.tree.map(lambda x: jnp.nan_to_num(x), state)
192
+ # Check if the mass and inertia are reasonable.
193
+ check_if_mass_and_inertia_are_correct(state, params or EnvParams(), static_params)
194
+
195
+ # Now check if the shapes are correct
196
+ return expand_env_state(state, static_params)
197
+
198
+
199
+ def stack_list_of_pytrees(list_of_pytrees):
200
+ v = jax.tree_map(lambda x: jnp.expand_dims(x, 0), list_of_pytrees[0])
201
+ for l in list_of_pytrees[1:]:
202
+ v = jax.tree_map(lambda x, y: jnp.concatenate([x, jnp.expand_dims(y, 0)], axis=0), v, l)
203
+ return v
204
+
205
+ def get_pcg_state_from_json(json_filename) -> PCGState:
206
+ env_state, _, _ = load_from_json_file(json_filename)
207
+ return env_state_to_pcg_state(env_state)
208
+
209
+ def my_load_file(filename):
210
+ data = bz2.BZ2File(filename, "rb")
211
+ data = pickle.load(data)
212
+ return data
213
+
214
+
215
+ def my_save_file(obj, filename):
216
+ with bz2.BZ2File(filename, "w") as f:
217
+ pickle.dump(obj, f)
218
+
219
+
220
+ def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
221
+ my_save_file(params, filename)
222
+
223
+
224
+ def load_params(filename: Union[str, os.PathLike], legacy=False) -> Dict:
225
+ if legacy:
226
+ filename = filename.replace("full_model.pbz2", "model.safetensors")
227
+ filename = filename.replace(".pbz2", ".safetensors")
228
+ return unflatten_dict(load_file(filename), sep=",")
229
+ return my_load_file(filename)
230
+
231
+
232
+ def load_params_from_wandb_artifact_path(checkpoint_name, legacy=False):
233
+ api = wandb.Api()
234
+ name = api.artifact(checkpoint_name).download()
235
+ network_params = load_params(name + "/model.pbz2", legacy=legacy)
236
+ return network_params
237
+
238
+
239
+ def save_params_to_wandb(params, timesteps, config):
240
+ if config["checkpoint_human_numbers"]:
241
+ timesteps = str(round(timesteps / 1e9)) + "B"
242
+
243
+ run_name = config["run_name"] + "-" + str(config["random_hash"]) + "-" + str(timesteps)
244
+ save_dir = os.path.join(config["save_path"], run_name)
245
+ os.makedirs(save_dir, exist_ok=True)
246
+ save_params(params, f"{save_dir}/model.pbz2")
247
+
248
+ # upload this to wandb as an artifact
249
+ artifact = wandb.Artifact(f"{run_name}-checkpoint", type="checkpoint")
250
+ artifact.add_file(f"{save_dir}/model.pbz2")
251
+ artifact.save()
252
+ print(f"Parameters of model saved in {save_dir}/model.pbz2")
253
+
254
+
255
+ def load_params_wandb_artifact_path_full_model(checkpoint_name):
256
+ api = wandb.Api()
257
+ name = api.artifact(checkpoint_name).download()
258
+ all_dict = load_params(name + "/full_model.pbz2")
259
+ return all_dict["params"]
260
+
261
+
262
+ def load_train_state_from_wandb_artifact_path(train_state, checkpoint_name, load_only_params=False, legacy=False):
263
+ api = wandb.Api()
264
+ name = api.artifact(checkpoint_name).download()
265
+ all_dict = load_params(name + "/full_model.pbz2", legacy=legacy)
266
+ if legacy:
267
+ return train_state.replace(params=all_dict)
268
+ train_state = train_state.replace(params=all_dict["params"])
269
+ if not load_only_params:
270
+ train_state = train_state.replace(
271
+ # step=all_dict["step"],
272
+ opt_state=all_dict["opt_state"]
273
+ )
274
+ return train_state
275
+
276
+
277
+ def save_params_to_wandb(params, timesteps, config):
278
+ return save_dict_to_wandb(params, timesteps, config, "params")
279
+
280
+
281
+ def save_dict_to_wandb(dict, timesteps, config, name):
282
+ timesteps = str(round(timesteps / 1e9)) + "B"
283
+ run_name = config["run_name"] + "-" + str(config["random_hash"]) + "-" + str(timesteps)
284
+ save_dir = os.path.join(config["save_path"], run_name)
285
+ os.makedirs(save_dir, exist_ok=True)
286
+ save_params(dict, f"{save_dir}/{name}.pbz2")
287
+
288
+ # upload this to wandb as an artifact
289
+ artifact = wandb.Artifact(f"{run_name}-checkpoint", type="checkpoint")
290
+ artifact.add_file(f"{save_dir}/{name}.pbz2")
291
+ artifact.save()
292
+ print(f"Parameters of model saved in {save_dir}/{name}.pbz2")
293
+
294
+
295
+ def save_model_to_wandb(train_state, timesteps, config, is_final=False):
296
+ dict_to_use = {"step": train_state.step, "params": train_state.params, "opt_state": train_state.opt_state}
297
+ step = int(train_state.step)
298
+ if config["economical_saving"]:
299
+ if step in [2048, 10240, 40960, 81920] or is_final:
300
+ save_dict_to_wandb(dict_to_use, timesteps, config, "full_model")
301
+ else:
302
+ print("Not saving model because step is", step)
303
+ else:
304
+ save_dict_to_wandb(dict_to_use, timesteps, config, "full_model")
305
+
306
+
307
+ def import_env_state_from_json(json_file: dict[str, Any]) -> tuple[EnvState, StaticEnvParams, EnvParams]:
308
+ from kinetix.environment.env import create_empty_env
309
+
310
+ def normalise(k, v):
311
+ if k == "screen_dim":
312
+ return v
313
+ if type(v) == dict and "0" in v:
314
+ return jnp.array([normalise(k, v[str(i)]) for i in range(len(v))])
315
+ return v
316
+
317
+ env_state = json_file["env_state"]
318
+ env_params = json_file["env_params"]
319
+ static_env_params = json_file["static_env_params"]
320
+ env_params_target = EnvParams()
321
+ static_env_params_target = StaticEnvParams()
322
+ new_env_params = flax.serialization.from_state_dict(
323
+ env_params_target, {k: normalise(k, v) for k, v in env_params.items()}
324
+ )
325
+ norm_static = {k: normalise(k, v) for k, v in static_env_params.items()}
326
+ # norm_static["screen_dim"] = tuple(static_env_params_target.screen_dim)
327
+ norm_static["downscale"] = static_env_params_target.downscale
328
+ # print(
329
+ # static_env_params_target,
330
+ # )
331
+ new_static_env_params = flax.serialization.from_state_dict(static_env_params_target, norm_static)
332
+ new_static_env_params = new_static_env_params.replace(screen_dim=static_env_params_target.screen_dim)
333
+
334
+ env_state_target = create_empty_env(new_static_env_params)
335
+
336
+ def astype(x, all):
337
+ return jnp.astype(x, all.dtype)
338
+
339
+ def _load_rigidbody(env_state_target, i, is_poly):
340
+
341
+ to_load_from: dict[str, Any] = env_state["circle" if not is_poly else "polygon"][i]
342
+ role = to_load_from.pop("role")
343
+ density = to_load_from.pop("density")
344
+ if "highlighted" in to_load_from:
345
+ _ = to_load_from.pop("highlighted")
346
+ new_obj = flax.serialization.from_state_dict(
347
+ jax.tree.map(lambda x: x[i], env_state_target.circle if not is_poly else env_state_target.polygon),
348
+ {k: normalise(k, v) for k, v in to_load_from.items()},
349
+ )
350
+
351
+ if is_poly:
352
+ env_state_target = env_state_target.replace(
353
+ polygon_shape_roles=env_state_target.polygon_shape_roles.at[i].set(role),
354
+ polygon_densities=env_state_target.polygon_densities.at[i].set(density),
355
+ polygon=jax.tree.map(
356
+ lambda all, new: all.at[i].set(astype(new, all)), env_state_target.polygon, new_obj
357
+ ),
358
+ )
359
+ else:
360
+ env_state_target = env_state_target.replace(
361
+ circle_shape_roles=env_state_target.circle_shape_roles.at[i].set(role),
362
+ circle_densities=env_state_target.circle_densities.at[i].set(density),
363
+ circle=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.circle, new_obj),
364
+ )
365
+ return env_state_target
366
+
367
+ # Now load the env state:
368
+ for i in range(new_static_env_params.num_circles):
369
+ env_state_target = _load_rigidbody(env_state_target, i, False)
370
+ for i in range(new_static_env_params.num_polygons):
371
+ env_state_target = _load_rigidbody(env_state_target, i, True)
372
+
373
+ for i in range(new_static_env_params.num_joints):
374
+ to_load_from = env_state["joint"][i]
375
+ motor_binding = to_load_from.pop("motor_binding")
376
+ new_obj = flax.serialization.from_state_dict(
377
+ jax.tree.map(lambda x: x[i], env_state_target.joint), {k: normalise(k, v) for k, v in to_load_from.items()}
378
+ )
379
+ env_state_target = env_state_target.replace(
380
+ joint=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.joint, new_obj),
381
+ motor_bindings=env_state_target.motor_bindings.at[i].set(motor_binding),
382
+ )
383
+
384
+ for i in range(new_static_env_params.num_thrusters):
385
+ to_load_from = env_state["thruster"][i]
386
+ thruster_binding = to_load_from.pop("thruster_binding")
387
+ new_obj = flax.serialization.from_state_dict(
388
+ jax.tree.map(lambda x: x[i], env_state_target.thruster),
389
+ {k: normalise(k, v) for k, v in to_load_from.items()},
390
+ )
391
+
392
+ env_state_target = env_state_target.replace(
393
+ thruster=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.thruster, new_obj),
394
+ thruster_bindings=env_state_target.thruster_bindings.at[i].set(thruster_binding),
395
+ )
396
+
397
+ env_state_target = env_state_target.replace(
398
+ collision_matrix=flax.serialization.from_state_dict(
399
+ env_state_target.collision_matrix, normalise("collision_matrix", env_state["collision_matrix"])
400
+ )
401
+ )
402
+
403
+ for i in range(env_state_target.acc_rr_manifolds.active.shape[0]):
404
+ a = flax.serialization.from_state_dict(
405
+ jax.tree.map(lambda x: x[i], env_state_target.acc_rr_manifolds),
406
+ {k: normalise(k, v) for k, v in env_state["acc_rr_manifolds"][i].items()},
407
+ )
408
+ b = flax.serialization.from_state_dict(
409
+ jax.tree.map(lambda x: x[i], env_state_target.acc_rr_manifolds),
410
+ {k: normalise(k, v) for k, v in env_state["acc_rr_manifolds"][i + 1].items()},
411
+ )
412
+ env_state_target = env_state_target.replace(
413
+ acc_rr_manifolds=jax.tree.map(
414
+ lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_rr_manifolds, a
415
+ ),
416
+ )
417
+ env_state_target.replace(
418
+ acc_rr_manifolds=jax.tree.map(
419
+ lambda all, new: all.at[i + 1].set(astype(new, all)), env_state_target.acc_rr_manifolds, b
420
+ )
421
+ )
422
+ for i in range(env_state_target.acc_cr_manifolds.active.shape[0]):
423
+ a = flax.serialization.from_state_dict(
424
+ jax.tree.map(lambda x: x[i], env_state_target.acc_cr_manifolds),
425
+ {k: normalise(k, v) for k, v in env_state["acc_cr_manifolds"][i].items()},
426
+ )
427
+ env_state_target = env_state_target.replace(
428
+ acc_cr_manifolds=jax.tree.map(
429
+ lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_cr_manifolds, a
430
+ ),
431
+ )
432
+ for i in range(env_state_target.acc_cc_manifolds.active.shape[0]):
433
+ a = flax.serialization.from_state_dict(
434
+ jax.tree.map(lambda x: x[i], env_state_target.acc_cc_manifolds),
435
+ {k: normalise(k, v) for k, v in env_state["acc_cc_manifolds"][i].items()},
436
+ )
437
+ env_state_target = env_state_target.replace(
438
+ acc_cc_manifolds=jax.tree.map(
439
+ lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_cc_manifolds, a
440
+ ),
441
+ )
442
+
443
+ env_state_target = env_state_target.replace(
444
+ collision_matrix=calculate_collision_matrix(new_static_env_params, env_state_target.joint)
445
+ )
446
+
447
+ return (
448
+ env_state_target,
449
+ new_static_env_params,
450
+ new_env_params.replace(max_timesteps=env_params_target.max_timesteps),
451
+ )
452
+
453
+
454
+ def export_env_state_to_json(
455
+ filename: str, env_state: EnvState, static_env_params: StaticEnvParams, env_params: EnvParams
456
+ ):
457
+ json_to_save = {
458
+ "polygon": [],
459
+ "circle": [],
460
+ "joint": [],
461
+ "thruster": [],
462
+ "collision_matrix": flax.serialization.to_state_dict(env_state.collision_matrix.tolist()),
463
+ "acc_rr_manifolds": [],
464
+ "acc_cr_manifolds": [],
465
+ "acc_cc_manifolds": [],
466
+ "gravity": flax.serialization.to_state_dict(env_state.gravity.tolist()),
467
+ }
468
+
469
+ def _rigidbody_to_json(index: int, is_poly):
470
+ main_arr = env_state.polygon if is_poly else env_state.circle
471
+ c = jax.tree.map(lambda x: x[index].tolist(), main_arr)
472
+ roles = env_state.polygon_shape_roles if is_poly else env_state.circle_shape_roles
473
+ densities = env_state.polygon_densities if is_poly else env_state.circle_densities
474
+ highlighted = env_state.polygon_highlighted if is_poly else env_state.circle_highlighted
475
+
476
+ d = flax.serialization.to_state_dict(c)
477
+ d["role"] = roles[index].tolist()
478
+ d["density"] = densities[index].tolist()
479
+ d["highlighted"] = highlighted[index].tolist()
480
+ return d
481
+
482
+ def _joint_to_json(i):
483
+ joint = jax.tree.map(lambda x: x[i].tolist(), env_state.joint)
484
+ d = flax.serialization.to_state_dict(joint)
485
+ d["motor_binding"] = env_state.motor_bindings[i].tolist()
486
+ return d
487
+
488
+ def _thruster_to_json(i):
489
+ thruster = jax.tree.map(lambda x: x[i].tolist(), env_state.thruster)
490
+ d = flax.serialization.to_state_dict(thruster)
491
+ d["thruster_binding"] = env_state.thruster_bindings[i].tolist()
492
+ return d
493
+
494
+ for i in range(static_env_params.num_circles):
495
+ json_to_save["circle"].append(_rigidbody_to_json(i, False))
496
+ for i in range(static_env_params.num_polygons):
497
+ json_to_save["polygon"].append(_rigidbody_to_json(i, True))
498
+ for i in range(static_env_params.num_joints):
499
+ json_to_save["joint"].append(_joint_to_json(i))
500
+ for i in range(static_env_params.num_thrusters):
501
+ json_to_save["thruster"].append(_thruster_to_json(i))
502
+
503
+ ncc, ncr, nrr, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(
504
+ static_env_params
505
+ )
506
+ for i in range(nrr):
507
+ a = jax.tree.map(lambda x: x[i, 0].tolist(), env_state.acc_rr_manifolds)
508
+ b = jax.tree.map(lambda x: x[i, 1].tolist(), env_state.acc_rr_manifolds)
509
+ json_to_save["acc_rr_manifolds"].append(flax.serialization.to_state_dict(a))
510
+ json_to_save["acc_rr_manifolds"].append(flax.serialization.to_state_dict(b))
511
+ for i in range(ncr):
512
+ a = jax.tree.map(lambda x: x[i].tolist(), env_state.acc_cr_manifolds)
513
+ json_to_save["acc_cr_manifolds"].append(flax.serialization.to_state_dict(a))
514
+
515
+ for i in range(ncc):
516
+ a = jax.tree.map(lambda x: x[i].tolist(), env_state.acc_cc_manifolds)
517
+ json_to_save["acc_cc_manifolds"].append(flax.serialization.to_state_dict(a))
518
+
519
+ to_save = {
520
+ "env_state": json_to_save,
521
+ "env_params": flax.serialization.to_state_dict(
522
+ jax.tree.map(lambda x: x.tolist() if type(x) == jnp.ndarray else x, env_params)
523
+ ),
524
+ "static_env_params": flax.serialization.to_state_dict(
525
+ jax.tree.map(lambda x: x.tolist() if type(x) == jnp.ndarray else x, static_env_params)
526
+ ),
527
+ }
528
+ with open(filename, "w+") as f:
529
+ json.dump(to_save, f)
530
+
531
+ return to_save
532
+
533
+
534
+ def load_from_json_file(filename):
535
+ with open(filename, "r") as f:
536
+ return import_env_state_from_json(json.load(f))
537
+
538
+
539
+ if __name__ == "__main__":
540
+ pass
kinetix/util/timing.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timeit import default_timer as tmr
2
+
3
+ counter = 0
4
+
5
+
6
+ def time_function(f, name):
7
+ global counter
8
+ t = "\t" * counter
9
+ # print(f"{t}Starting... {name}")
10
+ ss = tmr()
11
+ counter += 1
12
+ a = f()
13
+ counter -= 1
14
+ print(f"{t}{name} took {tmr() - ss} seconds")
15
+ return a