File size: 3,333 Bytes
8e63ced
 
1ac767c
43a91cc
8e63ced
 
bad24f3
67b0b7b
1ac767c
 
8e63ced
 
ef6c601
e7ed7d1
 
 
 
 
ef6c601
8e63ced
e7ed7d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef6c601
8e63ced
ef6c601
e7ed7d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e63ced
ef6c601
 
 
 
 
e7ed7d1
 
 
 
 
ef6c601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e63ced
 
e7ed7d1
8e63ced
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM

client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

model_name = "BeastGokul/Bio-Mistral-7B-finetuned"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)



def generate_response(chat_history, max_length, temperature, top_p):
    conversation = "\n".join([f"User: {msg[0]}\nModel: {msg[1]}" for msg in chat_history])
    inputs = tokenizer(conversation, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
    output = model.generate(**inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True)
    response_text = tokenizer.decode(output[0], skip_special_tokens=True).split("Model:")[-1].strip()
    chat_history.append((chat_history[-1][0], response_text))
    return chat_history, chat_history

with gr.Blocks(css="""
    .chatbox { max-height: 600px; overflow-y: auto; background-color: #f8f9fa; border: 1px solid #e0e0e0; padding: 10px; border-radius: 8px; }
    .message { padding: 8px; margin: 4px 0; border-radius: 6px; }
    .user-message { background-color: #cce5ff; text-align: left; }
    .model-message { background-color: #e2e3e5; text-align: left; }
""") as interface:
    gr.Markdown(
        """
        <h1 style="text-align:center; color: #2c3e50;">Biomedical AI Chat Interface</h1>
        <p style="text-align:center; color: #34495e;">
            Ask any biomedical or health-related questions to interact with the AI.
        </p>
        """
    )

    chat_history = gr.State([])

    with gr.Row():
        user_input = gr.Textbox(
            lines=2,
            placeholder="Type your biomedical query here...",
            label="Your Message",
            elem_id="user-input",
            container=False
        )
        chat_display = gr.Chatbox(label="Chat History", elem_id="chatbox", css_class="chatbox")

    example_queries = [
        "What are the common symptoms of diabetes?",
        "Explain the function of hemoglobin.",
        "How does insulin work in the body?",
        "What are the side effects of chemotherapy?",
        "Can you explain the process of DNA replication?"
    ]
    user_input.style(placeholder="Enter your biomedical query...")

    with gr.Row():
        max_length = gr.Slider(50, 500, value=200, step=10, label="Response Length")
        temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p")

    send_button = gr.Button("Send", elem_id="send-button")

    with gr.Row():
        for query in example_queries:
            gr.Button(query).click(fn=lambda q=query: (q, []), outputs=[user_input, chat_history])

    def add_user_message(user_message, chat_history):
        chat_history.append((user_message, ""))
        return "", chat_history

    send_button.click(
        fn=add_user_message,
        inputs=[user_input, chat_history],
        outputs=[user_input, chat_history],
    )

    send_button.click(
        fn=generate_response,
        inputs=[chat_history, max_length, temperature, top_p],
        outputs=[chat_display, chat_history],
    )

interface.launch()



if __name__ == "__main__":
    demo.launch()