BeastGokul commited on
Commit
d410c01
·
verified ·
1 Parent(s): 17b719a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -79
app.py CHANGED
@@ -9,84 +9,17 @@ base_model = AutoModelForCausalLM.from_pretrained("BioMistral/BioMistral-7B")
9
  base_model.resize_token_embeddings(len(tokenizer))
10
  model = PeftModel.from_pretrained(base_model, "BeastGokul/Bio-Mistral-7B-finetuned")
11
 
12
-
13
-
14
- def generate_response(chat_history, max_length, temperature, top_p):
15
- conversation = "\n".join([f"User: {msg[0]}\nModel: {msg[1]}" for msg in chat_history])
16
- inputs = tokenizer(conversation, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
17
- output = model.generate(**inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True)
18
- response_text = tokenizer.decode(output[0], skip_special_tokens=True).split("Model:")[-1].strip()
19
- chat_history.append((chat_history[-1][0], response_text))
20
- return chat_history, chat_history
21
-
22
- with gr.Blocks(css="""
23
- .chatbox { max-height: 600px; overflow-y: auto; background-color: #f8f9fa; border: 1px solid #e0e0e0; padding: 10px; border-radius: 8px; }
24
- .message { padding: 8px; margin: 4px 0; border-radius: 6px; }
25
- .user-message { background-color: #cce5ff; text-align: left; }
26
- .model-message { background-color: #e2e3e5; text-align: left; }
27
- """) as interface:
28
- gr.Markdown(
29
- """
30
- <h1 style="text-align:center; color: #2c3e50;">Biomedical AI Chat Interface</h1>
31
- <p style="text-align:center; color: #34495e;">
32
- Ask any biomedical or health-related questions to interact with the AI.
33
- </p>
34
- """
35
- )
36
-
37
- chat_history = gr.State([])
38
-
39
- with gr.Row():
40
- user_input = gr.Textbox(
41
- lines=2,
42
- placeholder="Type your biomedical query here...",
43
- label="Your Message",
44
- elem_id="user-input",
45
- container=False
46
- )
47
- chat_display = gr.Chatbot(label="Chat History", elem_id="chatbox", elem_classes=["chatbox"])
48
-
49
-
50
- example_queries = [
51
- "What are the common symptoms of diabetes?",
52
- "Explain the function of hemoglobin.",
53
- "How does insulin work in the body?",
54
- "What are the side effects of chemotherapy?",
55
- "Can you explain the process of DNA replication?"
56
- ]
57
  user_input = gr.Textbox(placeholder="Enter your biomedical query...", label="Your Query")
 
 
 
58
 
59
-
60
- with gr.Row():
61
- max_length = gr.Slider(50, 500, value=200, step=10, label="Response Length")
62
- temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
63
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p")
64
-
65
- send_button = gr.Button("Send", elem_id="send-button")
66
-
67
- with gr.Row():
68
- for query in example_queries:
69
- gr.Button(query).click(fn=lambda q=query: (q, []), outputs=[user_input, chat_history])
70
-
71
- def add_user_message(user_message, chat_history):
72
- chat_history.append((user_message, ""))
73
- return "", chat_history
74
-
75
- send_button.click(
76
- fn=add_user_message,
77
- inputs=[user_input, chat_history],
78
- outputs=[user_input, chat_history],
79
- )
80
-
81
- send_button.click(
82
- fn=generate_response,
83
- inputs=[chat_history, max_length, temperature, top_p],
84
- outputs=[chat_display, chat_history],
85
- )
86
-
87
- interface.launch()
88
-
89
-
90
-
91
- if __name__ == "__main__":
92
- demo.launch()
 
9
  base_model.resize_token_embeddings(len(tokenizer))
10
  model = PeftModel.from_pretrained(base_model, "BeastGokul/Bio-Mistral-7B-finetuned")
11
 
12
+ def generate_response(user_query):
13
+ inputs = tokenizer(user_query, return_tensors="pt")
14
+ outputs = model.generate(**inputs, max_length=100)
15
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
16
+ return response
17
+
18
+ # Define the Gradio interface
19
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  user_input = gr.Textbox(placeholder="Enter your biomedical query...", label="Your Query")
21
+ response = gr.Textbox(label="Response", interactive=False)
22
+
23
+ user_input.submit(fn=generate_response, inputs=user_input, outputs=response)
24
 
25
+ demo.launch()