sgoodfriend's picture
PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
460072a
raw
history blame
3.1 kB
from dataclasses import astuple
from typing import Optional
import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
def make_microrts_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
import gym_microrts
from gym_microrts import microrts_ai
from rl_algo_impls.shared.vec_env.microrts_compat import (
MicroRTSGridModeVecEnvCompat,
)
(
_, # env_type
n_envs,
_, # frame_stack
make_kwargs,
_, # no_reward_timeout_steps
_, # no_reward_fire_steps
_, # vec_env_class
_, # normalize
_, # normalize_kwargs,
rolling_length,
_, # train_record_video
_, # video_step_interval
_, # initial_steps_to_truncate
_, # clip_atari_rewards
_, # normalize_type
_, # mask_actions
bots,
) = astuple(hparams)
seed = config.seed(training=training)
make_kwargs = make_kwargs or {}
if "num_selfplay_envs" not in make_kwargs:
make_kwargs["num_selfplay_envs"] = 0
if "num_bot_envs" not in make_kwargs:
make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"]
if "reward_weight" in make_kwargs:
make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
if bots:
ai2s = []
for ai_name, n in bots.items():
for _ in range(n):
if len(ai2s) >= make_kwargs["num_bot_envs"]:
break
ai = getattr(microrts_ai, ai_name)
assert ai, f"{ai_name} not in microrts_ai"
ai2s.append(ai)
else:
ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]]
make_kwargs["ai2s"] = ai2s
envs = MicroRTSGridModeVecEnvCompat(**make_kwargs)
envs = HwcToChwObservation(envs)
envs = IsVectorEnv(envs)
envs = MicrortsMaskWrapper(envs)
if seed is not None:
envs.action_space.seed(seed)
envs.observation_space.seed(seed)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99))
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs,
tb_writer,
training=training,
rolling_length=rolling_length,
additional_keys_to_log=config.additional_keys_to_log,
)
return envs