caliex commited on
Commit
5e29726
·
1 Parent(s): 7e0870a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -21
app.py CHANGED
@@ -1,14 +1,18 @@
 
1
  import streamlit as st
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
 
5
-
6
  model_id = "Narrativaai/BioGPT-Large-finetuned-chatdoctor"
7
  tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
8
  model = AutoModelForCausalLM.from_pretrained(model_id)
9
 
10
 
11
- def answer_question(prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2, **kwargs):
 
 
12
  inputs = tokenizer(prompt, return_tensors="pt")
13
  input_ids = inputs["input_ids"].to("cpu")
14
  attention_mask = inputs["attention_mask"].to("cpu")
@@ -30,31 +34,99 @@ def answer_question(prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2,
30
  return output.split(" Response:")[1]
31
 
32
 
33
- st.set_page_config(page_title="Medical Chat Bot", page_icon=":ambulance:", layout="wide")
34
- st.title("Medical Chat Bot")
35
- st.caption("Talk your way to better health")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # with open("./sidebar.md", "r") as sidebar_file:
38
- # sidebar_content = sidebar_file.read()
39
 
40
- # with open("./styles.md", "r") as styles_file:
41
- # styles_content = styles_file.read()
42
 
43
- # # Display the DDL for the selected table
44
- # st.sidebar.markdown(sidebar_content)
45
 
46
- # st.write(styles_content, unsafe_allow_html=True)
 
 
47
 
48
-
49
- st.write("Please enter your question below:")
 
50
 
51
- # get user input
52
- user_input = st.text_input("You: ")
 
 
 
 
53
 
54
- if user_input:
55
- # generate response
56
- bot_response = answer_question(f"Input: {user_input}\nResponse:")
57
- st.write("")
58
- st.write("Bot:", bot_response)
59
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
1
+ import markdown
2
  import streamlit as st
3
+ from streamlit_chat import message
4
+ from streamlit_extras.colored_header import colored_header
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
7
 
 
8
  model_id = "Narrativaai/BioGPT-Large-finetuned-chatdoctor"
9
  tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large")
10
  model = AutoModelForCausalLM.from_pretrained(model_id)
11
 
12
 
13
+ def answer_question(
14
+ prompt, temperature=0.1, top_p=0.75, top_k=40, num_beams=2, **kwargs
15
+ ):
16
  inputs = tokenizer(prompt, return_tensors="pt")
17
  input_ids = inputs["input_ids"].to("cpu")
18
  attention_mask = inputs["attention_mask"].to("cpu")
 
34
  return output.split(" Response:")[1]
35
 
36
 
37
+ st.set_page_config(page_title="Talk To Me", page_icon=":ambulance:", layout="wide")
38
+
39
+ colored_header(
40
+ label="Talk To Me",
41
+ description="Talk your way to better health",
42
+ color_name="violet-70",
43
+ )
44
+
45
+ # st.title("Talk To Me")
46
+ # st.caption("Talk your way to better health")
47
+
48
+ # add sidebar
49
+ with open("./sidebar.md", "r") as sidebar_file:
50
+ sidebar_content = sidebar_file.read()
51
+
52
+ with open("./styles.md", "r") as styles_file:
53
+ styles_content = styles_file.read()
54
+
55
+
56
+ def add_sbg_from_url():
57
+ st.markdown(
58
+ f"""
59
+ <style>
60
+ .css-6qob1r {{
61
+ background-image: url("https://images.unsplash.com/photo-1524169358666-79f22534bc6e?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=3540&q=80");
62
+ background-attachment: fixed;
63
+ background-size: cover
64
+ }}
65
+ </style>
66
+ """,
67
+ unsafe_allow_html=True,
68
+ )
69
+
70
+
71
+ add_sbg_from_url()
72
+
73
+
74
+ def add_mbg_from_url():
75
+ st.markdown(
76
+ f"""
77
+ <style>
78
+ .stApp {{
79
+ background-image: url("https://images.unsplash.com/photo-1536353602887-521e965eb03f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=987&q=80");
80
+ background-attachment: fixed;
81
+ background-size: cover
82
+ }}
83
+ </style>
84
+ """,
85
+ unsafe_allow_html=True,
86
+ )
87
+
88
+
89
+ add_mbg_from_url()
90
 
 
 
91
 
92
+ # Display the sidebar content
93
+ st.sidebar.markdown(sidebar_content)
94
 
95
+ st.write(styles_content, unsafe_allow_html=True)
 
96
 
97
+ # Initialize session state
98
+ if "chat_history" not in st.session_state:
99
+ st.session_state.chat_history = []
100
 
101
+ # display default message if no chat history
102
+ if not st.session_state.chat_history:
103
+ message("Hi, I'm a medical chat bot. Ask me a question!")
104
 
105
+ # Display the chat history
106
+ for chat in st.session_state.chat_history:
107
+ if chat["is_user"]:
108
+ message(chat["message"], is_user=True)
109
+ else:
110
+ message(chat["message"])
111
 
112
+ with st.form("user_input_form"):
113
+ st.write("Please enter your question below:")
114
+ user_input = st.text_input("You: ")
 
 
115
 
116
+ # Check if user has submitted a question
117
+ if st.form_submit_button("Submit") and user_input:
118
+ with st.spinner('Loading model and generating response...'):
119
+ # Generate response and update chat history
120
+ bot_response = answer_question(f"Input: {user_input}\nResponse:")
121
+ st.session_state.chat_history.append({"message": user_input, "is_user": True})
122
+ st.session_state.chat_history.append(
123
+ {"message": bot_response, "is_user": False}
124
+ )
125
 
126
+ # Display the latest chat in the chat history
127
+ if st.session_state.chat_history:
128
+ latest_chat = st.session_state.chat_history[-1]
129
+ if latest_chat["is_user"]:
130
+ message(latest_chat["message"], is_user=True)
131
+ else:
132
+ message(latest_chat["message"])