ThomasFfefefef's picture
Added work in progress generate video live
1c71ebe
import gradio as gr
import os
from moviepy.editor import *
def replay(option):
path = ""
# Get the correct model
if (option == "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"):
path = "./LunarLander-v2.mp4"
elif(option == "CartPole-v1 πŸ•ΉοΈ"):
path = "./CartPole-v1.mp4"
elif(option == "Atari Space Invaders πŸ‘Ύ"):
path = "./SpaceInvadersNoFrameskip-v4.mp4"
# The only turnaround I found (Base64 video pb)
videoclip = VideoFileClip(path)
videoclip.write_videofile("new_filename.mp4")
return 'new_filename.mp4'
iface = gr.Interface(
replay,
[
gr.inputs.Dropdown(["Atari Space Invaders πŸ‘Ύ", "CartPole-v1 πŸ•ΉοΈ", "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"]),
],
"video",
title = 'Stable Baselines 3 with πŸ€—',
description = '',
article =
'''<div>
<p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
<p style="text-align: center"> Select the trained agent you want to watch perform.
These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
<p>
There are currently 3 models:
<ul>
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
</ul>
</div>'''
)
iface.launch()
"""
TODO: Next version with live video generation
import gradio as gr
import os
from Recorder import Recorder
from stable_baselines3 import PPO
#The Agent plays and we generate the video
def replay(option):
video_path = ""
# Get the correct model
if (option == "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"):
env_name = "Lunar Lander v2"
agent_name = "PPO"
print("TEST")
hf_model_filename = "LunarLander-v2"
hf_model_id = "ThomasSimonini/stable-baselines3-ppo-LunarLander-v2"
video_path = replay_gym(hf_model_filename, hf_model_id)
elif(option == "CartPole-v1 πŸ•ΉοΈ"):
hf_model_filename = "CartPole-v1"
hf_model_id = "ThomasSimonini/stable-baselines3-ppo-CartPole-v1"
video_path = replay_gym(hf_model_filename, hf_model_id)
elif(option == "Atari Space Invaders πŸ‘Ύ"):
hf_model_filename = "SpaceInvadersNoFrameskip-v4"
hf_model_id = "ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4"
video_path = replay_atari(hf_model_filename, hf_model_id)
#video_path = "./SpaceInvadersNoFrameskip-v4.mp4"
return video_path
def replay_gym(hf_model_filename, hf_model_id):
import gym
from stable_baselines3.common.evaluation import evaluate_policy
model = PPO.load_from_huggingface(hf_model_id,hf_model_filename)
eval_env = gym.make(hf_model_filename)
directory = './video'
env = Recorder(eval_env, directory)
obs = env.reset()
done = False
while not done:
action, _state = model.predict(obs)
obs, reward, done, info = env.step(action)
clip = env.play()
return clip
def replay_atari(hf_model_filename, hf_model_id):
os.system("python -m atari_py.import_roms \"content/atari_roms\"")
import gym
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
model = PPO.load_from_huggingface(hf_model_id, hf_model_filename)
eval_env = make_atari_env(hf_model_filename, n_envs=1, seed=0)
eval_env = VecFrameStack(eval_env, n_stack=4)
model = PPO.load_from_huggingface(hf_model_id, hf_model_filename)
import gym
directory = './video'
env = Recorder(eval_env, directory)
obs = env.reset()
done = False
while not done:
action, _state = model.predict(obs)
obs, reward, done, info = env.step(action)
clip = env.play()
return clip
iface = gr.Interface(
replay,
[
gr.inputs.Dropdown(["Atari Space Invaders πŸ‘Ύ", "CartPole-v1 πŸ•ΉοΈ", "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"]),
],
"video",
title = 'Stable Baselines 3 with πŸ€—',
description = '',
article =
'''<div>
<p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
<p style="text-align: center"> Select the trained agent you want to watch perform. We record your agent playing.
<p style="text-align: center"> Don't forget to <b>click on clear between each record.</b> </p>
These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
<p>
There are currently 3 models:
<ul>
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
</ul>
</div>'''
)
iface.launch()
"""