Spaces:
Runtime error
Runtime error
import os | |
import hydra | |
from omegaconf import OmegaConf | |
from kinetix.environment.ued.ued import ( | |
make_reset_train_function_with_list_of_levels, | |
make_reset_train_function_with_mutations, | |
) | |
from kinetix.render.renderer_pixels import make_render_pixels | |
from kinetix.util.config import ( | |
get_video_frequency, | |
init_wandb, | |
normalise_config, | |
generate_params_from_config, | |
) | |
os.environ["WANDB_DISABLE_SERVICE"] = "True" | |
import sys | |
from typing import Any, NamedTuple | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
from flax.training.train_state import TrainState | |
from kinetix.models import make_network_from_config | |
from kinetix.util.learning import general_eval, get_eval_levels | |
from flax.serialization import to_state_dict | |
import wandb | |
from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name | |
from kinetix.environment.wrappers import ( | |
AutoReplayWrapper, | |
AutoResetWrapper, | |
BatchEnvWrapper, | |
DenseRewardWrapper, | |
LogWrapper, | |
UnderspecifiedToGymnaxWrapper, | |
) | |
from kinetix.models.actor_critic import ScannedRNN | |
from kinetix.util.saving import ( | |
load_train_state_from_wandb_artifact_path, | |
save_model_to_wandb, | |
) | |
class Transition(NamedTuple): | |
done: jnp.ndarray | |
action: jnp.ndarray | |
value: jnp.ndarray | |
reward: jnp.ndarray | |
log_prob: jnp.ndarray | |
obs: Any | |
info: jnp.ndarray | |
def make_train(config, env_params, static_env_params): | |
config["num_updates"] = config["total_timesteps"] // config["num_steps"] // config["num_train_envs"] | |
config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"] | |
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params) | |
if config["train_level_mode"] == "list": | |
reset_func = make_reset_train_function_with_list_of_levels( | |
config, config["train_levels_list"], static_env_params, is_loading_train_levels=True | |
) | |
elif config["train_level_mode"] == "random": | |
reset_func = make_reset_train_function_with_mutations( | |
env.physics_engine, env_params, env.static_env_params, config | |
) | |
else: | |
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}") | |
env = UnderspecifiedToGymnaxWrapper(AutoResetWrapper(env, reset_func)) | |
eval_env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params) | |
eval_env = UnderspecifiedToGymnaxWrapper(AutoReplayWrapper(eval_env)) | |
env = DenseRewardWrapper(env) | |
env = LogWrapper(env) | |
env = BatchEnvWrapper(env, num_envs=config["num_train_envs"]) | |
eval_env_nonbatch = LogWrapper(DenseRewardWrapper(eval_env)) | |
def linear_schedule(count): | |
frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / config["num_updates"] | |
return config["lr"] * frac | |
def linear_warmup_cosine_decay_schedule(count): | |
frac = (count // (config["num_minibatches"] * config["update_epochs"])) / config[ | |
"num_updates" | |
] # between 0 and 1 | |
delta = config["peak_lr"] - config["initial_lr"] | |
frac_diff_max = 1.0 - config["warmup_frac"] | |
frac_cosine = (frac - config["warmup_frac"]) / frac_diff_max | |
return jax.lax.select( | |
frac < config["warmup_frac"], | |
config["initial_lr"] + delta * frac / config["warmup_frac"], | |
config["peak_lr"] * jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * ((frac_cosine) % 1.0)))), | |
) | |
def train(rng): | |
# INIT NETWORK | |
network = make_network_from_config(env, env_params, config) | |
rng, _rng = jax.random.split(rng) | |
obsv, env_state = env.reset(_rng, env_params) | |
dones = jnp.zeros((config["num_train_envs"]), dtype=jnp.bool_) | |
rng, _rng = jax.random.split(rng) | |
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
init_x = jax.tree.map(lambda x: x[None, ...], (obsv, dones)) | |
network_params = network.init(_rng, init_hstate, init_x) | |
param_count = sum(x.size for x in jax.tree_util.tree_leaves(network_params)) | |
obs_size = sum(x.size for x in jax.tree_util.tree_leaves(obsv)) // config["num_train_envs"] | |
print("Number of parameters", param_count, "size of obs: ", obs_size) | |
if config["anneal_lr"]: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adam(learning_rate=linear_schedule, eps=1e-5), | |
) | |
elif config["warmup_lr"]: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adamw(learning_rate=linear_warmup_cosine_decay_schedule, eps=1e-5), | |
) | |
else: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adam(config["lr"], eps=1e-5), | |
) | |
train_state = TrainState.create( | |
apply_fn=network.apply, | |
params=network_params, | |
tx=tx, | |
) | |
if config["load_from_checkpoint"] != None: | |
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"]) | |
train_state = load_train_state_from_wandb_artifact_path( | |
train_state, config["load_from_checkpoint"], load_only_params=config["load_only_params"] | |
) | |
# INIT ENV | |
rng, _rng = jax.random.split(rng) | |
obsv, env_state = env.reset(_rng, env_params) | |
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
render_static_env_params = env.static_env_params.replace(downscale=1) | |
pixel_renderer = jax.jit(make_render_pixels(env_params, render_static_env_params)) | |
pixel_render_fn = lambda x: pixel_renderer(x) / 255.0 | |
eval_levels = get_eval_levels(config["eval_levels"], env.static_env_params) | |
def _vmapped_eval_step(runner_state, rng): | |
def _single_eval_step(rng): | |
return general_eval( | |
rng, | |
eval_env_nonbatch, | |
env_params, | |
runner_state[0], | |
eval_levels, | |
env_params.max_timesteps, | |
config["num_eval_levels"], | |
keep_states=True, | |
return_trajectories=True, | |
) | |
(states, returns, done_idxs, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap( | |
_single_eval_step | |
)(jax.random.split(rng, config["eval_num_attempts"])) | |
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum( | |
1, eval_dones.sum(axis=1) | |
) | |
states_to_plot = jax.tree.map(lambda x: x[0], states) | |
# obs = jax.vmap(jax.vmap(pixel_render_fn))(states_to_plot.env_state.env_state.env_state) | |
return ( | |
states_to_plot, | |
done_idxs[0], | |
returns[0], | |
returns.mean(axis=0), | |
episode_lengths.mean(axis=0), | |
eval_solves.mean(axis=0), | |
) | |
# TRAIN LOOP | |
def _update_step(runner_state, unused): | |
# COLLECT TRAJECTORIES | |
def _env_step(runner_state, unused): | |
( | |
train_state, | |
env_state, | |
last_obs, | |
last_done, | |
hstate, | |
rng, | |
update_step, | |
) = runner_state | |
# SELECT ACTION | |
rng, _rng = jax.random.split(rng) | |
ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :]) | |
hstate, pi, value = network.apply(train_state.params, hstate, ac_in) | |
action = pi.sample(seed=_rng) | |
log_prob = pi.log_prob(action) | |
value, action, log_prob = ( | |
value.squeeze(0), | |
action.squeeze(0), | |
log_prob.squeeze(0), | |
) | |
# STEP ENV | |
rng, _rng = jax.random.split(rng) | |
obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params) | |
transition = Transition(last_done, action, value, reward, log_prob, last_obs, info) | |
runner_state = ( | |
train_state, | |
env_state, | |
obsv, | |
done, | |
hstate, | |
rng, | |
update_step, | |
) | |
return runner_state, transition | |
initial_hstate = runner_state[-3] | |
runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"]) | |
# CALCULATE ADVANTAGE | |
( | |
train_state, | |
env_state, | |
last_obs, | |
last_done, | |
hstate, | |
rng, | |
update_step, | |
) = runner_state | |
ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :]) | |
_, _, last_val = network.apply(train_state.params, hstate, ac_in) | |
last_val = last_val.squeeze(0) | |
def _calculate_gae(traj_batch, last_val, last_done): | |
def _get_advantages(carry, transition): | |
gae, next_value, next_done = carry | |
done, value, reward = ( | |
transition.done, | |
transition.value, | |
transition.reward, | |
) | |
delta = reward + config["gamma"] * next_value * (1 - next_done) - value | |
gae = delta + config["gamma"] * config["gae_lambda"] * (1 - next_done) * gae | |
return (gae, value, done), gae | |
_, advantages = jax.lax.scan( | |
_get_advantages, | |
(jnp.zeros_like(last_val), last_val, last_done), | |
traj_batch, | |
reverse=True, | |
unroll=16, | |
) | |
return advantages, advantages + traj_batch.value | |
advantages, targets = _calculate_gae(traj_batch, last_val, last_done) | |
# UPDATE NETWORK | |
def _update_epoch(update_state, unused): | |
def _update_minbatch(train_state, batch_info): | |
init_hstate, traj_batch, advantages, targets = batch_info | |
def _loss_fn(params, init_hstate, traj_batch, gae, targets): | |
# RERUN NETWORK | |
_, pi, value = network.apply(params, init_hstate[0], (traj_batch.obs, traj_batch.done)) | |
log_prob = pi.log_prob(traj_batch.action) | |
# CALCULATE VALUE LOSS | |
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( | |
-config["clip_eps"], config["clip_eps"] | |
) | |
value_losses = jnp.square(value - targets) | |
value_losses_clipped = jnp.square(value_pred_clipped - targets) | |
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() | |
# CALCULATE ACTOR LOSS | |
ratio = jnp.exp(log_prob - traj_batch.log_prob) | |
gae = (gae - gae.mean()) / (gae.std() + 1e-8) | |
loss_actor1 = ratio * gae | |
loss_actor2 = ( | |
jnp.clip( | |
ratio, | |
1.0 - config["clip_eps"], | |
1.0 + config["clip_eps"], | |
) | |
* gae | |
) | |
loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | |
loss_actor = loss_actor.mean() | |
entropy = pi.entropy().mean() | |
total_loss = loss_actor + config["vf_coef"] * value_loss - config["ent_coef"] * entropy | |
return total_loss, (value_loss, loss_actor, entropy) | |
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) | |
total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets) | |
train_state = train_state.apply_gradients(grads=grads) | |
return train_state, total_loss | |
( | |
train_state, | |
init_hstate, | |
traj_batch, | |
advantages, | |
targets, | |
rng, | |
) = update_state | |
rng, _rng = jax.random.split(rng) | |
permutation = jax.random.permutation(_rng, config["num_train_envs"]) | |
batch = (init_hstate, traj_batch, advantages, targets) | |
shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch) | |
minibatches = jax.tree_util.tree_map( | |
lambda x: jnp.swapaxes( | |
jnp.reshape( | |
x, | |
[x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]), | |
), | |
1, | |
0, | |
), | |
shuffled_batch, | |
) | |
train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches) | |
update_state = ( | |
train_state, | |
init_hstate, | |
traj_batch, | |
advantages, | |
targets, | |
rng, | |
) | |
return update_state, total_loss | |
init_hstate = initial_hstate[None, :] # TBH | |
update_state = ( | |
train_state, | |
init_hstate, | |
traj_batch, | |
advantages, | |
targets, | |
rng, | |
) | |
update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"]) | |
train_state = update_state[0] | |
metric = jax.tree.map( | |
lambda x: (x * traj_batch.info["returned_episode"]).sum() / traj_batch.info["returned_episode"].sum(), | |
traj_batch.info, | |
) | |
rng = update_state[-1] | |
if config["use_wandb"]: | |
vid_frequency = get_video_frequency(config, update_step) | |
rng, _rng = jax.random.split(rng) | |
to_log_videos = _vmapped_eval_step(runner_state, _rng) | |
should_log_videos = update_step % vid_frequency == 0 | |
first = jax.lax.cond( | |
should_log_videos, | |
lambda: jax.vmap(jax.vmap(pixel_render_fn))(to_log_videos[0].env_state.env_state.env_state), | |
lambda: ( | |
jnp.zeros( | |
( | |
env_params.max_timesteps, | |
config["num_eval_levels"], | |
*PixelObservations(env_params, render_static_env_params) | |
.observation_space(env_params) | |
.shape, | |
) | |
) | |
), | |
) | |
to_log_videos = (first, should_log_videos, *to_log_videos[1:]) | |
def callback(metric, raw_info, loss_info, update_step, to_log_videos): | |
to_log = {} | |
to_log["timing/num_updates"] = update_step | |
to_log["timing/num_env_steps"] = update_step * config["num_steps"] * config["num_train_envs"] | |
( | |
obs_vid, | |
should_log_videos, | |
idx_vid, | |
eval_return_vid, | |
eval_return_mean, | |
eval_eplen_mean, | |
eval_solverate_mean, | |
) = to_log_videos | |
to_log["eval/mean_eval_return"] = eval_return_mean.mean() | |
to_log["eval/mean_eval_eplen"] = eval_eplen_mean.mean() | |
for i, eval_name in enumerate(config["eval_levels"]): | |
return_on_video = eval_return_vid[i] | |
to_log[f"eval_video/return_{eval_name}"] = return_on_video | |
to_log[f"eval_video/len_{eval_name}"] = idx_vid[i] | |
to_log[f"eval_avg/return_{eval_name}"] = eval_return_mean[i] | |
to_log[f"eval_avg/solve_rate_{eval_name}"] = eval_solverate_mean[i] | |
if should_log_videos: | |
for i, eval_name in enumerate(config["eval_levels"]): | |
obs_to_use = obs_vid[: idx_vid[i], i] | |
obs_to_use = np.asarray(obs_to_use).transpose(0, 3, 2, 1)[:, :, ::-1, :] | |
to_log[f"media/eval_video_{eval_name}"] = wandb.Video((obs_to_use * 255).astype(np.uint8)) | |
wandb.log(to_log) | |
jax.debug.callback(callback, metric, traj_batch.info, loss_info, update_step, to_log_videos) | |
runner_state = ( | |
train_state, | |
env_state, | |
last_obs, | |
last_done, | |
hstate, | |
rng, | |
update_step + 1, | |
) | |
return runner_state, metric | |
rng, _rng = jax.random.split(rng) | |
runner_state = ( | |
train_state, | |
env_state, | |
obsv, | |
jnp.zeros((config["num_train_envs"]), dtype=bool), | |
init_hstate, | |
_rng, | |
0, | |
) | |
runner_state, metric = jax.lax.scan(_update_step, runner_state, None, config["num_updates"]) | |
return {"runner_state": runner_state, "metric": metric} | |
return train | |
def main(config): | |
config = normalise_config(OmegaConf.to_container(config), "PPO") | |
env_params, static_env_params = generate_params_from_config(config) | |
config["env_params"] = to_state_dict(env_params) | |
config["static_env_params"] = to_state_dict(static_env_params) | |
if config["use_wandb"]: | |
run = init_wandb(config, "PPO") | |
rng = jax.random.PRNGKey(config["seed"]) | |
rng, _rng = jax.random.split(rng) | |
train_jit = jax.jit(make_train(config, env_params, static_env_params)) | |
out = train_jit(_rng) | |
if config["use_wandb"]: | |
if config["save_policy"]: | |
train_state = jax.tree.map(lambda x: x, out["runner_state"][0]) | |
save_model_to_wandb(train_state, config["total_timesteps"], config) | |
if __name__ == "__main__": | |
main() | |