Marc-Alexandre Côté
Add autoscrolling
2d033b9
raw
history blame
6.11 kB
import streamlit as st
import streamlit.components.v1 as components
from textworld_express import TextWorldExpressEnv
description = """
[ArXiv Paper](https://arxiv.org/abs/2208.01174) | [Github Repo](https://github.com/cognitiveailab/TextWorldExpress)
"""
st.title("TextWorldExpress Demo")
st.markdown(description)
# Apply custom CSS.
with open('style.css')as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
env = st.session_state.get("env")
if env is None:
env = TextWorldExpressEnv()
st.session_state["env"] = env
#seed = st.session_state.get("seed")
obs = st.session_state.get("obs")
infos = st.session_state.get("infos")
history = st.session_state.get("history")
if history is None:
history = []
st.session_state["history"] = history
def clear_history():
history.clear()
with st.sidebar:
st.title("TextWorldExpress Demo")
st.markdown(description)
game = st.selectbox("Game:", env.getGameNames(), on_change=clear_history)
with st.expander("Settings"):
seed = st.number_input("Seed:", 0, 2**16, value=4242, on_change=clear_history)
if game == "cookingworld":
nb_ingredients = st.number_input("Ingredients:", 1, 5, value=3, on_change=clear_history,
help="The number of ingredients to use in generating the random recipe.")
nb_locations = st.number_input("Locations:", 1, 11, value=5, on_change=clear_history,
help="The number of locations in the environment.")
nb_distractors = st.number_input("Distractors:", 0, 10, value=5, on_change=clear_history,
help="The number of distractor ingredients (not used in the recipe) in the environment.")
with_doors = st.checkbox("With doors?", value=True, on_change=clear_history,
help="Whether rooms have doors that need to be opened.")
limited_inventory = st.checkbox("Limit inventory?", value=False, on_change=clear_history,
help="Whether the size of the inventory is limited.")
params = f"numLocations={nb_locations},numIngredients={nb_ingredients},numDistractorItems={nb_distractors},includeDoors={int(with_doors)},limitInventorySize={int(limited_inventory)}"
elif game == "twc":
nb_items = st.number_input("Items:", 1, 10, value=3, on_change=clear_history,
help="The number of items to put away.")
nb_locations = st.number_input("Locations:", 1, 3, value=3, on_change=clear_history,
help="The number of locations in the environment.")
with_doors = st.checkbox("With doors?", value=True, on_change=clear_history,
help="Whether rooms have doors that need to be opened.")
limited_inventory = st.checkbox("Limit inventory?", value=False, on_change=clear_history,
help="Whether the size of the inventory is limited.")
params = f"numLocations={nb_locations},numItemsToPutAway={nb_items},includeDoors={int(with_doors)},limitInventorySize={int(limited_inventory)}"
elif game == "coin":
nb_locations = st.number_input("Locations:", 1, 11, value=3, on_change=clear_history,
help="The number of locations in the environment.")
nb_distractors = st.number_input("Distractors:", 0, 10, value=5, on_change=clear_history,
help="The number of distractor (i.e. non-coin) items in the environment.")
with_doors = st.checkbox("With doors?", value=True, on_change=clear_history,
help="Whether rooms have doors that need to be opened.")
limited_inventory = st.checkbox("Limit inventory?", value=False, on_change=clear_history,
help="Whether the size of the inventory is limited.")
params = f"numLocations={nb_locations},numDistractorItems={nb_distractors},includeDoors={int(with_doors)},limitInventorySize={int(limited_inventory)}"
else:
params=""
if len(history) == 0:
obs, infos = env.reset(int(seed), gameFold="train", gameName=str(game), gameParams=params)
obs, reward, done, infos = env.step("look around")
st.session_state["obs"] = obs
st.session_state["infos"] = infos
history.append(("", env.getTaskDescription()))
history.append(("look around", obs))
def step():
act = st.session_state.action
if act:
obs, reward, done, infos = env.step(act)
history.append((act, obs))
st.session_state["obs"] = obs
st.session_state["infos"] = infos
if act == "reset":
clear_history()
with st.sidebar:
# st.warning(env.getTaskDescription())
st.success(f"Score: {infos['score']}")
valid_actions = [""] + sorted(infos["validActions"])
if infos['done']:
valid_actions = ["", "reset"]
#act = st.selectbox('Action:', options=valid_actions, index=0, on_change=step, key="action")
for act, obs in history:
if act:
st.write("> " + act)
if obs:
st.info(obs.replace('\n ', '\n- ').replace('\n\t', '\n- '))
act = st.selectbox('Next action:', options=valid_actions, index=0, on_change=step, key="action")
st.warning(f"Current score: {infos['score']} out of 1.0")
if infos['tasksuccess']:
with st.sidebar:
st.balloons()
st.success("Congratulations! You have completed the task.")
elif infos['taskfailure']:
with st.sidebar:
st.snow()
st.error("You have failed the task.")
# Auto scroll at the bottom of the page.
components.html(
f"""
<p>{st.session_state.obs}</p>
<script>
window.parent.document.querySelector('section.main').scrollTo(0, window.parent.document.querySelector('section.main').scrollHeight);
</script>
# """,
height=0
)