bharat-raghunathan commited on
Commit
373220b
·
verified ·
1 Parent(s): 9eda1ce

Added prompt templates to output correct diagnoses

Browse files
Files changed (1) hide show
  1. app.py +183 -170
app.py CHANGED
@@ -1,171 +1,184 @@
1
- import gradio as gr
2
- import os
3
- import requests
4
- from huggingface_hub import InferenceClient
5
- import google.generativeai as genai
6
- import openai
7
-
8
- def api_check_msg(api_key, selected_model):
9
- res = validate_api_key(api_key, selected_model)
10
- return res["message"]
11
-
12
- def validate_api_key(api_key, selected_model):
13
- # Check if the API key is valid for GPT-3.5-Turbo
14
- if "GPT" in selected_model:
15
- url = "https://api.openai.com/v1/models"
16
- headers = {
17
- "Authorization": f"Bearer {api_key}"
18
- }
19
- try:
20
- response = requests.get(url, headers=headers)
21
- if response.status_code == 200:
22
- return {"is_valid": True, "message": '<p style="color: green;">API Key is valid!</p>'}
23
- else:
24
- return {"is_valid": False, "message": f'<p style="color: red;">Invalid OpenAI API Key. Status code: {response.status_code}</p>'}
25
- except requests.exceptions.RequestException as e:
26
- return {"is_valid": False, "message": f'<p style="color: red;">Invalid OpenAI API Key. Error: {e}</p>'}
27
- elif "Llama" in selected_model:
28
- url = "https://huggingface.co/api/whoami-v2"
29
- headers = {
30
- "Authorization": f"Bearer {api_key}"
31
- }
32
- try:
33
- response = requests.get(url, headers=headers)
34
- if response.status_code == 200:
35
- return {"is_valid": True, "message": '<p style="color: green;">API Key is valid!</p>'}
36
- else:
37
- return {"is_valid": False, "message": f'<p style="color: red;">Invalid Hugging Face API Key. Status code: {response.status_code}</p>'}
38
- except requests.exceptions.RequestException as e:
39
- return {"is_valid": False, "message": f'<p style="color: red;">Invalid Hugging Face API Key. Error: {e}</p>'}
40
- elif "Gemini" in selected_model:
41
- try:
42
- genai.configure(api_key=api_key)
43
- model = genai.GenerativeModel("gemini-1.5-flash")
44
- response = model.generate_content("Help me diagnose the patient.")
45
- return {"is_valid": True, "message": '<p style="color: green;">API Key is valid!</p>'}
46
- except Exception as e:
47
- return {"is_valid": False, "message": f'<p style="color: red;">Invalid Google API Key. Error: {e}</p>'}
48
-
49
- def generate_text_chatgpt(key, prompt, temperature, top_p):
50
-
51
- openai.api_key = key
52
-
53
- response = openai.chat.completions.create(
54
- model="gpt-3.5-turbo-1106",
55
- messages=[{"role": "system", "content": "You are a talented diagnostician who is diagnosing a patient."},
56
- {"role": "user", "content": prompt}],
57
- temperature=temperature,
58
- max_tokens=50,
59
- top_p=top_p,
60
- frequency_penalty=0
61
- )
62
-
63
- return response.choices[0].message.content
64
-
65
-
66
- def generate_text_gemini(key, prompt, temperature, top_p):
67
- genai.configure(api_key=key)
68
-
69
- generation_config = genai.GenerationConfig(
70
- max_output_tokens=len(prompt)+50,
71
- temperature=temperature,
72
- top_p=top_p,
73
- )
74
- model = genai.GenerativeModel("gemini-1.5-flash", generation_config=generation_config)
75
- response = model.generate_content(prompt)
76
- return response.text
77
-
78
-
79
- def generate_text_llama(key, prompt, temperature, top_p):
80
- model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
81
- client = InferenceClient(api_key=key)
82
-
83
- messages = [{"role": "system", "content": "You are a talented diagnostician who is diagnosing a patient."},
84
- {"role": "user","content": prompt}]
85
-
86
- completion = client.chat.completions.create(
87
- model=model_name,
88
- messages=messages,
89
- max_tokens=len(prompt)+50,
90
- temperature=temperature,
91
- top_p=top_p
92
- )
93
-
94
- response = completion.choices[0].message.content
95
- if len(response) > len(prompt):
96
- return response[len(prompt):]
97
- return response
98
-
99
-
100
- def diagnose(key, model, top_k, temperature, symptom_prompt):
101
-
102
- model_map = {
103
- "GPT-3.5-Turbo": "GPT",
104
- "Llama-3": "Llama",
105
- "Gemini-1.5": "Gemini"
106
- }
107
- if symptom_prompt:
108
- if "GPT" in model:
109
- message = generate_text_chatgpt(key, symptom_prompt, temperature, top_k)
110
- elif "Llama" in model:
111
- message = generate_text_llama(key, symptom_prompt, temperature, top_k)
112
- elif "Gemini" in model:
113
- message = generate_text_gemini(key, symptom_prompt, temperature, top_k)
114
- else:
115
- message = "Incorrect model, please try again."
116
- else:
117
- message = "Please add the symptoms data"
118
-
119
- return message
120
-
121
- def update_model_components(selected_model):
122
- model_map = {
123
- "GPT-3.5-Turbo": "GPT",
124
- "Llama-3": "Llama",
125
- "Gemini-1.5": "Gemini"
126
- }
127
-
128
- link_map = {
129
- "GPT-3.5-Turbo": "https://platform.openai.com/account/api-keys",
130
- "Llama-3": "https://hf.co/settings/tokens",
131
- "Gemini-1.5": "https://aistudio.google.com/apikey"
132
- }
133
- textbox_label = f"Please input the API key for your {model_map[selected_model]} model"
134
- button_value = f"Don't have an API key? Get one for the {model_map[selected_model]} model here."
135
- button_link = link_map[selected_model]
136
- return gr.update(label=textbox_label), gr.update(value=button_value, link=button_link)
137
-
138
- def toggle_button(symptoms_text, api_key, model):
139
- if symptoms_text.strip() and validate_api_key(api_key, model):
140
- return gr.update(interactive=True)
141
- return gr.update(interactive=False)
142
-
143
-
144
- with gr.Blocks() as ui:
145
-
146
- with gr.Row(equal_height=500):
147
- with gr.Column(scale=1, min_width=300):
148
- model = gr.Radio(label="LLM Selection", value="GPT-3.5-Turbo",
149
- choices=["GPT-3.5-Turbo", "Llama-3", "Gemini-1.5"])
150
- is_valid = False
151
- key = gr.Textbox(label="Please input the API key for your Large Language model", type="password")
152
- status_message = gr.HTML(label="Validation Status")
153
- key.input(fn=api_check_msg, inputs=[key, model], outputs=status_message)
154
- button = gr.Button(value="Don't have an API key? Get one for the GPT model here.", link="https://platform.openai.com/account/api-keys")
155
- model.change(update_model_components, inputs=model, outputs=[key, button])
156
- gr.ClearButton(key, variant="primary")
157
-
158
- with gr.Column(scale=2, min_width=600):
159
- gr.Markdown("## Hello, Welcome to the GUI by Team #9.")
160
- temperature = gr.Slider(0.0, 1.0, value=0.7, step = 0.05, label="Temperature", info="Set the Temperature")
161
- top_p = gr.Slider(0.0, 1.0, value=0.9, step = 0.05, label="top-p value", info="Set the sampling nucleus parameter")
162
- symptoms = gr.Textbox(label="Add the symptom data in the input to receive diagnosis")
163
- llm_btn = gr.Button(value="Diagnose Disease", variant="primary", elem_id="diagnose", interactive=False)
164
- symptoms.input(toggle_button, inputs=[symptoms, key, model], outputs=llm_btn)
165
- key.input(toggle_button, inputs=[symptoms, key, model], outputs=llm_btn)
166
- model.change(toggle_button, inputs=[symptoms, key, model], outputs=llm_btn)
167
- output = gr.Textbox(label="LLM Output Status", interactive=False, placeholder="Output will appear here...")
168
- llm_btn.click(fn=diagnose, inputs=[key, model, top_p, temperature, symptoms], outputs=output, api_name="LLM_Comparator")
169
-
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  ui.launch(share=True)
 
