BeastGokul's picture
Update app.py
43a91cc verified
raw
history blame
3.33 kB
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
model_name = "BeastGokul/Bio-Mistral-7B-finetuned"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_response(chat_history, max_length, temperature, top_p):
conversation = "\n".join([f"User: {msg[0]}\nModel: {msg[1]}" for msg in chat_history])
inputs = tokenizer(conversation, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
output = model.generate(**inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True)
response_text = tokenizer.decode(output[0], skip_special_tokens=True).split("Model:")[-1].strip()
chat_history.append((chat_history[-1][0], response_text))
return chat_history, chat_history
with gr.Blocks(css="""
.chatbox { max-height: 600px; overflow-y: auto; background-color: #f8f9fa; border: 1px solid #e0e0e0; padding: 10px; border-radius: 8px; }
.message { padding: 8px; margin: 4px 0; border-radius: 6px; }
.user-message { background-color: #cce5ff; text-align: left; }
.model-message { background-color: #e2e3e5; text-align: left; }
""") as interface:
gr.Markdown(
"""
<h1 style="text-align:center; color: #2c3e50;">Biomedical AI Chat Interface</h1>
<p style="text-align:center; color: #34495e;">
Ask any biomedical or health-related questions to interact with the AI.
</p>
"""
)
chat_history = gr.State([])
with gr.Row():
user_input = gr.Textbox(
lines=2,
placeholder="Type your biomedical query here...",
label="Your Message",
elem_id="user-input",
container=False
)
chat_display = gr.Chatbox(label="Chat History", elem_id="chatbox", css_class="chatbox")
example_queries = [
"What are the common symptoms of diabetes?",
"Explain the function of hemoglobin.",
"How does insulin work in the body?",
"What are the side effects of chemotherapy?",
"Can you explain the process of DNA replication?"
]
user_input.style(placeholder="Enter your biomedical query...")
with gr.Row():
max_length = gr.Slider(50, 500, value=200, step=10, label="Response Length")
temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p")
send_button = gr.Button("Send", elem_id="send-button")
with gr.Row():
for query in example_queries:
gr.Button(query).click(fn=lambda q=query: (q, []), outputs=[user_input, chat_history])
def add_user_message(user_message, chat_history):
chat_history.append((user_message, ""))
return "", chat_history
send_button.click(
fn=add_user_message,
inputs=[user_input, chat_history],
outputs=[user_input, chat_history],
)
send_button.click(
fn=generate_response,
inputs=[chat_history, max_length, temperature, top_p],
outputs=[chat_display, chat_history],
)
interface.launch()
if __name__ == "__main__":
demo.launch()