Update app.py
Browse files
app.py
CHANGED
@@ -1,14 +1,18 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
4 |
|
5 |
-
|
6 |
model_id = "Narrativaai/BioGPT-Large-finetuned-chatdoctor"
|
7 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
|
8 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
9 |
|
10 |
|
11 |
-
def answer_question(
|
|
|
|
|
12 |
inputs = tokenizer(prompt, return_tensors="pt")
|
13 |
input_ids = inputs["input_ids"].to("cpu")
|
14 |
attention_mask = inputs["attention_mask"].to("cpu")
|
@@ -30,31 +34,99 @@ def answer_question(prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2,
|
|
30 |
return output.split(" Response:")[1]
|
31 |
|
32 |
|
33 |
-
st.set_page_config(page_title="
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
# with open("./sidebar.md", "r") as sidebar_file:
|
38 |
-
# sidebar_content = sidebar_file.read()
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
# st.sidebar.markdown(sidebar_content)
|
45 |
|
46 |
-
#
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
st.
|
|
|
50 |
|
51 |
-
#
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
st.write("")
|
58 |
-
st.write("Bot:", bot_response)
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import markdown
|
2 |
import streamlit as st
|
3 |
+
from streamlit_chat import message
|
4 |
+
from streamlit_extras.colored_header import colored_header
|
5 |
import torch
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
7 |
|
|
|
8 |
model_id = "Narrativaai/BioGPT-Large-finetuned-chatdoctor"
|
9 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
|
10 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
11 |
|
12 |
|
13 |
+
def answer_question(
|
14 |
+
prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2, **kwargs
|
15 |
+
):
|
16 |
inputs = tokenizer(prompt, return_tensors="pt")
|
17 |
input_ids = inputs["input_ids"].to("cpu")
|
18 |
attention_mask = inputs["attention_mask"].to("cpu")
|
|
|
34 |
return output.split(" Response:")[1]
|
35 |
|
36 |
|
37 |
+
st.set_page_config(page_title="Talk To Me", page_icon=":ambulance:", layout="wide")
|
38 |
+
|
39 |
+
colored_header(
|
40 |
+
label="Talk To Me",
|
41 |
+
description="Talk your way to better health",
|
42 |
+
color_name="violet-70",
|
43 |
+
)
|
44 |
+
|
45 |
+
# st.title("Talk To Me")
|
46 |
+
# st.caption("Talk your way to better health")
|
47 |
+
|
48 |
+
# add sidebar
|
49 |
+
with open("./sidebar.md", "r") as sidebar_file:
|
50 |
+
sidebar_content = sidebar_file.read()
|
51 |
+
|
52 |
+
with open("./styles.md", "r") as styles_file:
|
53 |
+
styles_content = styles_file.read()
|
54 |
+
|
55 |
+
|
56 |
+
def add_sbg_from_url():
|
57 |
+
st.markdown(
|
58 |
+
f"""
|
59 |
+
<style>
|
60 |
+
.css-6qob1r {{
|
61 |
+
background-image: url("https://images.unsplash.com/photo-1524169358666-79f22534bc6e?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=3540&q=80");
|
62 |
+
background-attachment: fixed;
|
63 |
+
background-size: cover
|
64 |
+
}}
|
65 |
+
</style>
|
66 |
+
""",
|
67 |
+
unsafe_allow_html=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
add_sbg_from_url()
|
72 |
+
|
73 |
+
|
74 |
+
def add_mbg_from_url():
|
75 |
+
st.markdown(
|
76 |
+
f"""
|
77 |
+
<style>
|
78 |
+
.stApp {{
|
79 |
+
background-image: url("https://images.unsplash.com/photo-1536353602887-521e965eb03f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=987&q=80");
|
80 |
+
background-attachment: fixed;
|
81 |
+
background-size: cover
|
82 |
+
}}
|
83 |
+
</style>
|
84 |
+
""",
|
85 |
+
unsafe_allow_html=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
add_mbg_from_url()
|
90 |
|
|
|
|
|
91 |
|
92 |
+
# Display the sidebar content
|
93 |
+
st.sidebar.markdown(sidebar_content)
|
94 |
|
95 |
+
st.write(styles_content, unsafe_allow_html=True)
|
|
|
96 |
|
97 |
+
# Initialize session state
|
98 |
+
if "chat_history" not in st.session_state:
|
99 |
+
st.session_state.chat_history = []
|
100 |
|
101 |
+
# display default message if no chat history
|
102 |
+
if not st.session_state.chat_history:
|
103 |
+
message("Hi, I'm a medical chat bot. Ask me a question!")
|
104 |
|
105 |
+
# Display the chat history
|
106 |
+
for chat in st.session_state.chat_history:
|
107 |
+
if chat["is_user"]:
|
108 |
+
message(chat["message"], is_user=True)
|
109 |
+
else:
|
110 |
+
message(chat["message"])
|
111 |
|
112 |
+
with st.form("user_input_form"):
|
113 |
+
st.write("Please enter your question below:")
|
114 |
+
user_input = st.text_input("You: ")
|
|
|
|
|
115 |
|
116 |
+
# Check if user has submitted a question
|
117 |
+
if st.form_submit_button("Submit") and user_input:
|
118 |
+
with st.spinner('Loading model and generating response...'):
|
119 |
+
# Generate response and update chat history
|
120 |
+
bot_response = answer_question(f"Input: {user_input}\nResponse:")
|
121 |
+
st.session_state.chat_history.append({"message": user_input, "is_user": True})
|
122 |
+
st.session_state.chat_history.append(
|
123 |
+
{"message": bot_response, "is_user": False}
|
124 |
+
)
|
125 |
|
126 |
+
# Display the latest chat in the chat history
|
127 |
+
if st.session_state.chat_history:
|
128 |
+
latest_chat = st.session_state.chat_history[-1]
|
129 |
+
if latest_chat["is_user"]:
|
130 |
+
message(latest_chat["message"], is_user=True)
|
131 |
+
else:
|
132 |
+
message(latest_chat["message"])
|