1
+ import gradio as gr
2
+ import os
3
+ import requests
4
+ from huggingface_hub import InferenceClient
5
+ import google.generativeai as genai
6
+ import openai
7
+
8
+ def api_check_msg(api_key, selected_model):
9
+ res = validate_api_key(api_key, selected_model)
10
+ return res["message"]
11
+
12
+ def validate_api_key(api_key, selected_model):
13
+ # Check if the API key is valid for GPT-3.5-Turbo
14
+ if "GPT" in selected_model:
15
+ url = "https://api.openai.com/v1/models"
16
+ headers = {
17
+ "Authorization": f"Bearer {api_key}"
18
+ }
19
+ try:
20
+ response = requests.get(url, headers=headers)
21
+ if response.status_code == 200:
22
+ return {"is_valid": True, "message": '<p style="color: green;">API Key is valid!</p>'}
23
+ else:
24
+ return {"is_valid": False, "message": f'<p style="color: red;">Invalid OpenAI API Key. Status code: {response.status_code}</p>'}
25
+ except requests.exceptions.RequestException as e:
26
+ return {"is_valid": False, "message": f'<p style="color: red;">Invalid OpenAI API Key. Error: {e}</p>'}
27
+ elif "Llama" in selected_model:
28
+ url = "https://huggingface.co/api/whoami-v2"
29
+ headers = {
30
+ "Authorization": f"Bearer {api_key}"
31
+ }
32
+ try:
33
+ response = requests.get(url, headers=headers)
34
+ if response.status_code == 200:
35
+ return {"is_valid": True, "message": '<p style="color: green;">API Key is valid!</p>'}
36
+ else:
37
+ return {"is_valid": False, "message": f'<p style="color: red;">Invalid Hugging Face API Key. Status code: {response.status_code}</p>'}
38
+ except requests.exceptions.RequestException as e:
39
+ return {"is_valid": False, "message": f'<p style="color: red;">Invalid Hugging Face API Key. Error: {e}</p>'}
40
+ elif "Gemini" in selected_model:
41
+ try:
42
+ genai.configure(api_key=api_key)
43
+ model = genai.GenerativeModel("gemini-1.5-flash")
44
+ response = model.generate_content("Help me diagnose the patient.")
45
+ return {"is_valid": True, "message": '<p style="color: green;">API Key is valid!</p>'}
46
+ except Exception as e:
47
+ return {"is_valid": False, "message": f'<p style="color: red;">Invalid Google API Key. Error: {e}</p>'}
48
+
49
+ def generate_text_chatgpt(key, prompt, temperature, top_p):
50
+
51
+ openai.api_key = key
52
+
53
+ prompt_template = f"""
54
+ {prompt} <Choose only one among the words Psoriasis, Arthritis, Bronchial asthma or Cervical spondylosis>
55
+ """
56
+
57
+ response = openai.chat.completions.create(
58
+ model="gpt-3.5-turbo-1106",
59
+ messages=[{"role": "system", "content": "You are a talented diagnostician who is diagnosing a patient."},
60
+ {"role": "user", "content": prompt_template}],
61
+ temperature=temperature,
62
+ max_tokens=50,
63
+ top_p=top_p,
64
+ frequency_penalty=0
65
+ )
66
+
67
+ return response.choices[0].message.content
68
+
69
+
70
+ def generate_text_gemini(key, prompt, temperature, top_p):
71
+ genai.configure(api_key=key)
72
+
73
+ prompt_template = f"""
74
+ {prompt} <Choose only one among the words Psoriasis, Arthritis, Bronchial asthma or Cervical spondylosis>
75
+ """
76
+
77
+ generation_config = genai.GenerationConfig(
78
+ max_output_tokens=len(prompt_template)+50,
79
+ temperature=temperature,
80
+ top_p=top_p,
81
+ )
82
+ model = genai.GenerativeModel("gemini-1.5-flash", generation_config=generation_config)
83
+ response = model.generate_content(prompt_template)
84
+ return response.text
85
+
86
+
87
+ def generate_text_llama(key, prompt, temperature, top_p):
88
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
89
+ client = InferenceClient(api_key=key)
90
+
91
+ prompt_template = f"""
92
+ {prompt} <Choose only one among the words Psoriasis, Arthritis, Bronchial asthma or Cervical spondylosis>
93
+ Do not list the symptoms again in the response. Do not add any additional text. Do not attempt to explain your answer.
94
+ """
95
+
96
+ messages = [{"role": "system", "content": "You are a talented diagnostician who is diagnosing a patient."},
97
+ {"role": "user","content": prompt_template}]
98
+
99
+ completion = client.chat.completions.create(
100
+ model=model_name,
101
+ messages=messages,
102
+ max_tokens=len(prompt_template)+50,
103
+ temperature=temperature,
104
+ top_p=top_p
105
+ )
106
+
107
+ response = completion.choices[0].message.content
108
+ if len(response) > len(prompt_template):
109
+ return response[len(prompt_template):]
110
+ return response
111
+
112
+
113
+ def diagnose(key, model, top_k, temperature, symptom_prompt):
114
+
115
+ model_map = {
116
+ "GPT-3.5-Turbo": "GPT",
117
+ "Llama-3": "Llama",
118
+ "Gemini-1.5": "Gemini"
119
+ }
120
+ if symptom_prompt:
121
+ if "GPT" in model:
122
+ message = generate_text_chatgpt(key, symptom_prompt, temperature, top_k)
123
+ elif "Llama" in model:
124
+ message = generate_text_llama(key, symptom_prompt, temperature, top_k)
125
+ elif "Gemini" in model:
126
+ message = generate_text_gemini(key, symptom_prompt, temperature, top_k)
127
+ else:
128
+ message = "Incorrect model, please try again."
129
+ else:
130
+ message = "Please add the symptoms data"
131
+
132
+ return message
133
+
134
+ def update_model_components(selected_model):
135
+ model_map = {
136
+ "GPT-3.5-Turbo": "GPT",
137
+ "Llama-3": "Llama",
138
+ "Gemini-1.5": "Gemini"
139
+ }
140
+
141
+ link_map = {
142
+ "GPT-3.5-Turbo": "https://platform.openai.com/account/api-keys",
143
+ "Llama-3": "https://hf.co/settings/tokens",
144
+ "Gemini-1.5": "https://aistudio.google.com/apikey"
145
+ }
146
+ textbox_label = f"Please input the API key for your {model_map[selected_model]} model"
147
+ button_value = f"Don't have an API key? Get one for the {model_map[selected_model]} model here."
148
+ button_link = link_map[selected_model]
149
+ return gr.update(label=textbox_label), gr.update(value=button_value, link=button_link)
150
+
151
+ def toggle_button(symptoms_text, api_key, model):
152
+ if symptoms_text.strip() and validate_api_key(api_key, model):
153
+ return gr.update(interactive=True)
154
+ return gr.update(interactive=False)
155
+
156
+
157
+ with gr.Blocks() as ui:
158
+
159
+ with gr.Row(equal_height=500):
160
+ with gr.Column(scale=1, min_width=300):
161
+ model = gr.Radio(label="LLM Selection", value="GPT-3.5-Turbo",
162
+ choices=["GPT-3.5-Turbo", "Llama-3", "Gemini-1.5"])
163
+ is_valid = False
164
+ key = gr.Textbox(label="Please input the API key for your Large Language model", type="password")
165
+ status_message = gr.HTML(label="Validation Status")
166
+ key.input(fn=api_check_msg, inputs=[key, model], outputs=status_message)
167
+ button = gr.Button(value="Don't have an API key? Get one for the GPT model here.", link="https://platform.openai.com/account/api-keys")
168
+ model.change(update_model_components, inputs=model, outputs=[key, button])
169
+ gr.ClearButton(key, variant="primary")
170
+
171
+ with gr.Column(scale=2, min_width=600):
172
+ gr.Markdown("## Hello, Welcome to the GUI by Team #9.")
173
+ temperature = gr.Slider(0.0, 1.0, value=0.7, step = 0.05, label="Temperature", info="Set the Temperature")
174
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step = 0.05, label="top-p value", info="Set the sampling nucleus parameter")
175
+ symptoms = gr.Textbox(label="Add the symptom data in the input to receive diagnosis")
176
+ llm_btn = gr.Button(value="Diagnose Disease", variant="primary", elem_id="diagnose", interactive=False)
177
+ symptoms.input(toggle_button, inputs=[symptoms, key, model], outputs=llm_btn)
178
+ key.input(toggle_button, inputs=[symptoms, key, model], outputs=llm_btn)
179
+ model.change(toggle_button, inputs=[symptoms, key, model], outputs=llm_btn)
180
+ output = gr.Textbox(label="LLM Output Status", interactive=False, placeholder="Output will appear here...")
181
+ llm_btn.click(fn=diagnose, inputs=[key, model, top_p, temperature, symptoms], outputs=output, api_name="LLM_Comparator")
182
+
183
+
184
  ui.launch(share=True)