PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
460072a
from dataclasses import astuple | |
from typing import Optional | |
import gym | |
import numpy as np | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from rl_algo_impls.runner.config import Config, EnvHyperparams | |
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper | |
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter | |
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation | |
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv | |
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder | |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv | |
def make_microrts_env( | |
config: Config, | |
hparams: EnvHyperparams, | |
training: bool = True, | |
render: bool = False, | |
normalize_load_path: Optional[str] = None, | |
tb_writer: Optional[SummaryWriter] = None, | |
) -> VecEnv: | |
import gym_microrts | |
from gym_microrts import microrts_ai | |
from rl_algo_impls.shared.vec_env.microrts_compat import ( | |
MicroRTSGridModeVecEnvCompat, | |
) | |
( | |
_, # env_type | |
n_envs, | |
_, # frame_stack | |
make_kwargs, | |
_, # no_reward_timeout_steps | |
_, # no_reward_fire_steps | |
_, # vec_env_class | |
_, # normalize | |
_, # normalize_kwargs, | |
rolling_length, | |
_, # train_record_video | |
_, # video_step_interval | |
_, # initial_steps_to_truncate | |
_, # clip_atari_rewards | |
_, # normalize_type | |
_, # mask_actions | |
bots, | |
) = astuple(hparams) | |
seed = config.seed(training=training) | |
make_kwargs = make_kwargs or {} | |
if "num_selfplay_envs" not in make_kwargs: | |
make_kwargs["num_selfplay_envs"] = 0 | |
if "num_bot_envs" not in make_kwargs: | |
make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"] | |
if "reward_weight" in make_kwargs: | |
make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"]) | |
if bots: | |
ai2s = [] | |
for ai_name, n in bots.items(): | |
for _ in range(n): | |
if len(ai2s) >= make_kwargs["num_bot_envs"]: | |
break | |
ai = getattr(microrts_ai, ai_name) | |
assert ai, f"{ai_name} not in microrts_ai" | |
ai2s.append(ai) | |
else: | |
ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]] | |
make_kwargs["ai2s"] = ai2s | |
envs = MicroRTSGridModeVecEnvCompat(**make_kwargs) | |
envs = HwcToChwObservation(envs) | |
envs = IsVectorEnv(envs) | |
envs = MicrortsMaskWrapper(envs) | |
if seed is not None: | |
envs.action_space.seed(seed) | |
envs.observation_space.seed(seed) | |
envs = gym.wrappers.RecordEpisodeStatistics(envs) | |
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99)) | |
if training: | |
assert tb_writer | |
envs = EpisodeStatsWriter( | |
envs, | |
tb_writer, | |
training=training, | |
rolling_length=rolling_length, | |
additional_keys_to_log=config.additional_keys_to_log, | |
) | |
return envs | |