Spaces:
Runtime error
Runtime error
import fnmatch | |
import os | |
from typing import Dict, SupportsFloat | |
import gymnasium as gym | |
import numpy as np | |
import torch | |
from gymnasium import wrappers | |
from huggingface_hub import HfApi | |
from huggingface_hub.utils._errors import EntryNotFoundError | |
from src.logging import setup_logger | |
logger = setup_logger(__name__) | |
API = HfApi(token=os.environ.get("TOKEN")) | |
logger.info(f"Is CUDA available: {torch.cuda.is_available()}") | |
logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
ALL_ENV_IDS = [ | |
"AdventureNoFrameskip-v4", | |
"AirRaidNoFrameskip-v4", | |
"AlienNoFrameskip-v4", | |
"AmidarNoFrameskip-v4", | |
"AssaultNoFrameskip-v4", | |
"AsterixNoFrameskip-v4", | |
"AsteroidsNoFrameskip-v4", | |
"AtlantisNoFrameskip-v4", | |
"BankHeistNoFrameskip-v4", | |
"BattleZoneNoFrameskip-v4", | |
"BeamRiderNoFrameskip-v4", | |
"BerzerkNoFrameskip-v4", | |
"BowlingNoFrameskip-v4", | |
"BoxingNoFrameskip-v4", | |
"BreakoutNoFrameskip-v4", | |
"CarnivalNoFrameskip-v4", | |
"CentipedeNoFrameskip-v4", | |
"ChopperCommandNoFrameskip-v4", | |
"CrazyClimberNoFrameskip-v4", | |
"DefenderNoFrameskip-v4", | |
"DemonAttackNoFrameskip-v4", | |
"DoubleDunkNoFrameskip-v4", | |
"ElevatorActionNoFrameskip-v4", | |
"EnduroNoFrameskip-v4", | |
"FishingDerbyNoFrameskip-v4", | |
"FreewayNoFrameskip-v4", | |
"FrostbiteNoFrameskip-v4", | |
"GopherNoFrameskip-v4", | |
"GravitarNoFrameskip-v4", | |
"HeroNoFrameskip-v4", | |
"IceHockeyNoFrameskip-v4", | |
"JamesbondNoFrameskip-v4", | |
"JourneyEscapeNoFrameskip-v4", | |
"KangarooNoFrameskip-v4", | |
"KrullNoFrameskip-v4", | |
"KungFuMasterNoFrameskip-v4", | |
"MontezumaRevengeNoFrameskip-v4", | |
"MsPacmanNoFrameskip-v4", | |
"NameThisGameNoFrameskip-v4", | |
"PhoenixNoFrameskip-v4", | |
"PitfallNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"PooyanNoFrameskip-v4", | |
"PrivateEyeNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
"RiverraidNoFrameskip-v4", | |
"RoadRunnerNoFrameskip-v4", | |
"RobotankNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"SkiingNoFrameskip-v4", | |
"SolarisNoFrameskip-v4", | |
"SpaceInvadersNoFrameskip-v4", | |
"StarGunnerNoFrameskip-v4", | |
"TennisNoFrameskip-v4", | |
"TimePilotNoFrameskip-v4", | |
"TutankhamNoFrameskip-v4", | |
"UpNDownNoFrameskip-v4", | |
"VentureNoFrameskip-v4", | |
"VideoPinballNoFrameskip-v4", | |
"WizardOfWorNoFrameskip-v4", | |
"YarsRevengeNoFrameskip-v4", | |
"ZaxxonNoFrameskip-v4", | |
# Box2D | |
"BipedalWalker-v3", | |
"BipedalWalkerHardcore-v3", | |
"CarRacing-v2", | |
"LunarLander-v2", | |
"LunarLanderContinuous-v2", | |
# Toy text | |
"Blackjack-v1", | |
"CliffWalking-v0", | |
"FrozenLake-v1", | |
"FrozenLake8x8-v1", | |
# Classic control | |
"Acrobot-v1", | |
"CartPole-v1", | |
"MountainCar-v0", | |
"MountainCarContinuous-v0", | |
"Pendulum-v1", | |
# MuJoCo | |
"Ant-v4", | |
"HalfCheetah-v4", | |
"Hopper-v4", | |
"Humanoid-v4", | |
"HumanoidStandup-v4", | |
"InvertedDoublePendulum-v4", | |
"InvertedPendulum-v4", | |
"Pusher-v4", | |
"Reacher-v4", | |
"Swimmer-v4", | |
"Walker2d-v4", | |
] | |
class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
""" | |
Sample initial states by taking random number of no-ops on reset. | |
No-op is assumed to be action 0. | |
:param env: Environment to wrap | |
:param noop_max: Maximum value of no-ops to run | |
""" | |
def __init__(self, env: gym.Env, noop_max: int = 30) -> None: | |
super().__init__(env) | |
self.noop_max = noop_max | |
self.override_num_noops = None | |
self.noop_action = 0 | |
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] | |
def reset(self, **kwargs): | |
self.env.reset(**kwargs) | |
if self.override_num_noops is not None: | |
noops = self.override_num_noops | |
else: | |
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) | |
assert noops > 0 | |
obs = np.zeros(0) | |
info: Dict = {} | |
for _ in range(noops): | |
obs, _, terminated, truncated, info = self.env.step(self.noop_action) | |
if terminated or truncated: | |
obs, info = self.env.reset(**kwargs) | |
return obs, info | |
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
""" | |
Take action on reset for environments that are fixed until firing. | |
:param env: Environment to wrap | |
""" | |
def __init__(self, env: gym.Env) -> None: | |
super().__init__(env) | |
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] | |
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] | |
def reset(self, **kwargs): | |
self.env.reset(**kwargs) | |
obs, _, terminated, truncated, _ = self.env.step(1) | |
if terminated or truncated: | |
self.env.reset(**kwargs) | |
obs, _, terminated, truncated, _ = self.env.step(2) | |
if terminated or truncated: | |
self.env.reset(**kwargs) | |
return obs, {} | |
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
""" | |
Make end-of-life == end-of-episode, but only reset on true game over. | |
Done by DeepMind for the DQN and co. since it helps value estimation. | |
:param env: Environment to wrap | |
""" | |
def __init__(self, env: gym.Env) -> None: | |
super().__init__(env) | |
self.lives = 0 | |
self.was_real_done = True | |
def step(self, action: int): | |
obs, reward, terminated, truncated, info = self.env.step(action) | |
self.was_real_done = terminated or truncated | |
# check current lives, make loss of life terminal, | |
# then update lives to handle bonus lives | |
lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] | |
if 0 < lives < self.lives: | |
# for Qbert sometimes we stay in lives == 0 condition for a few frames | |
# so its important to keep lives > 0, so that we only reset once | |
# the environment advertises done. | |
terminated = True | |
self.lives = lives | |
return obs, reward, terminated, truncated, info | |
def reset(self, **kwargs): | |
""" | |
Calls the Gym environment reset, only when lives are exhausted. | |
This way all states are still reachable even though lives are episodic, | |
and the learner need not know about any of this behind-the-scenes. | |
:param kwargs: Extra keywords passed to env.reset() call | |
:return: the first observation of the environment | |
""" | |
if self.was_real_done: | |
obs, info = self.env.reset(**kwargs) | |
else: | |
# no-op step to advance from terminal/lost life state | |
obs, _, terminated, truncated, info = self.env.step(0) | |
# The no-op step can lead to a game over, so we need to check it again | |
# to see if we should reset the environment and avoid the | |
# monitor.py `RuntimeError: Tried to step environment that needs reset` | |
if terminated or truncated: | |
obs, info = self.env.reset(**kwargs) | |
self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] | |
return obs, info | |
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
""" | |
Return only every ``skip``-th frame (frameskipping) | |
and return the max between the two last frames. | |
:param env: Environment to wrap | |
:param skip: Number of ``skip``-th frame | |
The same action will be taken ``skip`` times. | |
""" | |
def __init__(self, env: gym.Env, skip: int = 4) -> None: | |
super().__init__(env) | |
# most recent raw observations (for max pooling across time steps) | |
assert env.observation_space.dtype is not None, "No dtype specified for the observation space" | |
assert env.observation_space.shape is not None, "No shape defined for the observation space" | |
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) | |
self._skip = skip | |
def step(self, action: int): | |
""" | |
Step the environment with the given action | |
Repeat action, sum reward, and max over last observations. | |
:param action: the action | |
:return: observation, reward, terminated, truncated, information | |
""" | |
total_reward = 0.0 | |
terminated = truncated = False | |
for i in range(self._skip): | |
obs, reward, terminated, truncated, info = self.env.step(action) | |
done = terminated or truncated | |
if i == self._skip - 2: | |
self._obs_buffer[0] = obs | |
if i == self._skip - 1: | |
self._obs_buffer[1] = obs | |
total_reward += float(reward) | |
if done: | |
break | |
# Note that the observation on the done=True frame | |
# doesn't matter | |
max_frame = self._obs_buffer.max(axis=0) | |
return max_frame, total_reward, terminated, truncated, info | |
class ClipRewardEnv(gym.RewardWrapper): | |
""" | |
Clip the reward to {+1, 0, -1} by its sign. | |
:param env: Environment to wrap | |
""" | |
def __init__(self, env: gym.Env) -> None: | |
super().__init__(env) | |
def reward(self, reward: SupportsFloat) -> float: | |
""" | |
Bin reward to {+1, 0, -1} by its sign. | |
:param reward: | |
:return: | |
""" | |
return np.sign(float(reward)) | |
def make(env_id): | |
def thunk(): | |
env = gym.make(env_id) | |
env = wrappers.RecordEpisodeStatistics(env) | |
if "NoFrameskip" in env_id: | |
env = NoopResetEnv(env, noop_max=30) | |
env = MaxAndSkipEnv(env, skip=4) | |
env = EpisodicLifeEnv(env) | |
if "FIRE" in env.unwrapped.get_action_meanings(): | |
env = FireResetEnv(env) | |
env = ClipRewardEnv(env) | |
env = wrappers.ResizeObservation(env, (84, 84)) | |
env = wrappers.GrayScaleObservation(env) | |
env = wrappers.FrameStack(env, 4) | |
return env | |
return thunk | |
def pattern_match(patterns, source_list): | |
if isinstance(patterns, str): | |
patterns = [patterns] | |
env_ids = set() | |
for pattern in patterns: | |
for matching in fnmatch.filter(source_list, pattern): | |
env_ids.add(matching) | |
return sorted(list(env_ids)) | |
def evaluate(model_id, revision): | |
tags = API.model_info(model_id, revision=revision).tags | |
# Extract the environment IDs from the tags (usually only one) | |
env_ids = pattern_match(tags, ALL_ENV_IDS) | |
logger.info(f"Selected environments: {env_ids}") | |
results = {} | |
# Check if the agent exists | |
try: | |
agent_path = API.hf_hub_download(repo_id=model_id, filename="agent.pt") | |
except EntryNotFoundError: | |
logger.error("Agent not found") | |
return None | |
# Check safety | |
security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security | |
if security is None or "safe" not in security: | |
logger.error("Agent safety not available") | |
return None | |
elif not security["safe"]: | |
logger.error("Agent not safe") | |
return None | |
# Load the agent | |
try: | |
agent = torch.jit.load(agent_path).to("cuda") | |
except Exception as e: | |
logger.error(f"Error loading agent: {e}") | |
return None | |
# Evaluate the agent on the environments | |
for env_id in env_ids: | |
envs = gym.vector.SyncVectorEnv([make(env_id) for _ in range(3)]) | |
observations, _ = envs.reset() | |
episodic_returns = [] | |
while len(episodic_returns) < 10: | |
actions = agent(torch.tensor(observations, device="cuda")).cpu().numpy() | |
observations, _, _, _, infos = envs.step(actions) | |
if "final_info" in infos: | |
for info in infos["final_info"]: | |
if info is None or "episode" not in info: | |
continue | |
episodic_returns.append(float(info["episode"]["r"])) | |
results[env_id] = {"episodic_returns": episodic_returns} | |
logger.info(f"Environment {env_id}: {np.mean(episodic_returns)} ± {np.std(episodic_returns)}") | |
return results | |