Spaces:
Runtime error
Runtime error
""" | |
Based on PureJaxRL Implementation of PPO | |
""" | |
import os | |
import sys | |
import time | |
import typing | |
from functools import partial | |
from typing import NamedTuple | |
import chex | |
import hydra | |
import jax | |
import jax.experimental | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import optax | |
from flax.training.train_state import TrainState | |
from kinetix.environment.ued.ued import make_reset_train_function_with_mutations, make_vmapped_filtered_level_sampler | |
from kinetix.environment.ued.ued import ( | |
make_reset_train_function_with_list_of_levels, | |
make_reset_train_function_with_mutations, | |
) | |
from kinetix.util.config import ( | |
generate_ued_params_from_config, | |
init_wandb, | |
normalise_config, | |
generate_params_from_config, | |
get_eval_level_groups, | |
) | |
from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from flax.serialization import to_state_dict | |
import wandb | |
from kinetix.environment.env import make_kinetix_env_from_name | |
from kinetix.environment.wrappers import ( | |
AutoReplayWrapper, | |
DenseRewardWrapper, | |
LogWrapper, | |
UnderspecifiedToGymnaxWrapper, | |
) | |
from kinetix.models import make_network_from_config | |
from kinetix.models.actor_critic import ScannedRNN | |
from kinetix.render.renderer_pixels import make_render_pixels | |
from kinetix.util.learning import general_eval, get_eval_levels | |
from kinetix.util.saving import ( | |
load_train_state_from_wandb_artifact_path, | |
save_model_to_wandb, | |
) | |
sys.path.append("ued") | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from safetensors.flax import load_file, save_file | |
def save_params(params: typing.Dict, filename: typing.Union[str, os.PathLike]) -> None: | |
flattened_dict = flatten_dict(params, sep=",") | |
save_file(flattened_dict, filename) | |
def load_params(filename: typing.Union[str, os.PathLike]) -> typing.Dict: | |
flattened_dict = load_file(filename) | |
return unflatten_dict(flattened_dict, sep=",") | |
class Transition(NamedTuple): | |
global_done: jnp.ndarray | |
done: jnp.ndarray | |
action: jnp.ndarray | |
value: jnp.ndarray | |
reward: jnp.ndarray | |
log_prob: jnp.ndarray | |
obs: jnp.ndarray | |
info: jnp.ndarray | |
class RolloutBatch(NamedTuple): | |
obs: jnp.ndarray | |
actions: jnp.ndarray | |
rewards: jnp.ndarray | |
dones: jnp.ndarray | |
log_probs: jnp.ndarray | |
values: jnp.ndarray | |
targets: jnp.ndarray | |
advantages: jnp.ndarray | |
# carry: jnp.ndarray | |
mask: jnp.ndarray | |
def evaluate_rnn( | |
rng: chex.PRNGKey, | |
env: UnderspecifiedEnv, | |
env_params: EnvParams, | |
train_state: TrainState, | |
init_hstate: chex.ArrayTree, | |
init_obs: Observation, | |
init_env_state: EnvState, | |
max_episode_length: int, | |
keep_states=True, | |
) -> tuple[chex.Array, chex.Array, chex.Array]: | |
"""This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths) | |
Args: | |
rng (chex.PRNGKey): | |
env (UnderspecifiedEnv): | |
env_params (EnvParams): | |
train_state (TrainState): | |
init_hstate (chex.ArrayTree): Shape (num_levels, ) | |
init_obs (Observation): Shape (num_levels, ) | |
init_env_state (EnvState): Shape (num_levels, ) | |
max_episode_length (int): | |
Returns: | |
Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,) | |
""" | |
num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0] | |
def step(carry, _): | |
rng, hstate, obs, state, done, mask, episode_length = carry | |
rng, rng_action, rng_step = jax.random.split(rng, 3) | |
x = jax.tree.map(lambda x: x[None, ...], (obs, done)) | |
hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x) | |
action = pi.sample(seed=rng_action).squeeze(0) | |
obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
jax.random.split(rng_step, num_levels), state, action, env_params | |
) | |
next_mask = mask & ~done | |
episode_length += mask | |
if keep_states: | |
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, info) | |
else: | |
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, info) | |
(_, _, _, _, _, _, episode_lengths), (states, rewards, infos) = jax.lax.scan( | |
step, | |
( | |
rng, | |
init_hstate, | |
init_obs, | |
init_env_state, | |
jnp.zeros(num_levels, dtype=bool), | |
jnp.ones(num_levels, dtype=bool), | |
jnp.zeros(num_levels, dtype=jnp.int32), | |
), | |
None, | |
length=max_episode_length, | |
) | |
return states, rewards, episode_lengths, infos | |
def main(config): | |
time_start = time.time() | |
config = OmegaConf.to_container(config) | |
config = normalise_config(config, "SFL" if config["ued"]["sampled_envs_ratio"] > 0 else "SFL-DR") | |
env_params, static_env_params = generate_params_from_config(config) | |
config["env_params"] = to_state_dict(env_params) | |
config["static_env_params"] = to_state_dict(static_env_params) | |
run = init_wandb(config, "SFL") | |
rng = jax.random.PRNGKey(config["seed"]) | |
config["num_envs_from_sampled"] = int(config["num_train_envs"] * config["sampled_envs_ratio"]) | |
config["num_envs_to_generate"] = int(config["num_train_envs"] * (1 - config["sampled_envs_ratio"])) | |
assert (config["num_envs_from_sampled"] + config["num_envs_to_generate"]) == config["num_train_envs"] | |
def make_env(static_env_params): | |
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params) | |
env = AutoReplayWrapper(env) | |
env = UnderspecifiedToGymnaxWrapper(env) | |
env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"]) | |
env = LogWrapper(env) | |
return env | |
env = make_env(static_env_params) | |
if config["train_level_mode"] == "list": | |
sample_random_level = make_reset_train_function_with_list_of_levels( | |
config, config["train_levels"], static_env_params, make_pcg_state=False, is_loading_train_levels=True | |
) | |
elif config["train_level_mode"] == "random": | |
sample_random_level = make_reset_train_function_with_mutations( | |
env.physics_engine, env_params, static_env_params, config, make_pcg_state=False | |
) | |
else: | |
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}") | |
sample_random_levels = make_vmapped_filtered_level_sampler( | |
sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env | |
) | |
_, eval_static_env_params = generate_params_from_config( | |
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]} | |
) | |
eval_env = make_env(eval_static_env_params) | |
ued_params = generate_ued_params_from_config(config) | |
def make_render_fn(static_env_params): | |
render_fn_inner = make_render_pixels(env_params, static_env_params) | |
render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1] | |
return render_fn | |
render_fn = make_render_fn(static_env_params) | |
render_fn_eval = make_render_fn(eval_static_env_params) | |
NUM_EVAL_DR_LEVELS = 200 | |
key_to_sample_dr_eval_set = jax.random.PRNGKey(100) | |
DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS) | |
print("Hello here num steps is ", config["num_steps"]) | |
print("CONFIG is ", config) | |
config["total_timesteps"] = config["num_updates"] * config["num_steps"] * config["num_train_envs"] | |
config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"] | |
config["clip_eps"] = config["clip_eps"] | |
config["env_name"] = config["env_name"] | |
network = make_network_from_config(env, env_params, config) | |
def linear_schedule(count): | |
count = count // (config["num_minibatches"] * config["update_epochs"]) | |
frac = 1.0 - count / config["num_updates"] | |
return config["lr"] * frac | |
# INIT NETWORK | |
rng, _rng = jax.random.split(rng) | |
train_envs = 32 # To not run out of memory, the initial sample size does not matter. | |
obs, _ = env.reset_to_level(rng, sample_random_level(rng), env_params) | |
obs = jax.tree.map( | |
lambda x: jnp.repeat(jnp.repeat(x[None, ...], train_envs, axis=0)[None, ...], 256, axis=0), | |
obs, | |
) | |
init_x = (obs, jnp.zeros((256, train_envs))) | |
init_hstate = ScannedRNN.initialize_carry(train_envs) | |
network_params = network.init(_rng, init_hstate, init_x) | |
if config["anneal_lr"]: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adam(learning_rate=linear_schedule, eps=1e-5), | |
) | |
else: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adam(config["lr"], eps=1e-5), | |
) | |
train_state = TrainState.create( | |
apply_fn=network.apply, | |
params=network_params, | |
tx=tx, | |
) | |
if config["load_from_checkpoint"] != None: | |
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"]) | |
train_state = load_train_state_from_wandb_artifact_path( | |
train_state, | |
config["load_from_checkpoint"], | |
load_only_params=config["load_only_params"], | |
legacy=config["load_legacy_checkpoint"], | |
) | |
rng, _rng = jax.random.split(rng) | |
# INIT ENV | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
rng_reset = jax.random.split(_rng, config["num_train_envs"]) | |
new_levels = sample_random_levels(_rng2, config["num_train_envs"]) | |
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params) | |
start_state = env_state | |
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
def log_buffer_learnability(rng, train_state, instances): | |
BATCH_SIZE = config["num_to_save"] | |
BATCH_ACTORS = BATCH_SIZE | |
def _batch_step(unused, rng): | |
def _env_step(runner_state, unused): | |
env_state, start_state, last_obs, last_done, hstate, rng = runner_state | |
# SELECT ACTION | |
rng, _rng = jax.random.split(rng) | |
obs_batch = last_obs | |
ac_in = ( | |
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch), | |
last_done[np.newaxis, :], | |
) | |
hstate, pi, value = network.apply(train_state.params, hstate, ac_in) | |
action = pi.sample(seed=_rng).squeeze() | |
log_prob = pi.log_prob(action) | |
env_act = action | |
# STEP ENV | |
rng, _rng = jax.random.split(rng) | |
rng_step = jax.random.split(_rng, config["num_to_save"]) | |
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
rng_step, env_state, env_act, env_params | |
) | |
done_batch = done | |
transition = Transition( | |
done, | |
last_done, | |
action.squeeze(), | |
value.squeeze(), | |
reward, | |
log_prob.squeeze(), | |
obs_batch, | |
info, | |
) | |
runner_state = (env_state, start_state, obsv, done_batch, hstate, rng) | |
return runner_state, transition | |
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info): | |
idxs = jnp.arange(max_steps) | |
def __ep_outcomes(start_idx, end_idx): | |
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps) | |
r = jnp.sum(returns * mask) | |
goal_r = info["GoalR"] # (returns > 0) * 1.0 | |
success = jnp.sum(goal_r * mask) | |
l = end_idx - start_idx | |
return r, success, l | |
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze() | |
mask_done = jnp.where(done_idxs == max_steps, 0, 1) | |
ep_return, success, length = __ep_outcomes( | |
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs | |
) | |
return { | |
"ep_return": ep_return.mean(where=mask_done), | |
"num_episodes": mask_done.sum(), | |
"success_rate": success.mean(where=mask_done), | |
"ep_len": length.mean(where=mask_done), | |
} | |
# sample envs | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
rng_reset = jax.random.split(_rng, config["num_to_save"]) | |
rng_levels = jax.random.split(_rng2, config["num_to_save"]) | |
# obsv, env_state = jax.vmap(sample_random_level, in_axes=(0,))(reset_rng) | |
# new_levels = jax.vmap(sample_random_level)(rng_levels) | |
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, instances, env_params) | |
# env_instances = new_levels | |
init_hstate = ScannedRNN.initialize_carry( | |
BATCH_ACTORS, | |
) | |
runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng) | |
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"]) | |
done_by_env = traj_batch.done.reshape((-1, config["num_to_save"])) | |
reward_by_env = traj_batch.reward.reshape((-1, config["num_to_save"])) | |
# info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info) | |
o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info) | |
success_by_env = o["success_rate"].reshape((1, config["num_to_save"])) | |
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0) | |
return None, (learnability_by_env, success_by_env.sum(axis=0)) | |
rngs = jax.random.split(rng, 1) | |
_, (learnability, success_by_env) = jax.lax.scan(_batch_step, None, rngs, 1) | |
return learnability[0], success_by_env[0] | |
num_eval_levels = len(config["eval_levels"]) | |
all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params) | |
eval_group_indices = get_eval_level_groups(config["eval_levels"]) | |
print("group indices", eval_group_indices) | |
def get_learnability_set(rng, network_params): | |
BATCH_ACTORS = config["batch_size"] | |
def _batch_step(unused, rng): | |
def _env_step(runner_state, unused): | |
env_state, start_state, last_obs, last_done, hstate, rng = runner_state | |
# SELECT ACTION | |
rng, _rng = jax.random.split(rng) | |
obs_batch = last_obs | |
ac_in = ( | |
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch), | |
last_done[np.newaxis, :], | |
) | |
hstate, pi, value = network.apply(network_params, hstate, ac_in) | |
action = pi.sample(seed=_rng).squeeze() | |
log_prob = pi.log_prob(action) | |
env_act = action | |
# STEP ENV | |
rng, _rng = jax.random.split(rng) | |
rng_step = jax.random.split(_rng, config["batch_size"]) | |
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
rng_step, env_state, env_act, env_params | |
) | |
done_batch = done | |
transition = Transition( | |
done, | |
last_done, | |
action.squeeze(), | |
value.squeeze(), | |
reward, | |
log_prob.squeeze(), | |
obs_batch, | |
info, | |
) | |
runner_state = (env_state, start_state, obsv, done_batch, hstate, rng) | |
return runner_state, transition | |
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info): | |
idxs = jnp.arange(max_steps) | |
def __ep_outcomes(start_idx, end_idx): | |
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps) | |
r = jnp.sum(returns * mask) | |
goal_r = info["GoalR"] # (returns > 0) * 1.0 | |
success = jnp.sum(goal_r * mask) | |
l = end_idx - start_idx | |
return r, success, l | |
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze() | |
mask_done = jnp.where(done_idxs == max_steps, 0, 1) | |
ep_return, success, length = __ep_outcomes( | |
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs | |
) | |
return { | |
"ep_return": ep_return.mean(where=mask_done), | |
"num_episodes": mask_done.sum(), | |
"success_rate": success.mean(where=mask_done), | |
"ep_len": length.mean(where=mask_done), | |
} | |
# sample envs | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
rng_reset = jax.random.split(_rng, config["batch_size"]) | |
new_levels = sample_random_levels(_rng2, config["batch_size"]) | |
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params) | |
env_instances = new_levels | |
init_hstate = ScannedRNN.initialize_carry( | |
BATCH_ACTORS, | |
) | |
runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng) | |
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"]) | |
done_by_env = traj_batch.done.reshape((-1, config["batch_size"])) | |
reward_by_env = traj_batch.reward.reshape((-1, config["batch_size"])) | |
# info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info) | |
o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info) | |
success_by_env = o["success_rate"].reshape((1, config["batch_size"])) | |
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0) | |
return None, (learnability_by_env, success_by_env.sum(axis=0), env_instances) | |
if config["sampled_envs_ratio"] == 0.0: | |
print("Not doing any rollouts because sampled_envs_ratio is 0.0") | |
# Here we have zero envs, so we can literally just sample random ones because there is no point. | |
top_instances = sample_random_levels(_rng, config["num_to_save"]) | |
top_success = top_learn = learnability = success_rates = jnp.zeros(config["num_to_save"]) | |
else: | |
rngs = jax.random.split(rng, config["num_batches"]) | |
_, (learnability, success_rates, env_instances) = jax.lax.scan( | |
_batch_step, None, rngs, config["num_batches"] | |
) | |
flat_env_instances = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), env_instances) | |
learnability = learnability.flatten() + success_rates.flatten() * 0.001 | |
top_1000 = jnp.argsort(learnability)[-config["num_to_save"] :] | |
top_1000_instances = jax.tree.map(lambda x: x.at[top_1000].get(), flat_env_instances) | |
top_learn, top_instances = learnability.at[top_1000].get(), top_1000_instances | |
top_success = success_rates.at[top_1000].get() | |
if config["put_eval_levels_in_buffer"]: | |
top_instances = jax.tree.map( | |
lambda all, new: jnp.concatenate([all[:-num_eval_levels], new], axis=0), | |
top_instances, | |
all_eval_levels.env_state, | |
) | |
log = { | |
"learnability/learnability_sampled_mean": learnability.mean(), | |
"learnability/learnability_sampled_median": jnp.median(learnability), | |
"learnability/learnability_sampled_min": learnability.min(), | |
"learnability/learnability_sampled_max": learnability.max(), | |
"learnability/learnability_selected_mean": top_learn.mean(), | |
"learnability/learnability_selected_median": jnp.median(top_learn), | |
"learnability/learnability_selected_min": top_learn.min(), | |
"learnability/learnability_selected_max": top_learn.max(), | |
"learnability/solve_rate_sampled_mean": top_success.mean(), | |
"learnability/solve_rate_sampled_median": jnp.median(top_success), | |
"learnability/solve_rate_sampled_min": top_success.min(), | |
"learnability/solve_rate_sampled_max": top_success.max(), | |
"learnability/solve_rate_selected_mean": success_rates.mean(), | |
"learnability/solve_rate_selected_median": jnp.median(success_rates), | |
"learnability/solve_rate_selected_min": success_rates.min(), | |
"learnability/solve_rate_selected_max": success_rates.max(), | |
} | |
return top_learn, top_instances, log | |
def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True): | |
""" | |
This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"]. | |
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,) | |
""" | |
num_levels = len(config["eval_levels"]) | |
# eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params) | |
return general_eval( | |
rng, | |
eval_env, | |
env_params, | |
train_state, | |
all_eval_levels, | |
env_params.max_timesteps, | |
num_levels, | |
keep_states=keep_states, | |
return_trajectories=True, | |
) | |
def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False): | |
return general_eval( | |
rng, | |
env, | |
env_params, | |
train_state, | |
DR_EVAL_LEVELS, | |
env_params.max_timesteps, | |
NUM_EVAL_DR_LEVELS, | |
keep_states=keep_states, | |
) | |
def eval_on_top_learnable_levels(rng: chex.PRNGKey, train_state: TrainState, levels, keep_states=True): | |
N = 5 | |
return general_eval( | |
rng, | |
env, | |
env_params, | |
train_state, | |
jax.tree.map(lambda x: x[:N], levels), | |
env_params.max_timesteps, | |
N, | |
keep_states=keep_states, | |
) | |
# TRAIN LOOP | |
def train_step(runner_state_instances, unused): | |
# COLLECT TRAJECTORIES | |
runner_state, instances = runner_state_instances | |
num_env_instances = instances.polygon.position.shape[0] | |
def _env_step(runner_state, unused): | |
train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state | |
# SELECT ACTION | |
rng, _rng = jax.random.split(rng) | |
obs_batch = last_obs | |
ac_in = ( | |
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch), | |
last_done[np.newaxis, :], | |
) | |
hstate, pi, value = network.apply(train_state.params, hstate, ac_in) | |
action = pi.sample(seed=_rng).squeeze() | |
log_prob = pi.log_prob(action) | |
env_act = action | |
# STEP ENV | |
rng, _rng = jax.random.split(rng) | |
rng_step = jax.random.split(_rng, config["num_train_envs"]) | |
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
rng_step, env_state, env_act, env_params | |
) | |
done_batch = done | |
transition = Transition( | |
done, | |
last_done, | |
action.squeeze(), | |
value.squeeze(), | |
reward, | |
log_prob.squeeze(), | |
obs_batch, | |
info, | |
) | |
runner_state = (train_state, env_state, start_state, obsv, done_batch, hstate, update_steps, rng) | |
return runner_state, (transition) | |
initial_hstate = runner_state[-3] | |
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"]) | |
# CALCULATE ADVANTAGE | |
train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state | |
last_obs_batch = last_obs # batchify(last_obs, env.agents, config["num_train_envs"]) | |
ac_in = ( | |
jax.tree.map(lambda x: x[np.newaxis, :], last_obs_batch), | |
last_done[np.newaxis, :], | |
) | |
_, _, last_val = network.apply(train_state.params, hstate, ac_in) | |
last_val = last_val.squeeze() | |
def _calculate_gae(traj_batch, last_val): | |
def _get_advantages(gae_and_next_value, transition: Transition): | |
gae, next_value = gae_and_next_value | |
done, value, reward = ( | |
transition.global_done, | |
transition.value, | |
transition.reward, | |
) | |
delta = reward + config["gamma"] * next_value * (1 - done) - value | |
gae = delta + config["gamma"] * config["gae_lambda"] * (1 - done) * gae | |
return (gae, value), gae | |
_, advantages = jax.lax.scan( | |
_get_advantages, | |
(jnp.zeros_like(last_val), last_val), | |
traj_batch, | |
reverse=True, | |
unroll=16, | |
) | |
return advantages, advantages + traj_batch.value | |
advantages, targets = _calculate_gae(traj_batch, last_val) | |
# UPDATE NETWORK | |
def _update_epoch(update_state, unused): | |
def _update_minbatch(train_state, batch_info): | |
init_hstate, traj_batch, advantages, targets = batch_info | |
def _loss_fn_masked(params, init_hstate, traj_batch, gae, targets): | |
# RERUN NETWORK | |
_, pi, value = network.apply( | |
params, | |
jax.tree.map(lambda x: x.transpose(), init_hstate), | |
(traj_batch.obs, traj_batch.done), | |
) | |
log_prob = pi.log_prob(traj_batch.action) | |
# CALCULATE VALUE LOSS | |
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( | |
-config["clip_eps"], config["clip_eps"] | |
) | |
value_losses = jnp.square(value - targets) | |
value_losses_clipped = jnp.square(value_pred_clipped - targets) | |
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped) | |
critic_loss = config["vf_coef"] * value_loss.mean() | |
# CALCULATE ACTOR LOSS | |
logratio = log_prob - traj_batch.log_prob | |
ratio = jnp.exp(logratio) | |
# if env.do_sep_reward: gae = gae.sum(axis=-1) | |
gae = (gae - gae.mean()) / (gae.std() + 1e-8) | |
loss_actor1 = ratio * gae | |
loss_actor2 = ( | |
jnp.clip( | |
ratio, | |
1.0 - config["clip_eps"], | |
1.0 + config["clip_eps"], | |
) | |
* gae | |
) | |
loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | |
loss_actor = loss_actor.mean() | |
entropy = pi.entropy().mean() | |
approx_kl = jax.lax.stop_gradient(((ratio - 1) - logratio).mean()) | |
clipfrac = jax.lax.stop_gradient((jnp.abs(ratio - 1) > config["clip_eps"]).mean()) | |
total_loss = loss_actor + critic_loss - config["ent_coef"] * entropy | |
return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clipfrac) | |
grad_fn = jax.value_and_grad(_loss_fn_masked, has_aux=True) | |
total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets) | |
train_state = train_state.apply_gradients(grads=grads) | |
return train_state, total_loss | |
( | |
train_state, | |
init_hstate, | |
traj_batch, | |
advantages, | |
targets, | |
rng, | |
) = update_state | |
rng, _rng = jax.random.split(rng) | |
init_hstate = jax.tree.map(lambda x: jnp.reshape(x, (256, config["num_train_envs"])), init_hstate) | |
batch = ( | |
init_hstate, | |
traj_batch, | |
advantages.squeeze(), | |
targets.squeeze(), | |
) | |
permutation = jax.random.permutation(_rng, config["num_train_envs"]) | |
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch) | |
minibatches = jax.tree_util.tree_map( | |
lambda x: jnp.swapaxes( | |
jnp.reshape( | |
x, | |
[x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]), | |
), | |
1, | |
0, | |
), | |
shuffled_batch, | |
) | |
train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches) | |
# total_loss = jax.tree.map(lambda x: x.mean(), total_loss) | |
update_state = ( | |
train_state, | |
init_hstate, | |
traj_batch, | |
advantages, | |
targets, | |
rng, | |
) | |
return update_state, total_loss | |
# init_hstate = initial_hstate[None, :].squeeze().transpose() | |
init_hstate = jax.tree.map(lambda x: x[None, :].squeeze().transpose(), initial_hstate) | |
update_state = ( | |
train_state, | |
init_hstate, | |
traj_batch, | |
advantages, | |
targets, | |
rng, | |
) | |
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"]) | |
train_state = update_state[0] | |
metric = traj_batch.info | |
metric = jax.tree.map( | |
lambda x: x.reshape((config["num_steps"], config["num_train_envs"])), # , env.num_agents | |
traj_batch.info, | |
) | |
rng = update_state[-1] | |
def callback(metric): | |
dones = metric["dones"] | |
wandb.log( | |
{ | |
"episode_return": (metric["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()), | |
"episode_solved": (metric["returned_episode_solved"] * dones).sum() / jnp.maximum(1, dones.sum()), | |
"episode_length": (metric["returned_episode_lengths"] * dones).sum() / jnp.maximum(1, dones.sum()), | |
"timing/num_env_steps": int( | |
int(metric["update_steps"]) * int(config["num_train_envs"]) * int(config["num_steps"]) | |
), | |
"timing/num_updates": metric["update_steps"], | |
**metric["loss_info"], | |
} | |
) | |
loss_info = jax.tree.map(lambda x: x.mean(), loss_info) | |
metric["loss_info"] = { | |
"loss/total_loss": loss_info[0], | |
"loss/value_loss": loss_info[1][0], | |
"loss/policy_loss": loss_info[1][1], | |
"loss/entropy_loss": loss_info[1][2], | |
} | |
metric["dones"] = traj_batch.done | |
metric["update_steps"] = update_steps | |
jax.experimental.io_callback(callback, None, metric) | |
# SAMPLE NEW ENVS | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
rng_reset = jax.random.split(_rng, config["num_envs_to_generate"]) | |
new_levels = sample_random_levels(_rng2, config["num_envs_to_generate"]) | |
obsv_gen, env_state_gen = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params) | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
sampled_env_instances_idxs = jax.random.randint(_rng, (config["num_envs_from_sampled"],), 0, num_env_instances) | |
sampled_env_instances = jax.tree.map(lambda x: x.at[sampled_env_instances_idxs].get(), instances) | |
myrng = jax.random.split(_rng2, config["num_envs_from_sampled"]) | |
obsv_sampled, env_state_sampled = jax.vmap(env.reset_to_level, in_axes=(0, 0))(myrng, sampled_env_instances) | |
obsv = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), obsv_gen, obsv_sampled) | |
env_state = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), env_state_gen, env_state_sampled) | |
start_state = env_state | |
hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
update_steps = update_steps + 1 | |
runner_state = ( | |
train_state, | |
env_state, | |
start_state, | |
obsv, | |
jnp.zeros((config["num_train_envs"]), dtype=bool), | |
hstate, | |
update_steps, | |
rng, | |
) | |
return (runner_state, instances), metric | |
def log_buffer(learnability, levels, epoch): | |
num_samples = levels.polygon.position.shape[0] | |
states = levels | |
rows = 2 | |
fig, axes = plt.subplots(rows, int(num_samples / rows), figsize=(20, 10)) | |
axes = axes.flatten() | |
all_imgs = jax.vmap(render_fn)(states) | |
for i, ax in enumerate(axes): | |
# ax.imshow(train_state.plr_buffer.get_sample(i)) | |
score = learnability[i] | |
ax.imshow(all_imgs[i] / 255.0) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_title(f"learnability: {score:.3f}") | |
ax.set_aspect("equal", "box") | |
plt.tight_layout() | |
fig.canvas.draw() | |
im = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
plt.close() | |
return {"maps": wandb.Image(im)} | |
def train_and_eval_step(runner_state, eval_rng): | |
learnability_rng, eval_singleton_rng, eval_sampled_rng, _rng = jax.random.split(eval_rng, 4) | |
# TRAIN | |
learnabilty_scores, instances, test_metrics = get_learnability_set(learnability_rng, runner_state[0].params) | |
if config["log_learnability_before_after"]: | |
learn_scores_before, success_score_before = log_buffer_learnability( | |
learnability_rng, runner_state[0], instances | |
) | |
print("instance size", sum(x.size for x in jax.tree_util.tree_leaves(instances))) | |
runner_state_instances = (runner_state, instances) | |
runner_state_instances, metrics = jax.lax.scan(train_step, runner_state_instances, None, config["eval_freq"]) | |
if config["log_learnability_before_after"]: | |
learn_scores_after, success_score_after = log_buffer_learnability( | |
learnability_rng, runner_state_instances[0][0], instances | |
) | |
# EVAL | |
rng, rng_eval = jax.random.split(eval_singleton_rng) | |
(states, cum_rewards, _, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(eval, (0, None))( | |
jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0] | |
) | |
all_eval_eplens = episode_lengths | |
# Collect Metrics | |
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,) | |
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum( | |
1, eval_dones.sum(axis=1) | |
) | |
eval_solves = eval_solves.mean(axis=0) | |
# just grab the first run | |
states, episode_lengths = jax.tree_util.tree_map( | |
lambda x: x[0], (states, episode_lengths) | |
) # (num_steps, num_eval_levels, ...), (num_eval_levels,) | |
# And one attempt | |
states = jax.tree_util.tree_map(lambda x: x[:, :], states) | |
episode_lengths = episode_lengths[:] | |
images = jax.vmap(jax.vmap(render_fn_eval))( | |
states.env_state.env_state.env_state | |
) # (num_steps, num_eval_levels, ...) | |
frames = images.transpose( | |
0, 1, 4, 2, 3 | |
) # WandB expects color channel before image dimensions when dealing with animations for some reason | |
test_metrics["update_count"] = runner_state[-2] | |
test_metrics["eval_returns"] = eval_returns | |
test_metrics["eval_ep_lengths"] = episode_lengths | |
test_metrics["eval_animation"] = (frames, episode_lengths) | |
# Eval on sampled | |
dr_states, dr_cum_rewards, _, dr_episode_lengths, dr_infos = jax.vmap(eval_on_dr_levels, (0, None))( | |
jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0] | |
) | |
eval_dr_returns = dr_cum_rewards.mean(axis=0).mean() | |
eval_dr_eplen = dr_episode_lengths.mean(axis=0).mean() | |
test_metrics["eval/mean_eval_return_sampled"] = eval_dr_returns | |
my_eval_dones = dr_infos["returned_episode"] | |
eval_dr_solves = (dr_infos["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum( | |
1, my_eval_dones.sum(axis=1) | |
) | |
test_metrics["eval/mean_eval_solve_rate_sampled"] = eval_dr_solves | |
test_metrics["eval/mean_eval_eplen_sampled"] = eval_dr_eplen | |
# Collect Metrics | |
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,) | |
log_dict = {} | |
log_dict["to_remove"] = { | |
"eval_return": eval_returns, | |
"eval_solve_rate": eval_solves, | |
"eval_eplen": all_eval_eplens, | |
} | |
for i, name in enumerate(config["eval_levels"]): | |
log_dict[f"eval_avg_return/{name}"] = eval_returns[i] | |
log_dict[f"eval_avg_solve_rate/{name}"] = eval_solves[i] | |
log_dict.update({"eval/mean_eval_return": eval_returns.mean()}) | |
log_dict.update({"eval/mean_eval_solve_rate": eval_solves.mean()}) | |
log_dict.update({"eval/mean_eval_eplen": all_eval_eplens.mean()}) | |
test_metrics.update(log_dict) | |
runner_state, _ = runner_state_instances | |
test_metrics["update_count"] = runner_state[-2] | |
top_instances = jax.tree.map(lambda x: x.at[-5:].get(), instances) | |
# Eval on top learnable levels | |
tl_states, tl_cum_rewards, _, tl_episode_lengths, tl_infos = jax.vmap( | |
eval_on_top_learnable_levels, (0, None, None) | |
)(jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0], top_instances) | |
# just grab the first run | |
states, episode_lengths = jax.tree_util.tree_map( | |
lambda x: x[0], (tl_states, tl_episode_lengths) | |
) # (num_steps, num_eval_levels, ...), (num_eval_levels,) | |
# And one attempt | |
states = jax.tree_util.tree_map(lambda x: x[:, :], states) | |
episode_lengths = episode_lengths[:] | |
images = jax.vmap(jax.vmap(render_fn))( | |
states.env_state.env_state.env_state | |
) # (num_steps, num_eval_levels, ...) | |
frames = images.transpose( | |
0, 1, 4, 2, 3 | |
) # WandB expects color channel before image dimensions when dealing with animations for some reason | |
test_metrics["top_learnable_animation"] = (frames, episode_lengths, tl_cum_rewards) | |
if config["log_learnability_before_after"]: | |
def single(x, name): | |
return { | |
f"{name}_mean": x.mean(), | |
f"{name}_std": x.std(), | |
f"{name}_min": x.min(), | |
f"{name}_max": x.max(), | |
f"{name}_median": jnp.median(x), | |
} | |
test_metrics["learnability_log_v2/"] = { | |
**single(learn_scores_before, "learnability_before"), | |
**single(learn_scores_after, "learnability_after"), | |
**single(success_score_before, "success_score_before"), | |
**single(success_score_after, "success_score_after"), | |
} | |
return runner_state, (learnabilty_scores.at[-20:].get(), top_instances), test_metrics | |
rng, _rng = jax.random.split(rng) | |
runner_state = ( | |
train_state, | |
env_state, | |
start_state, | |
obsv, | |
jnp.zeros((config["num_train_envs"]), dtype=bool), | |
init_hstate, | |
0, | |
_rng, | |
) | |
def log_eval(stats): | |
log_dict = {} | |
to_remove = stats["to_remove"] | |
del stats["to_remove"] | |
def _aggregate_per_size(values, name): | |
to_return = {} | |
for group_name, indices in eval_group_indices.items(): | |
to_return[f"{name}_{group_name}"] = values[indices].mean() | |
return to_return | |
env_steps = stats["update_count"] * config["num_train_envs"] * config["num_steps"] | |
env_steps_delta = config["eval_freq"] * config["num_train_envs"] * config["num_steps"] | |
time_now = time.time() | |
log_dict = { | |
"timing/num_updates": stats["update_count"], | |
"timing/num_env_steps": env_steps, | |
"timing/sps": env_steps_delta / stats["time_delta"], | |
"timing/sps_agg": env_steps / (time_now - time_start), | |
} | |
log_dict.update(_aggregate_per_size(to_remove["eval_return"], "eval_aggregate/return")) | |
log_dict.update(_aggregate_per_size(to_remove["eval_solve_rate"], "eval_aggregate/solve_rate")) | |
for i in range((len(config["eval_levels"]))): | |
frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i] | |
frames = np.array(frames[:episode_length]) | |
log_dict.update( | |
{ | |
f"media/eval_video_{config['eval_levels'][i]}": wandb.Video( | |
frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})" | |
) | |
} | |
) | |
for j in range(5): | |
frames, episode_length, cum_rewards = ( | |
stats["top_learnable_animation"][0][:, j], | |
stats["top_learnable_animation"][1][j], | |
stats["top_learnable_animation"][2][:, j], | |
) # num attempts | |
rr = "|".join([f"{r:<.2f}" for r in cum_rewards]) | |
frames = np.array(frames[:episode_length]) | |
log_dict.update( | |
{ | |
f"media/tl_animation_{j}": wandb.Video( | |
frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})\n{rr}" | |
) | |
} | |
) | |
stats.update(log_dict) | |
wandb.log(stats, step=stats["update_count"]) | |
checkpoint_steps = config["checkpoint_save_freq"] | |
assert config["num_updates"] % config["eval_freq"] == 0, "num_updates must be divisible by eval_freq" | |
for eval_step in range(int(config["num_updates"] // config["eval_freq"])): | |
start_time = time.time() | |
rng, eval_rng = jax.random.split(rng) | |
runner_state, instances, metrics = train_and_eval_step(runner_state, eval_rng) | |
curr_time = time.time() | |
metrics.update(log_buffer(*instances, metrics["update_count"])) | |
metrics["time_delta"] = curr_time - start_time | |
metrics["steps_per_section"] = (config["eval_freq"] * config["num_steps"] * config["num_train_envs"]) / metrics[ | |
"time_delta" | |
] | |
log_eval(metrics) | |
if ((eval_step + 1) * config["eval_freq"]) % checkpoint_steps == 0: | |
if config["save_path"] is not None: | |
steps = int(metrics["update_count"]) * int(config["num_train_envs"]) * int(config["num_steps"]) | |
# save_params_to_wandb(runner_state[0].params, steps, config) | |
save_model_to_wandb(runner_state[0], steps, config) | |
if config["save_path"] is not None: | |
# save_params_to_wandb(runner_state[0].params, config["total_timesteps"], config) | |
save_model_to_wandb(runner_state[0], config["total_timesteps"], config) | |
if __name__ == "__main__": | |
# with jax.disable_jit(): | |
# main() | |
main() | |