Spaces:
Runtime error
Runtime error
from functools import partial | |
from jax2d.engine import recalculate_mass_and_inertia, recompute_global_joint_positions, select_shape | |
from kinetix.environment.env_state import EnvState, StaticEnvParams | |
from kinetix.pcg.pcg_state import PCGState | |
import jax | |
import jax.numpy as jnp | |
def _process_tied_together_shapes(pcg_state: PCGState, sampled_state: EnvState, static_params: StaticEnvParams): | |
# Get the matrix of tied together positions. Since we vmap, we only want one entry active for any (i, j, k). Thus, we mask out some of the duplicate ones. | |
tied = jnp.triu(pcg_state.tied_together & jnp.logical_not(jnp.eye(pcg_state.tied_together.shape[0], dtype=bool))) | |
has_anything_in_column = tied.any(axis=0) | |
tied = ( | |
tied * jnp.logical_not(has_anything_in_column)[:, None] | |
) # if there is something in a column, it means a previous one with a lower index has already been processed | |
should_use_delta_positions = tied.any(axis=0) | |
# This is the delta we have moved after sampling | |
delta_positions = jnp.concatenate( | |
[ | |
sampled_state.polygon.position - pcg_state.env_state.polygon.position, | |
sampled_state.circle.position - pcg_state.env_state.circle.position, | |
] | |
) | |
def _get_effect_of_shape_i_on_all_others(item_index, item_row_of_what_is_tied): | |
delta_pos = delta_positions[item_index] | |
return jnp.arange(len(item_row_of_what_is_tied)), delta_pos[None] * item_row_of_what_is_tied[:, None] | |
indices, positions = jax.vmap(_get_effect_of_shape_i_on_all_others, (0, 0))(jnp.arange(tied.shape[0]), tied) | |
indices = indices.flatten() | |
positions = positions.reshape(indices.shape[0], -1) | |
default_positions = jnp.concatenate( | |
[pcg_state.env_state.polygon.position, pcg_state.env_state.circle.position], axis=0 | |
) | |
sampled_positions = jnp.concatenate([sampled_state.polygon.position, sampled_state.circle.position], axis=0) | |
updated_positions = default_positions.at[indices].add(positions) | |
# Use the deltas or the sampled positions | |
positions = jnp.where(should_use_delta_positions[:, None], updated_positions, sampled_positions) | |
sampled_state = sampled_state.replace( | |
polygon=sampled_state.polygon.replace(position=positions[: static_params.num_polygons]), | |
circle=sampled_state.circle.replace(position=positions[static_params.num_polygons :]), | |
) | |
return sampled_state | |
def sample_pcg_state(rng, pcg_state: PCGState, params, static_params): | |
def _pcg_fn(rng, main_val, max_val, mask): | |
pcg_val = jax.random.uniform(rng, shape=main_val.shape) * ( | |
max_val.astype(float) - main_val.astype(float) | |
) + main_val.astype(float) | |
if jnp.issubdtype(main_val.dtype, jnp.integer) or jnp.issubdtype(main_val.dtype, jnp.bool_): | |
pcg_val = jnp.round(pcg_val) | |
pcg_val = pcg_val.astype(main_val.dtype) | |
new_val = jax.lax.select(mask.astype(bool), pcg_val, main_val) | |
return new_val | |
def _random_split_like_tree(rng, target): | |
tree_def = jax.tree_structure(target) | |
rngs = jax.random.split(rng, tree_def.num_leaves) | |
return jax.tree_unflatten(tree_def, rngs) | |
rng, _rng = jax.random.split(rng) | |
rng_tree = _random_split_like_tree(_rng, pcg_state.env_state) | |
sampled_state = jax.tree_util.tree_map( | |
_pcg_fn, rng_tree, pcg_state.env_state, pcg_state.env_state_max, pcg_state.env_state_pcg_mask | |
) | |
sampled_state = _process_tied_together_shapes(pcg_state, sampled_state, static_params) | |
sampled_state = recompute_global_joint_positions(sampled_state, static_params) | |
env_state = recalculate_mass_and_inertia( | |
sampled_state, static_params, sampled_state.polygon_densities, sampled_state.circle_densities | |
) | |
return env_state | |
def env_state_to_pcg_state(env_state: EnvState): | |
N = env_state.polygon.active.shape[0] + env_state.circle.active.shape[0] | |
pcg_state = PCGState( | |
env_state=env_state, | |
env_state_max=env_state, | |
env_state_pcg_mask=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x, dtype=bool), env_state), | |
tied_together=jnp.zeros((N, N), dtype=bool), | |
) | |
return pcg_state | |