Spaces:
Runtime error
Runtime error
from dataclasses import field | |
import jax.numpy as jnp | |
from flax import struct | |
from jax2d.sim_state import SimState, SimParams, StaticSimParams | |
class EnvState(SimState): | |
thruster_bindings: jnp.ndarray | |
motor_bindings: jnp.ndarray | |
motor_auto: jnp.ndarray | |
polygon_shape_roles: jnp.ndarray | |
circle_shape_roles: jnp.ndarray | |
polygon_highlighted: jnp.ndarray | |
circle_highlighted: jnp.ndarray | |
polygon_densities: jnp.ndarray | |
circle_densities: jnp.ndarray | |
timestep: int = 0 | |
class EnvParams(SimParams): | |
max_timesteps: int = 256 | |
pixels_per_unit: int = 100 | |
dense_reward_scale: float = 0.1 | |
num_shape_roles: int = 4 | |
class StaticEnvParams(StaticSimParams): | |
screen_dim: tuple[int, int] = (500, 500) | |
downscale: int = 4 | |
frame_skip: int = 1 | |
max_shape_size: int = 2 | |
num_motor_bindings: int = 4 | |
num_thruster_bindings: int = 2 | |