Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from datetime import datetime as dt | |
import streamlit as st | |
from streamlit_tags import st_tags | |
import beam_search | |
import top_sampling | |
from pprint import pprint | |
import json | |
with open("config.json") as f: | |
cfg = json.loads(f.read()) | |
st.set_page_config(layout="wide") | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"]) | |
model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name_or_path"]) | |
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer) | |
return generator, tokenizer | |
def sampling_changed(obj): | |
print(obj) | |
with st.spinner('Loading model...'): | |
generator, tokenizer = load_model() | |
# st.image("images/chef-transformer.png", width=400) | |
st.header("Chef Transformer π©βπ³ / π¨βπ³") | |
st.markdown( | |
"This demo uses [T5 trained on recipe-nlg](https://huggingface.co/flax-community/t5-recipe-generation) " | |
"to generate recipe from a given set of ingredients" | |
) | |
img = st.sidebar.image("images/chef-transformer-transparent.png", width=310) | |
add_text_sidebar = st.sidebar.title("Popular recipes:") | |
add_text_sidebar = st.sidebar.text("Recipe preset(example#1)") | |
add_text_sidebar = st.sidebar.text("Recipe preset(example#2)") | |
add_text_sidebar = st.sidebar.title("Mode:") | |
sampling_mode = st.sidebar.selectbox("select a Mode", index=0, options=["Top Sampling", "Beam Search"]) | |
original_keywords = st.multiselect( | |
"Choose ingredients", | |
cfg["first_100"], | |
["parmesan cheese", "fresh oregano", "basil", "whole wheat flour"] | |
) | |
# st.write("Add custom ingredients here:") | |
# custom_keywords = st_tags( | |
# label="", | |
# text='Press enter to add more', | |
# value=['salt'], | |
# suggestions=["z"], | |
# maxtags=15, | |
# key='1') | |
def custom_keywords_on_change(): | |
pass | |
custom_keywords = st.text_input( | |
'Add custom ingredients here (separated by `,`): ', | |
", ".join(["salt", "pepper"]), | |
key="custom_keywords", | |
on_change=custom_keywords_on_change, | |
max_chars=1000) | |
custom_keywords = list(set([x.strip() for x in custom_keywords.strip().split(',') if len(x.strip()) > 0])) | |
all_ingredients = [] | |
all_ingredients.extend(original_keywords) | |
all_ingredients.extend(custom_keywords) | |
all_ingredients = ", ".join(all_ingredients) | |
st.markdown("**Generate recipe for:** " + all_ingredients) | |
submit = st.button('Get Recipe!') | |
if submit: | |
with st.spinner('Generating recipe...'): | |
if sampling_mode == "Beam Search": | |
generated = generator(all_ingredients, return_tensors=True, return_text=False, | |
**beam_search.generate_kwargs) | |
outputs = beam_search.post_generator(generated, tokenizer) | |
elif sampling_mode == "Top-k Sampling": | |
generated = generator(all_ingredients, return_tensors=True, return_text=False, | |
**top_sampling.generate_kwargs) | |
outputs = top_sampling.post_generator(generated, tokenizer) | |
output = outputs[0] | |
output['title'] = " ".join([w.capitalize() for w in output['title'].split()]) | |
markdown_output = "" | |
markdown_output += f"## {output['title']}\n" | |
markdown_output += f"#### Ingredients:\n" | |
for o in output["ingredients"]: | |
markdown_output += f"- {o}\n" | |
markdown_output += f"#### Directions:\n" | |
for o in output["directions"]: | |
markdown_output += f"- {o}\n" | |
st.markdown(markdown_output) | |