Spaces:
Runtime error
Runtime error
Quentin Gallouédec
commited on
Commit
·
1cbc1b7
1
Parent(s):
3f2add7
mujoco
Browse files- app.py +9 -7
- requirements.txt +2 -0
app.py
CHANGED
@@ -67,7 +67,7 @@ def evaluate(model_id, revision):
|
|
67 |
return None
|
68 |
|
69 |
# Check safety
|
70 |
-
security = next(iter(API.
|
71 |
if security is None or "safe" not in security:
|
72 |
logger.error("Agent safety not available")
|
73 |
return None
|
@@ -100,7 +100,8 @@ def evaluate(model_id, revision):
|
|
100 |
episodic_rewards.append(episodic_reward)
|
101 |
|
102 |
mean_reward = np.mean(episodic_rewards)
|
103 |
-
|
|
|
104 |
return results
|
105 |
|
106 |
|
@@ -195,7 +196,7 @@ def get_leaderboard_df():
|
|
195 |
model_id = report["config"]["model_id"]
|
196 |
row = {"Agent": model_id, "Status": report["status"]}
|
197 |
if report["status"] == "DONE":
|
198 |
-
results = {env_id: result["
|
199 |
row.update(results)
|
200 |
data.append(row)
|
201 |
|
@@ -237,8 +238,7 @@ with gr.Blocks(js=dark_mode_gradio_js) as demo:
|
|
237 |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
238 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
239 |
with gr.TabItem("🏅 Leaderboard", elem_id="llm-benchmark-tab-table", id=0):
|
240 |
-
|
241 |
-
hidden_df = gr.components.Dataframe(full_df, visible=False) # hidden dataframe
|
242 |
|
243 |
env_checkboxes = gr.components.CheckboxGroup(
|
244 |
label="Environments",
|
@@ -246,17 +246,19 @@ with gr.Blocks(js=dark_mode_gradio_js) as demo:
|
|
246 |
value=[ALL_ENV_IDS[0]],
|
247 |
interactive=True,
|
248 |
)
|
249 |
-
leaderboard = gr.components.Dataframe(select_column([ALL_ENV_IDS[0]],
|
250 |
|
251 |
# Events
|
252 |
env_checkboxes.change(select_column, [env_checkboxes, hidden_df], leaderboard)
|
|
|
|
|
253 |
|
254 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
|
255 |
gr.Markdown(ABOUT_TEXT)
|
256 |
|
257 |
|
258 |
scheduler = BackgroundScheduler()
|
259 |
-
scheduler.add_job(func=backend_routine, trigger="interval", seconds=
|
260 |
scheduler.start()
|
261 |
|
262 |
|
|
|
67 |
return None
|
68 |
|
69 |
# Check safety
|
70 |
+
security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security
|
71 |
if security is None or "safe" not in security:
|
72 |
logger.error("Agent safety not available")
|
73 |
return None
|
|
|
100 |
episodic_rewards.append(episodic_reward)
|
101 |
|
102 |
mean_reward = np.mean(episodic_rewards)
|
103 |
+
std_reward = np.std(episodic_rewards)
|
104 |
+
results[env_id] = {"episodic_return_mean": mean_reward, "episodic_reward_std": std_reward}
|
105 |
return results
|
106 |
|
107 |
|
|
|
196 |
model_id = report["config"]["model_id"]
|
197 |
row = {"Agent": model_id, "Status": report["status"]}
|
198 |
if report["status"] == "DONE":
|
199 |
+
results = {env_id: result["episodic_return_mean"] for env_id, result in report["results"].items()}
|
200 |
row.update(results)
|
201 |
data.append(row)
|
202 |
|
|
|
238 |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
239 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
240 |
with gr.TabItem("🏅 Leaderboard", elem_id="llm-benchmark-tab-table", id=0):
|
241 |
+
hidden_df = gr.components.Dataframe(get_leaderboard_df, visible=False, every=60) # hidden dataframe
|
|
|
242 |
|
243 |
env_checkboxes = gr.components.CheckboxGroup(
|
244 |
label="Environments",
|
|
|
246 |
value=[ALL_ENV_IDS[0]],
|
247 |
interactive=True,
|
248 |
)
|
249 |
+
leaderboard = gr.components.Dataframe(select_column([ALL_ENV_IDS[0]], get_leaderboard_df()))
|
250 |
|
251 |
# Events
|
252 |
env_checkboxes.change(select_column, [env_checkboxes, hidden_df], leaderboard)
|
253 |
+
# Update hidden dataframe
|
254 |
+
# hidden_df.change(get_leaderboard_df, [], hidden_df, every=10)
|
255 |
|
256 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=2):
|
257 |
gr.Markdown(ABOUT_TEXT)
|
258 |
|
259 |
|
260 |
scheduler = BackgroundScheduler()
|
261 |
+
scheduler.add_job(func=backend_routine, trigger="interval", seconds=60)
|
262 |
scheduler.start()
|
263 |
|
264 |
|
requirements.txt
CHANGED
@@ -13,6 +13,8 @@ python-dateutil==2.8.2
|
|
13 |
requests==2.28.2
|
14 |
torch==2.2.2
|
15 |
tqdm==4.65.0
|
|
|
|
|
16 |
|
17 |
# Log Visualizer
|
18 |
BeautifulSoup4==4.12.2
|
|
|
13 |
requests==2.28.2
|
14 |
torch==2.2.2
|
15 |
tqdm==4.65.0
|
16 |
+
cython<3
|
17 |
+
free-mujoco-py
|
18 |
|
19 |
# Log Visualizer
|
20 |
BeautifulSoup4==4.12.2
|