kinet-test / Kinetix /kinetix /render /renderer_symbolic_entity.py
tree3po's picture
Upload 190 files
e0f25ed verified
from cmath import rect
from functools import partial
import jax
import jax.numpy as jnp
from flax import struct
from jax2d.engine import get_pairwise_interaction_indices
from kinetix.environment.env_state import EnvState
from kinetix.render.renderer_symbolic_common import (
make_circle_features,
make_joint_features,
make_polygon_features,
make_thruster_features,
make_unified_shape_features,
)
@struct.dataclass
class EntityObservation:
circles: jnp.ndarray
polygons: jnp.ndarray
joints: jnp.ndarray
thrusters: jnp.ndarray
circle_mask: jnp.ndarray
polygon_mask: jnp.ndarray
joint_mask: jnp.ndarray
thruster_mask: jnp.ndarray
attention_mask: jnp.ndarray
# collision_mask: jnp.ndarray
joint_indexes: jnp.ndarray
thruster_indexes: jnp.ndarray
def make_render_entities(params, static_params):
_, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params)
circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons)
circle_circle_pairs = circle_circle_pairs + static_params.num_polygons
def render_entities(state: EnvState):
state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state)
joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params)
thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params)
poly_nodes, poly_mask = make_polygon_features(state, params, static_params)
circle_nodes, circle_mask = make_circle_features(state, params, static_params)
def _add_grav(nodes):
return jnp.concatenate(
[nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1
) # add gravity to each shape's embedding
poly_nodes = _add_grav(poly_nodes)
circle_nodes = _add_grav(circle_nodes)
# Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster )
mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0)
num_shapes = static_params.num_polygons + static_params.num_circles
def make_n_squared_mask(val):
# val has shape N of bools.
N = val.shape[0]
A = jnp.eye(N, N, dtype=bool) # also have things attend to themselves
# Make the shapes fully connected
full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool))
one_hop_connected = jnp.zeros((N, N), dtype=bool)
one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True)
one_hop_connected = one_hop_connected.at[0, 0].set(False) # invalid joints have indices of (0, 0)
multi_hop_connected = jnp.logical_not(state.collision_matrix)
collision_mask = state.collision_matrix
# where val is false, we want to mask out the row and column.
full_mask = full_mask & (val[:, None]) & (val[None, :])
collision_mask = collision_mask & (val[:, None]) & (val[None, :])
multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :])
one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :])
collision_manifold_mask = jnp.zeros_like(collision_mask)
def _set(collision_manifold_mask, pairs, active):
return collision_manifold_mask.at[
pairs[:, 0],
pairs[:, 1],
].set(active)
collision_manifold_mask = _set(
collision_manifold_mask,
rect_rect_pairs,
jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]),
)
collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active)
collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active)
collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :])
return jnp.concatenate(
[full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]],
axis=0,
)
mask_n_squared = make_n_squared_mask(mask_flat_shapes)
return EntityObservation(
circles=circle_nodes,
polygons=poly_nodes,
joints=joint_features,
thrusters=thruster_features,
circle_mask=circle_mask,
polygon_mask=poly_mask,
joint_mask=joint_mask,
thruster_mask=thruster_mask,
attention_mask=mask_n_squared,
joint_indexes=joint_indexes,
thruster_indexes=thruster_indexes,
)
return render_entities