File size: 3,333 Bytes
8e63ced 1ac767c 43a91cc 8e63ced bad24f3 67b0b7b 1ac767c 8e63ced ef6c601 e7ed7d1 ef6c601 8e63ced e7ed7d1 ef6c601 8e63ced ef6c601 e7ed7d1 8e63ced ef6c601 e7ed7d1 ef6c601 8e63ced e7ed7d1 8e63ced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
model_name = "BeastGokul/Bio-Mistral-7B-finetuned"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_response(chat_history, max_length, temperature, top_p):
conversation = "\n".join([f"User: {msg[0]}\nModel: {msg[1]}" for msg in chat_history])
inputs = tokenizer(conversation, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
output = model.generate(**inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True)
response_text = tokenizer.decode(output[0], skip_special_tokens=True).split("Model:")[-1].strip()
chat_history.append((chat_history[-1][0], response_text))
return chat_history, chat_history
with gr.Blocks(css="""
.chatbox { max-height: 600px; overflow-y: auto; background-color: #f8f9fa; border: 1px solid #e0e0e0; padding: 10px; border-radius: 8px; }
.message { padding: 8px; margin: 4px 0; border-radius: 6px; }
.user-message { background-color: #cce5ff; text-align: left; }
.model-message { background-color: #e2e3e5; text-align: left; }
""") as interface:
gr.Markdown(
"""
<h1 style="text-align:center; color: #2c3e50;">Biomedical AI Chat Interface</h1>
<p style="text-align:center; color: #34495e;">
Ask any biomedical or health-related questions to interact with the AI.
</p>
"""
)
chat_history = gr.State([])
with gr.Row():
user_input = gr.Textbox(
lines=2,
placeholder="Type your biomedical query here...",
label="Your Message",
elem_id="user-input",
container=False
)
chat_display = gr.Chatbox(label="Chat History", elem_id="chatbox", css_class="chatbox")
example_queries = [
"What are the common symptoms of diabetes?",
"Explain the function of hemoglobin.",
"How does insulin work in the body?",
"What are the side effects of chemotherapy?",
"Can you explain the process of DNA replication?"
]
user_input.style(placeholder="Enter your biomedical query...")
with gr.Row():
max_length = gr.Slider(50, 500, value=200, step=10, label="Response Length")
temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p")
send_button = gr.Button("Send", elem_id="send-button")
with gr.Row():
for query in example_queries:
gr.Button(query).click(fn=lambda q=query: (q, []), outputs=[user_input, chat_history])
def add_user_message(user_message, chat_history):
chat_history.append((user_message, ""))
return "", chat_history
send_button.click(
fn=add_user_message,
inputs=[user_input, chat_history],
outputs=[user_input, chat_history],
)
send_button.click(
fn=generate_response,
inputs=[chat_history, max_length, temperature, top_p],
outputs=[chat_display, chat_history],
)
interface.launch()
if __name__ == "__main__":
demo.launch()
|