Spaces:
Runtime error
Runtime error
import copy | |
import datetime | |
import gzip | |
import json | |
import os | |
from hashlib import md5 | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from numpy import isin | |
from kinetix.environment.ued.ued_state import UEDParams | |
from omegaconf import OmegaConf | |
from pandas import isna | |
from typing import List, Tuple | |
import wandb | |
from kinetix.environment.env_state import EnvParams, StaticEnvParams | |
from collections import defaultdict | |
from kinetix.util.saving import load_from_json_file | |
def get_hash_without_seed(config): | |
old_seed = config["seed"] | |
config["seed"] = 0 | |
ans = md5(OmegaConf.to_yaml(config, sort_keys=True).encode()).hexdigest() | |
config["seed"] = old_seed | |
return ans | |
def get_date() -> str: | |
return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | |
def generate_params_from_config(config): | |
if config.get("env_size_type", "predefined") == "custom": | |
# must load env params from a file | |
_, static_env_params, env_params = load_from_json_file(os.path.join("worlds", config["custom_path"])) | |
return env_params, static_env_params.replace( | |
frame_skip=config["frame_skip"], | |
) | |
env_params = EnvParams() | |
static_env_params = StaticEnvParams().replace( | |
num_polygons=config["num_polygons"], | |
num_circles=config["num_circles"], | |
num_joints=config["num_joints"], | |
num_thrusters=config["num_thrusters"], | |
frame_skip=config["frame_skip"], | |
num_motor_bindings=config["num_motor_bindings"], | |
num_thruster_bindings=config["num_thruster_bindings"], | |
) | |
return env_params, static_env_params | |
def generate_ued_params_from_config(config) -> UEDParams: | |
ans = UEDParams() | |
if config["env_size_name"] == "s": | |
ans = ans.replace(add_shape_n_proposals=1) # otherwise we get a very weird XLA bug. | |
if "fixate_chance_max" in config: | |
print("Changing fixate chance max to", config["fixate_chance_max"]) | |
ans = ans.replace(fixate_chance_max=config["fixate_chance_max"]) | |
return ans | |
def get_eval_level_groups(eval_levels: List[str]) -> List[Tuple[str, str]]: | |
def get_groups(s): | |
# This is the size group | |
group_one = s.split("/")[0] | |
group_two = s.split("/")[1].split("_")[0] | |
group_two = "".join([i for i in group_two if not i.isdigit()]) | |
if group_two == "h": | |
group_two = "handmade" | |
if group_two == "r": | |
group_two = "random" | |
return f"{group_one}_all", f"{group_one}_{group_two}" | |
indices = defaultdict(list) | |
for idx, s in enumerate(eval_levels): | |
groups = get_groups(s) | |
for group in groups: | |
indices[group].append(idx) | |
indices2 = {} | |
for g in indices: | |
indices2[g] = np.array(indices[g]) | |
return indices2 | |
def normalise_config(config, name, editor_config=False): | |
old_config = copy.deepcopy(config) | |
keys = ["env", "learning", "model", "misc", "eval", "ued", "env_size", "train_levels"] | |
for k in keys: | |
if k not in config: | |
config[k] = {} | |
small_d = config[k] | |
del config[k] | |
for kk, vv in small_d.items(): | |
assert kk not in config, kk | |
config[kk] = vv | |
if not editor_config: | |
config["eval_env_size_true"] = config["eval_env_size"] | |
if config["num_train_envs"] == 2048 and "Pixels" in config["env_name"]: | |
config["num_train_envs"] = 512 | |
if "SFL" in name and config["env_size_name"] in ["m", "l"]: | |
config["eval_num_attempts"] = 6 # to avoid a very weird XLA bug. | |
config["hash"] = get_hash_without_seed(config) | |
config["random_hash"] = np.random.randint(2**31) | |
config["log_save_path"] = f"logs/{config['hash']}/{config['seed']}-{get_date()}" | |
os.makedirs(config["log_save_path"], exist_ok=True) | |
with open(f"{config['log_save_path']}/config.yaml", "w") as f: | |
f.write(OmegaConf.to_yaml(old_config)) | |
if config["group"] == "auto": | |
config["group"] = f"{name}-" + config["group_auto_prefix"] + config["env_name"].replace("Kinetix-", "") | |
config["group"] += "-" + str(config["env_size_name"]) | |
if config["eval_levels"] == ["auto"] or config["eval_levels"] == "auto": | |
config["eval_levels"] = config["train_levels_list"] | |
print("Using Auto eval levels:", config["eval_levels"]) | |
config["num_eval_levels"] = len(config["eval_levels"]) | |
steps = ( | |
config["num_steps"] | |
* config.get("outer_rollout_steps", 1) | |
* config["num_train_envs"] | |
* (2 if name == "PAIRED" else 1) | |
) | |
config["num_updates"] = int(config["total_timesteps"]) // steps | |
nsteps = int(config["total_timesteps"] // 1e6) | |
letter = "M" | |
if nsteps >= 1000: | |
nsteps = nsteps // 1000 | |
letter = "B" | |
config["run_name"] = ( | |
config["env_name"] + f"-{name}-" + str(nsteps) + letter + "-" + str(config["num_train_envs"]) | |
) | |
if config["checkpoint_save_freq"] >= config["num_updates"]: | |
config["checkpoint_save_freq"] = config["num_updates"] | |
return config | |
def get_tags(config, name): | |
return [name] | |
tags = [name] | |
if name in ["PLR", "ACCEL", "DR"]: | |
if config["use_accel"]: | |
tags.append("ACCEL") | |
else: | |
tags.append("PLR") | |
return tags | |
def init_wandb(config, name) -> wandb.run: | |
run = wandb.init( | |
config=config, | |
project=config["wandb_project"], | |
group=config["group"], | |
name=config["run_name"], | |
entity=config["wandb_entity"], | |
mode=config["wandb_mode"], | |
tags=get_tags(config, name), | |
) | |
wandb.define_metric("timing/num_updates") | |
wandb.define_metric("timing/num_env_steps") | |
wandb.define_metric("*", step_metric="timing/num_env_steps") | |
wandb.define_metric("timing/sps", step_metric="timing/num_env_steps") | |
return run | |
def save_data_to_local_file(data_to_save, config): | |
if not config.get("save_local_data", False): | |
return | |
def reverse_in(li, value): | |
for i, v in enumerate(li): | |
if v in value: | |
return True | |
return False | |
clean_data = {k: v for k, v in data_to_save.items() if not reverse_in(["media/", "images/"], k)} | |
def _clean(x): | |
if isinstance(x, jnp.ndarray): | |
return x.tolist() | |
elif isinstance(x, jnp.float32): | |
if jnp.isnan(x): | |
return -float("inf") | |
return round(float(x) * 1000) / 1000 | |
elif isinstance(x, jnp.int32): | |
return int(x) | |
return x | |
clean_data = jax.tree_map(lambda x: _clean(x), clean_data) | |
print("Saving this data:", clean_data) | |
with open(f"{config['log_save_path']}/data.jsonl", "a+") as f: | |
f.write(json.dumps(clean_data) + "\n") | |
def compress_log_files_after_run(config): | |
fpath = f"{config['log_save_path']}/data.jsonl" | |
with open(fpath, "rb") as f_in, gzip.open(fpath + ".gz", "wb") as f_out: | |
f_out.writelines(f_in) | |
def get_video_frequency(config, update_step): | |
frac_through_training = update_step / config["num_updates"] | |
vid_frequency = ( | |
config["eval_freq"] | |
* config["video_frequency"] | |
* jax.lax.select( | |
(0.1 <= frac_through_training) & (frac_through_training < 0.3), | |
1, | |
jax.lax.select( | |
(0.3 <= frac_through_training) & (frac_through_training < 0.6), | |
2, | |
4, | |
), | |
) | |
) | |
return vid_frequency | |