tree3po's picture
Upload 190 files
e0f25ed verified
"""
Based on PureJaxRL Implementation of PPO
"""
import os
import sys
import time
import typing
from functools import partial
from typing import NamedTuple
import chex
import hydra
import jax
import jax.experimental
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax.training.train_state import TrainState
from kinetix.environment.ued.ued import make_reset_train_function_with_mutations, make_vmapped_filtered_level_sampler
from kinetix.environment.ued.ued import (
make_reset_train_function_with_list_of_levels,
make_reset_train_function_with_mutations,
)
from kinetix.util.config import (
generate_ued_params_from_config,
init_wandb,
normalise_config,
generate_params_from_config,
get_eval_level_groups,
)
from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv
from omegaconf import OmegaConf
from PIL import Image
from flax.serialization import to_state_dict
import wandb
from kinetix.environment.env import make_kinetix_env_from_name
from kinetix.environment.wrappers import (
AutoReplayWrapper,
DenseRewardWrapper,
LogWrapper,
UnderspecifiedToGymnaxWrapper,
)
from kinetix.models import make_network_from_config
from kinetix.models.actor_critic import ScannedRNN
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.util.learning import general_eval, get_eval_levels
from kinetix.util.saving import (
load_train_state_from_wandb_artifact_path,
save_model_to_wandb,
)
sys.path.append("ued")
from flax.traverse_util import flatten_dict, unflatten_dict
from safetensors.flax import load_file, save_file
def save_params(params: typing.Dict, filename: typing.Union[str, os.PathLike]) -> None:
flattened_dict = flatten_dict(params, sep=",")
save_file(flattened_dict, filename)
def load_params(filename: typing.Union[str, os.PathLike]) -> typing.Dict:
flattened_dict = load_file(filename)
return unflatten_dict(flattened_dict, sep=",")
class Transition(NamedTuple):
global_done: jnp.ndarray
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray
class RolloutBatch(NamedTuple):
obs: jnp.ndarray
actions: jnp.ndarray
rewards: jnp.ndarray
dones: jnp.ndarray
log_probs: jnp.ndarray
values: jnp.ndarray
targets: jnp.ndarray
advantages: jnp.ndarray
# carry: jnp.ndarray
mask: jnp.ndarray
def evaluate_rnn(
rng: chex.PRNGKey,
env: UnderspecifiedEnv,
env_params: EnvParams,
train_state: TrainState,
init_hstate: chex.ArrayTree,
init_obs: Observation,
init_env_state: EnvState,
max_episode_length: int,
keep_states=True,
) -> tuple[chex.Array, chex.Array, chex.Array]:
"""This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths)
Args:
rng (chex.PRNGKey):
env (UnderspecifiedEnv):
env_params (EnvParams):
train_state (TrainState):
init_hstate (chex.ArrayTree): Shape (num_levels, )
init_obs (Observation): Shape (num_levels, )
init_env_state (EnvState): Shape (num_levels, )
max_episode_length (int):
Returns:
Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,)
"""
num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0]
def step(carry, _):
rng, hstate, obs, state, done, mask, episode_length = carry
rng, rng_action, rng_step = jax.random.split(rng, 3)
x = jax.tree.map(lambda x: x[None, ...], (obs, done))
hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x)
action = pi.sample(seed=rng_action).squeeze(0)
obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
jax.random.split(rng_step, num_levels), state, action, env_params
)
next_mask = mask & ~done
episode_length += mask
if keep_states:
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, info)
else:
return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, info)
(_, _, _, _, _, _, episode_lengths), (states, rewards, infos) = jax.lax.scan(
step,
(
rng,
init_hstate,
init_obs,
init_env_state,
jnp.zeros(num_levels, dtype=bool),
jnp.ones(num_levels, dtype=bool),
jnp.zeros(num_levels, dtype=jnp.int32),
),
None,
length=max_episode_length,
)
return states, rewards, episode_lengths, infos
@hydra.main(version_base=None, config_path="../configs", config_name="sfl")
def main(config):
time_start = time.time()
config = OmegaConf.to_container(config)
config = normalise_config(config, "SFL" if config["ued"]["sampled_envs_ratio"] > 0 else "SFL-DR")
env_params, static_env_params = generate_params_from_config(config)
config["env_params"] = to_state_dict(env_params)
config["static_env_params"] = to_state_dict(static_env_params)
run = init_wandb(config, "SFL")
rng = jax.random.PRNGKey(config["seed"])
config["num_envs_from_sampled"] = int(config["num_train_envs"] * config["sampled_envs_ratio"])
config["num_envs_to_generate"] = int(config["num_train_envs"] * (1 - config["sampled_envs_ratio"]))
assert (config["num_envs_from_sampled"] + config["num_envs_to_generate"]) == config["num_train_envs"]
def make_env(static_env_params):
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
env = AutoReplayWrapper(env)
env = UnderspecifiedToGymnaxWrapper(env)
env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
env = LogWrapper(env)
return env
env = make_env(static_env_params)
if config["train_level_mode"] == "list":
sample_random_level = make_reset_train_function_with_list_of_levels(
config, config["train_levels"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
)
elif config["train_level_mode"] == "random":
sample_random_level = make_reset_train_function_with_mutations(
env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
)
else:
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
sample_random_levels = make_vmapped_filtered_level_sampler(
sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
)
_, eval_static_env_params = generate_params_from_config(
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
)
eval_env = make_env(eval_static_env_params)
ued_params = generate_ued_params_from_config(config)
def make_render_fn(static_env_params):
render_fn_inner = make_render_pixels(env_params, static_env_params)
render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
return render_fn
render_fn = make_render_fn(static_env_params)
render_fn_eval = make_render_fn(eval_static_env_params)
NUM_EVAL_DR_LEVELS = 200
key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)
print("Hello here num steps is ", config["num_steps"])
print("CONFIG is ", config)
config["total_timesteps"] = config["num_updates"] * config["num_steps"] * config["num_train_envs"]
config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"]
config["clip_eps"] = config["clip_eps"]
config["env_name"] = config["env_name"]
network = make_network_from_config(env, env_params, config)
def linear_schedule(count):
count = count // (config["num_minibatches"] * config["update_epochs"])
frac = 1.0 - count / config["num_updates"]
return config["lr"] * frac
# INIT NETWORK
rng, _rng = jax.random.split(rng)
train_envs = 32 # To not run out of memory, the initial sample size does not matter.
obs, _ = env.reset_to_level(rng, sample_random_level(rng), env_params)
obs = jax.tree.map(
lambda x: jnp.repeat(jnp.repeat(x[None, ...], train_envs, axis=0)[None, ...], 256, axis=0),
obs,
)
init_x = (obs, jnp.zeros((256, train_envs)))
init_hstate = ScannedRNN.initialize_carry(train_envs)
network_params = network.init(_rng, init_hstate, init_x)
if config["anneal_lr"]:
tx = optax.chain(
optax.clip_by_global_norm(config["max_grad_norm"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config["max_grad_norm"]),
optax.adam(config["lr"], eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
if config["load_from_checkpoint"] != None:
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
train_state = load_train_state_from_wandb_artifact_path(
train_state,
config["load_from_checkpoint"],
load_only_params=config["load_only_params"],
legacy=config["load_legacy_checkpoint"],
)
rng, _rng = jax.random.split(rng)
# INIT ENV
rng, _rng, _rng2 = jax.random.split(rng, 3)
rng_reset = jax.random.split(_rng, config["num_train_envs"])
new_levels = sample_random_levels(_rng2, config["num_train_envs"])
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
start_state = env_state
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
@jax.jit
def log_buffer_learnability(rng, train_state, instances):
BATCH_SIZE = config["num_to_save"]
BATCH_ACTORS = BATCH_SIZE
def _batch_step(unused, rng):
def _env_step(runner_state, unused):
env_state, start_state, last_obs, last_done, hstate, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
obs_batch = last_obs
ac_in = (
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
last_done[np.newaxis, :],
)
hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
action = pi.sample(seed=_rng).squeeze()
log_prob = pi.log_prob(action)
env_act = action
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["num_to_save"])
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
rng_step, env_state, env_act, env_params
)
done_batch = done
transition = Transition(
done,
last_done,
action.squeeze(),
value.squeeze(),
reward,
log_prob.squeeze(),
obs_batch,
info,
)
runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
return runner_state, transition
@partial(jax.vmap, in_axes=(None, 1, 1, 1))
@partial(jax.jit, static_argnums=(0,))
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
idxs = jnp.arange(max_steps)
@partial(jax.vmap, in_axes=(0, 0))
def __ep_outcomes(start_idx, end_idx):
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
r = jnp.sum(returns * mask)
goal_r = info["GoalR"] # (returns > 0) * 1.0
success = jnp.sum(goal_r * mask)
l = end_idx - start_idx
return r, success, l
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
mask_done = jnp.where(done_idxs == max_steps, 0, 1)
ep_return, success, length = __ep_outcomes(
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
)
return {
"ep_return": ep_return.mean(where=mask_done),
"num_episodes": mask_done.sum(),
"success_rate": success.mean(where=mask_done),
"ep_len": length.mean(where=mask_done),
}
# sample envs
rng, _rng, _rng2 = jax.random.split(rng, 3)
rng_reset = jax.random.split(_rng, config["num_to_save"])
rng_levels = jax.random.split(_rng2, config["num_to_save"])
# obsv, env_state = jax.vmap(sample_random_level, in_axes=(0,))(reset_rng)
# new_levels = jax.vmap(sample_random_level)(rng_levels)
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, instances, env_params)
# env_instances = new_levels
init_hstate = ScannedRNN.initialize_carry(
BATCH_ACTORS,
)
runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
done_by_env = traj_batch.done.reshape((-1, config["num_to_save"]))
reward_by_env = traj_batch.reward.reshape((-1, config["num_to_save"]))
# info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
success_by_env = o["success_rate"].reshape((1, config["num_to_save"]))
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
return None, (learnability_by_env, success_by_env.sum(axis=0))
rngs = jax.random.split(rng, 1)
_, (learnability, success_by_env) = jax.lax.scan(_batch_step, None, rngs, 1)
return learnability[0], success_by_env[0]
num_eval_levels = len(config["eval_levels"])
all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
eval_group_indices = get_eval_level_groups(config["eval_levels"])
print("group indices", eval_group_indices)
@jax.jit
def get_learnability_set(rng, network_params):
BATCH_ACTORS = config["batch_size"]
def _batch_step(unused, rng):
def _env_step(runner_state, unused):
env_state, start_state, last_obs, last_done, hstate, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
obs_batch = last_obs
ac_in = (
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
last_done[np.newaxis, :],
)
hstate, pi, value = network.apply(network_params, hstate, ac_in)
action = pi.sample(seed=_rng).squeeze()
log_prob = pi.log_prob(action)
env_act = action
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["batch_size"])
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
rng_step, env_state, env_act, env_params
)
done_batch = done
transition = Transition(
done,
last_done,
action.squeeze(),
value.squeeze(),
reward,
log_prob.squeeze(),
obs_batch,
info,
)
runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
return runner_state, transition
@partial(jax.vmap, in_axes=(None, 1, 1, 1))
@partial(jax.jit, static_argnums=(0,))
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
idxs = jnp.arange(max_steps)
@partial(jax.vmap, in_axes=(0, 0))
def __ep_outcomes(start_idx, end_idx):
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
r = jnp.sum(returns * mask)
goal_r = info["GoalR"] # (returns > 0) * 1.0
success = jnp.sum(goal_r * mask)
l = end_idx - start_idx
return r, success, l
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
mask_done = jnp.where(done_idxs == max_steps, 0, 1)
ep_return, success, length = __ep_outcomes(
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
)
return {
"ep_return": ep_return.mean(where=mask_done),
"num_episodes": mask_done.sum(),
"success_rate": success.mean(where=mask_done),
"ep_len": length.mean(where=mask_done),
}
# sample envs
rng, _rng, _rng2 = jax.random.split(rng, 3)
rng_reset = jax.random.split(_rng, config["batch_size"])
new_levels = sample_random_levels(_rng2, config["batch_size"])
obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
env_instances = new_levels
init_hstate = ScannedRNN.initialize_carry(
BATCH_ACTORS,
)
runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
done_by_env = traj_batch.done.reshape((-1, config["batch_size"]))
reward_by_env = traj_batch.reward.reshape((-1, config["batch_size"]))
# info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
success_by_env = o["success_rate"].reshape((1, config["batch_size"]))
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
return None, (learnability_by_env, success_by_env.sum(axis=0), env_instances)
if config["sampled_envs_ratio"] == 0.0:
print("Not doing any rollouts because sampled_envs_ratio is 0.0")
# Here we have zero envs, so we can literally just sample random ones because there is no point.
top_instances = sample_random_levels(_rng, config["num_to_save"])
top_success = top_learn = learnability = success_rates = jnp.zeros(config["num_to_save"])
else:
rngs = jax.random.split(rng, config["num_batches"])
_, (learnability, success_rates, env_instances) = jax.lax.scan(
_batch_step, None, rngs, config["num_batches"]
)
flat_env_instances = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), env_instances)
learnability = learnability.flatten() + success_rates.flatten() * 0.001
top_1000 = jnp.argsort(learnability)[-config["num_to_save"] :]
top_1000_instances = jax.tree.map(lambda x: x.at[top_1000].get(), flat_env_instances)
top_learn, top_instances = learnability.at[top_1000].get(), top_1000_instances
top_success = success_rates.at[top_1000].get()
if config["put_eval_levels_in_buffer"]:
top_instances = jax.tree.map(
lambda all, new: jnp.concatenate([all[:-num_eval_levels], new], axis=0),
top_instances,
all_eval_levels.env_state,
)
log = {
"learnability/learnability_sampled_mean": learnability.mean(),
"learnability/learnability_sampled_median": jnp.median(learnability),
"learnability/learnability_sampled_min": learnability.min(),
"learnability/learnability_sampled_max": learnability.max(),
"learnability/learnability_selected_mean": top_learn.mean(),
"learnability/learnability_selected_median": jnp.median(top_learn),
"learnability/learnability_selected_min": top_learn.min(),
"learnability/learnability_selected_max": top_learn.max(),
"learnability/solve_rate_sampled_mean": top_success.mean(),
"learnability/solve_rate_sampled_median": jnp.median(top_success),
"learnability/solve_rate_sampled_min": top_success.min(),
"learnability/solve_rate_sampled_max": top_success.max(),
"learnability/solve_rate_selected_mean": success_rates.mean(),
"learnability/solve_rate_selected_median": jnp.median(success_rates),
"learnability/solve_rate_selected_min": success_rates.min(),
"learnability/solve_rate_selected_max": success_rates.max(),
}
return top_learn, top_instances, log
def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
"""
This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
"""
num_levels = len(config["eval_levels"])
# eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
return general_eval(
rng,
eval_env,
env_params,
train_state,
all_eval_levels,
env_params.max_timesteps,
num_levels,
keep_states=keep_states,
return_trajectories=True,
)
def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
return general_eval(
rng,
env,
env_params,
train_state,
DR_EVAL_LEVELS,
env_params.max_timesteps,
NUM_EVAL_DR_LEVELS,
keep_states=keep_states,
)
def eval_on_top_learnable_levels(rng: chex.PRNGKey, train_state: TrainState, levels, keep_states=True):
N = 5
return general_eval(
rng,
env,
env_params,
train_state,
jax.tree.map(lambda x: x[:N], levels),
env_params.max_timesteps,
N,
keep_states=keep_states,
)
# TRAIN LOOP
def train_step(runner_state_instances, unused):
# COLLECT TRAJECTORIES
runner_state, instances = runner_state_instances
num_env_instances = instances.polygon.position.shape[0]
def _env_step(runner_state, unused):
train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
obs_batch = last_obs
ac_in = (
jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
last_done[np.newaxis, :],
)
hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
action = pi.sample(seed=_rng).squeeze()
log_prob = pi.log_prob(action)
env_act = action
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["num_train_envs"])
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
rng_step, env_state, env_act, env_params
)
done_batch = done
transition = Transition(
done,
last_done,
action.squeeze(),
value.squeeze(),
reward,
log_prob.squeeze(),
obs_batch,
info,
)
runner_state = (train_state, env_state, start_state, obsv, done_batch, hstate, update_steps, rng)
return runner_state, (transition)
initial_hstate = runner_state[-3]
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"])
# CALCULATE ADVANTAGE
train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
last_obs_batch = last_obs # batchify(last_obs, env.agents, config["num_train_envs"])
ac_in = (
jax.tree.map(lambda x: x[np.newaxis, :], last_obs_batch),
last_done[np.newaxis, :],
)
_, _, last_val = network.apply(train_state.params, hstate, ac_in)
last_val = last_val.squeeze()
def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition: Transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.global_done,
transition.value,
transition.reward,
)
delta = reward + config["gamma"] * next_value * (1 - done) - value
gae = delta + config["gamma"] * config["gae_lambda"] * (1 - done) * gae
return (gae, value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value
advantages, targets = _calculate_gae(traj_batch, last_val)
# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
init_hstate, traj_batch, advantages, targets = batch_info
def _loss_fn_masked(params, init_hstate, traj_batch, gae, targets):
# RERUN NETWORK
_, pi, value = network.apply(
params,
jax.tree.map(lambda x: x.transpose(), init_hstate),
(traj_batch.obs, traj_batch.done),
)
log_prob = pi.log_prob(traj_batch.action)
# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
-config["clip_eps"], config["clip_eps"]
)
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped)
critic_loss = config["vf_coef"] * value_loss.mean()
# CALCULATE ACTOR LOSS
logratio = log_prob - traj_batch.log_prob
ratio = jnp.exp(logratio)
# if env.do_sep_reward: gae = gae.sum(axis=-1)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config["clip_eps"],
1.0 + config["clip_eps"],
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
approx_kl = jax.lax.stop_gradient(((ratio - 1) - logratio).mean())
clipfrac = jax.lax.stop_gradient((jnp.abs(ratio - 1) > config["clip_eps"]).mean())
total_loss = loss_actor + critic_loss - config["ent_coef"] * entropy
return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clipfrac)
grad_fn = jax.value_and_grad(_loss_fn_masked, has_aux=True)
total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
(
train_state,
init_hstate,
traj_batch,
advantages,
targets,
rng,
) = update_state
rng, _rng = jax.random.split(rng)
init_hstate = jax.tree.map(lambda x: jnp.reshape(x, (256, config["num_train_envs"])), init_hstate)
batch = (
init_hstate,
traj_batch,
advantages.squeeze(),
targets.squeeze(),
)
permutation = jax.random.permutation(_rng, config["num_train_envs"])
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
[x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]),
),
1,
0,
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
# total_loss = jax.tree.map(lambda x: x.mean(), total_loss)
update_state = (
train_state,
init_hstate,
traj_batch,
advantages,
targets,
rng,
)
return update_state, total_loss
# init_hstate = initial_hstate[None, :].squeeze().transpose()
init_hstate = jax.tree.map(lambda x: x[None, :].squeeze().transpose(), initial_hstate)
update_state = (
train_state,
init_hstate,
traj_batch,
advantages,
targets,
rng,
)
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"])
train_state = update_state[0]
metric = traj_batch.info
metric = jax.tree.map(
lambda x: x.reshape((config["num_steps"], config["num_train_envs"])), # , env.num_agents
traj_batch.info,
)
rng = update_state[-1]
def callback(metric):
dones = metric["dones"]
wandb.log(
{
"episode_return": (metric["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
"episode_solved": (metric["returned_episode_solved"] * dones).sum() / jnp.maximum(1, dones.sum()),
"episode_length": (metric["returned_episode_lengths"] * dones).sum() / jnp.maximum(1, dones.sum()),
"timing/num_env_steps": int(
int(metric["update_steps"]) * int(config["num_train_envs"]) * int(config["num_steps"])
),
"timing/num_updates": metric["update_steps"],
**metric["loss_info"],
}
)
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
metric["loss_info"] = {
"loss/total_loss": loss_info[0],
"loss/value_loss": loss_info[1][0],
"loss/policy_loss": loss_info[1][1],
"loss/entropy_loss": loss_info[1][2],
}
metric["dones"] = traj_batch.done
metric["update_steps"] = update_steps
jax.experimental.io_callback(callback, None, metric)
# SAMPLE NEW ENVS
rng, _rng, _rng2 = jax.random.split(rng, 3)
rng_reset = jax.random.split(_rng, config["num_envs_to_generate"])
new_levels = sample_random_levels(_rng2, config["num_envs_to_generate"])
obsv_gen, env_state_gen = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
rng, _rng, _rng2 = jax.random.split(rng, 3)
sampled_env_instances_idxs = jax.random.randint(_rng, (config["num_envs_from_sampled"],), 0, num_env_instances)
sampled_env_instances = jax.tree.map(lambda x: x.at[sampled_env_instances_idxs].get(), instances)
myrng = jax.random.split(_rng2, config["num_envs_from_sampled"])
obsv_sampled, env_state_sampled = jax.vmap(env.reset_to_level, in_axes=(0, 0))(myrng, sampled_env_instances)
obsv = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), obsv_gen, obsv_sampled)
env_state = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), env_state_gen, env_state_sampled)
start_state = env_state
hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
update_steps = update_steps + 1
runner_state = (
train_state,
env_state,
start_state,
obsv,
jnp.zeros((config["num_train_envs"]), dtype=bool),
hstate,
update_steps,
rng,
)
return (runner_state, instances), metric
def log_buffer(learnability, levels, epoch):
num_samples = levels.polygon.position.shape[0]
states = levels
rows = 2
fig, axes = plt.subplots(rows, int(num_samples / rows), figsize=(20, 10))
axes = axes.flatten()
all_imgs = jax.vmap(render_fn)(states)
for i, ax in enumerate(axes):
# ax.imshow(train_state.plr_buffer.get_sample(i))
score = learnability[i]
ax.imshow(all_imgs[i] / 255.0)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(f"learnability: {score:.3f}")
ax.set_aspect("equal", "box")
plt.tight_layout()
fig.canvas.draw()
im = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
plt.close()
return {"maps": wandb.Image(im)}
@jax.jit
def train_and_eval_step(runner_state, eval_rng):
learnability_rng, eval_singleton_rng, eval_sampled_rng, _rng = jax.random.split(eval_rng, 4)
# TRAIN
learnabilty_scores, instances, test_metrics = get_learnability_set(learnability_rng, runner_state[0].params)
if config["log_learnability_before_after"]:
learn_scores_before, success_score_before = log_buffer_learnability(
learnability_rng, runner_state[0], instances
)
print("instance size", sum(x.size for x in jax.tree_util.tree_leaves(instances)))
runner_state_instances = (runner_state, instances)
runner_state_instances, metrics = jax.lax.scan(train_step, runner_state_instances, None, config["eval_freq"])
if config["log_learnability_before_after"]:
learn_scores_after, success_score_after = log_buffer_learnability(
learnability_rng, runner_state_instances[0][0], instances
)
# EVAL
rng, rng_eval = jax.random.split(eval_singleton_rng)
(states, cum_rewards, _, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(eval, (0, None))(
jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
)
all_eval_eplens = episode_lengths
# Collect Metrics
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
1, eval_dones.sum(axis=1)
)
eval_solves = eval_solves.mean(axis=0)
# just grab the first run
states, episode_lengths = jax.tree_util.tree_map(
lambda x: x[0], (states, episode_lengths)
) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
# And one attempt
states = jax.tree_util.tree_map(lambda x: x[:, :], states)
episode_lengths = episode_lengths[:]
images = jax.vmap(jax.vmap(render_fn_eval))(
states.env_state.env_state.env_state
) # (num_steps, num_eval_levels, ...)
frames = images.transpose(
0, 1, 4, 2, 3
) # WandB expects color channel before image dimensions when dealing with animations for some reason
test_metrics["update_count"] = runner_state[-2]
test_metrics["eval_returns"] = eval_returns
test_metrics["eval_ep_lengths"] = episode_lengths
test_metrics["eval_animation"] = (frames, episode_lengths)
# Eval on sampled
dr_states, dr_cum_rewards, _, dr_episode_lengths, dr_infos = jax.vmap(eval_on_dr_levels, (0, None))(
jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
)
eval_dr_returns = dr_cum_rewards.mean(axis=0).mean()
eval_dr_eplen = dr_episode_lengths.mean(axis=0).mean()
test_metrics["eval/mean_eval_return_sampled"] = eval_dr_returns
my_eval_dones = dr_infos["returned_episode"]
eval_dr_solves = (dr_infos["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
1, my_eval_dones.sum(axis=1)
)
test_metrics["eval/mean_eval_solve_rate_sampled"] = eval_dr_solves
test_metrics["eval/mean_eval_eplen_sampled"] = eval_dr_eplen
# Collect Metrics
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
log_dict = {}
log_dict["to_remove"] = {
"eval_return": eval_returns,
"eval_solve_rate": eval_solves,
"eval_eplen": all_eval_eplens,
}
for i, name in enumerate(config["eval_levels"]):
log_dict[f"eval_avg_return/{name}"] = eval_returns[i]
log_dict[f"eval_avg_solve_rate/{name}"] = eval_solves[i]
log_dict.update({"eval/mean_eval_return": eval_returns.mean()})
log_dict.update({"eval/mean_eval_solve_rate": eval_solves.mean()})
log_dict.update({"eval/mean_eval_eplen": all_eval_eplens.mean()})
test_metrics.update(log_dict)
runner_state, _ = runner_state_instances
test_metrics["update_count"] = runner_state[-2]
top_instances = jax.tree.map(lambda x: x.at[-5:].get(), instances)
# Eval on top learnable levels
tl_states, tl_cum_rewards, _, tl_episode_lengths, tl_infos = jax.vmap(
eval_on_top_learnable_levels, (0, None, None)
)(jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0], top_instances)
# just grab the first run
states, episode_lengths = jax.tree_util.tree_map(
lambda x: x[0], (tl_states, tl_episode_lengths)
) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
# And one attempt
states = jax.tree_util.tree_map(lambda x: x[:, :], states)
episode_lengths = episode_lengths[:]
images = jax.vmap(jax.vmap(render_fn))(
states.env_state.env_state.env_state
) # (num_steps, num_eval_levels, ...)
frames = images.transpose(
0, 1, 4, 2, 3
) # WandB expects color channel before image dimensions when dealing with animations for some reason
test_metrics["top_learnable_animation"] = (frames, episode_lengths, tl_cum_rewards)
if config["log_learnability_before_after"]:
def single(x, name):
return {
f"{name}_mean": x.mean(),
f"{name}_std": x.std(),
f"{name}_min": x.min(),
f"{name}_max": x.max(),
f"{name}_median": jnp.median(x),
}
test_metrics["learnability_log_v2/"] = {
**single(learn_scores_before, "learnability_before"),
**single(learn_scores_after, "learnability_after"),
**single(success_score_before, "success_score_before"),
**single(success_score_after, "success_score_after"),
}
return runner_state, (learnabilty_scores.at[-20:].get(), top_instances), test_metrics
rng, _rng = jax.random.split(rng)
runner_state = (
train_state,
env_state,
start_state,
obsv,
jnp.zeros((config["num_train_envs"]), dtype=bool),
init_hstate,
0,
_rng,
)
def log_eval(stats):
log_dict = {}
to_remove = stats["to_remove"]
del stats["to_remove"]
def _aggregate_per_size(values, name):
to_return = {}
for group_name, indices in eval_group_indices.items():
to_return[f"{name}_{group_name}"] = values[indices].mean()
return to_return
env_steps = stats["update_count"] * config["num_train_envs"] * config["num_steps"]
env_steps_delta = config["eval_freq"] * config["num_train_envs"] * config["num_steps"]
time_now = time.time()
log_dict = {
"timing/num_updates": stats["update_count"],
"timing/num_env_steps": env_steps,
"timing/sps": env_steps_delta / stats["time_delta"],
"timing/sps_agg": env_steps / (time_now - time_start),
}
log_dict.update(_aggregate_per_size(to_remove["eval_return"], "eval_aggregate/return"))
log_dict.update(_aggregate_per_size(to_remove["eval_solve_rate"], "eval_aggregate/solve_rate"))
for i in range((len(config["eval_levels"]))):
frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
frames = np.array(frames[:episode_length])
log_dict.update(
{
f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})"
)
}
)
for j in range(5):
frames, episode_length, cum_rewards = (
stats["top_learnable_animation"][0][:, j],
stats["top_learnable_animation"][1][j],
stats["top_learnable_animation"][2][:, j],
) # num attempts
rr = "|".join([f"{r:<.2f}" for r in cum_rewards])
frames = np.array(frames[:episode_length])
log_dict.update(
{
f"media/tl_animation_{j}": wandb.Video(
frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})\n{rr}"
)
}
)
stats.update(log_dict)
wandb.log(stats, step=stats["update_count"])
checkpoint_steps = config["checkpoint_save_freq"]
assert config["num_updates"] % config["eval_freq"] == 0, "num_updates must be divisible by eval_freq"
for eval_step in range(int(config["num_updates"] // config["eval_freq"])):
start_time = time.time()
rng, eval_rng = jax.random.split(rng)
runner_state, instances, metrics = train_and_eval_step(runner_state, eval_rng)
curr_time = time.time()
metrics.update(log_buffer(*instances, metrics["update_count"]))
metrics["time_delta"] = curr_time - start_time
metrics["steps_per_section"] = (config["eval_freq"] * config["num_steps"] * config["num_train_envs"]) / metrics[
"time_delta"
]
log_eval(metrics)
if ((eval_step + 1) * config["eval_freq"]) % checkpoint_steps == 0:
if config["save_path"] is not None:
steps = int(metrics["update_count"]) * int(config["num_train_envs"]) * int(config["num_steps"])
# save_params_to_wandb(runner_state[0].params, steps, config)
save_model_to_wandb(runner_state[0], steps, config)
if config["save_path"] is not None:
# save_params_to_wandb(runner_state[0].params, config["total_timesteps"], config)
save_model_to_wandb(runner_state[0], config["total_timesteps"], config)
if __name__ == "__main__":
# with jax.disable_jit():
# main()
main()