|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
model_name = "BeastGokul/Bio-Mistral-7B-finetuned" |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
def generate_response(chat_history, max_length, temperature, top_p): |
|
conversation = "\n".join([f"User: {msg[0]}\nModel: {msg[1]}" for msg in chat_history]) |
|
inputs = tokenizer(conversation, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") |
|
output = model.generate(**inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True) |
|
response_text = tokenizer.decode(output[0], skip_special_tokens=True).split("Model:")[-1].strip() |
|
chat_history.append((chat_history[-1][0], response_text)) |
|
return chat_history, chat_history |
|
|
|
with gr.Blocks(css=""" |
|
.chatbox { max-height: 600px; overflow-y: auto; background-color: #f8f9fa; border: 1px solid #e0e0e0; padding: 10px; border-radius: 8px; } |
|
.message { padding: 8px; margin: 4px 0; border-radius: 6px; } |
|
.user-message { background-color: #cce5ff; text-align: left; } |
|
.model-message { background-color: #e2e3e5; text-align: left; } |
|
""") as interface: |
|
gr.Markdown( |
|
""" |
|
<h1 style="text-align:center; color: #2c3e50;">Biomedical AI Chat Interface</h1> |
|
<p style="text-align:center; color: #34495e;"> |
|
Ask any biomedical or health-related questions to interact with the AI. |
|
</p> |
|
""" |
|
) |
|
|
|
chat_history = gr.State([]) |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox( |
|
lines=2, |
|
placeholder="Type your biomedical query here...", |
|
label="Your Message", |
|
elem_id="user-input", |
|
container=False |
|
) |
|
chat_display = gr.Chatbox(label="Chat History", elem_id="chatbox", css_class="chatbox") |
|
|
|
example_queries = [ |
|
"What are the common symptoms of diabetes?", |
|
"Explain the function of hemoglobin.", |
|
"How does insulin work in the body?", |
|
"What are the side effects of chemotherapy?", |
|
"Can you explain the process of DNA replication?" |
|
] |
|
user_input.style(placeholder="Enter your biomedical query...") |
|
|
|
with gr.Row(): |
|
max_length = gr.Slider(50, 500, value=200, step=10, label="Response Length") |
|
temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p") |
|
|
|
send_button = gr.Button("Send", elem_id="send-button") |
|
|
|
with gr.Row(): |
|
for query in example_queries: |
|
gr.Button(query).click(fn=lambda q=query: (q, []), outputs=[user_input, chat_history]) |
|
|
|
def add_user_message(user_message, chat_history): |
|
chat_history.append((user_message, "")) |
|
return "", chat_history |
|
|
|
send_button.click( |
|
fn=add_user_message, |
|
inputs=[user_input, chat_history], |
|
outputs=[user_input, chat_history], |
|
) |
|
|
|
send_button.click( |
|
fn=generate_response, |
|
inputs=[chat_history, max_length, temperature, top_p], |
|
outputs=[chat_display, chat_history], |
|
) |
|
|
|
interface.launch() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|