Spaces:
Sleeping
Sleeping
Yousefsalem
commited on
Commit
•
8817c45
1
Parent(s):
bab8e0c
Upload 4 files
Browse files- src/__init__.py +0 -0
- src/chatbot.py +72 -0
- src/memory.py +77 -0
- src/models.py +42 -0
src/__init__.py
ADDED
File without changes
|
src/chatbot.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from langchain.chains import ConversationChain
|
3 |
+
from .memory import EnhancedInMemoryHistory
|
4 |
+
from .models import route_llm, prompt
|
5 |
+
|
6 |
+
# Function to process input and generate a response
|
7 |
+
def process_input(user_input, session_id='1'):
|
8 |
+
"""
|
9 |
+
Processes the user input and generates a response using the conversation chain.
|
10 |
+
|
11 |
+
Parameters:
|
12 |
+
user_input (str): The user's input message.
|
13 |
+
session_id (str): The session ID for the chat (default is "1").
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
str: The generated response from the chatbot.
|
17 |
+
"""
|
18 |
+
memory = get_by_session_id(session_id)
|
19 |
+
|
20 |
+
if user_input.lower() == 'exit':
|
21 |
+
return "Exiting the chat session."
|
22 |
+
|
23 |
+
llm = route_llm(user_input)
|
24 |
+
|
25 |
+
conversation_chain = ConversationChain(
|
26 |
+
llm=llm,
|
27 |
+
prompt=prompt,
|
28 |
+
memory=memory,
|
29 |
+
input_key='input',
|
30 |
+
verbose=True
|
31 |
+
)
|
32 |
+
|
33 |
+
response = conversation_chain.run({"input": user_input})
|
34 |
+
memory.save_context({'input': user_input}, response)
|
35 |
+
|
36 |
+
return response
|
37 |
+
|
38 |
+
# Function for Gradio interface
|
39 |
+
def chatbot_interface(user_input, chat_history=None, session_id="1"):
|
40 |
+
"""
|
41 |
+
Interface function for Gradio to handle input and output between the user and the chatbot.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
user_input (str): The user's input message.
|
45 |
+
session_id (str): The session ID for the chat (default is "1").
|
46 |
+
chat_history (list): List of previous chat messages in the format [[user, bot], ...]
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
list: Updated chat history including the new user and bot messages.
|
50 |
+
"""
|
51 |
+
if chat_history is None:
|
52 |
+
chat_history = []
|
53 |
+
|
54 |
+
if user_input == "":
|
55 |
+
bot_response = "Hi there! How can I help you today?"
|
56 |
+
else:
|
57 |
+
bot_response = process_input(user_input, session_id)
|
58 |
+
|
59 |
+
chat_history.append([user_input, bot_response])
|
60 |
+
|
61 |
+
return chat_history
|
62 |
+
|
63 |
+
# Gradio launch
|
64 |
+
def launch_gradio_interface():
|
65 |
+
gr.Interface(
|
66 |
+
fn=chatbot_interface,
|
67 |
+
inputs=[gr.Textbox(lines=7, label="Your input", placeholder="Type your message here...")],
|
68 |
+
outputs=gr.Chatbot(label="Chat History"),
|
69 |
+
title="AI Chatbot",
|
70 |
+
live=False
|
71 |
+
).launch()
|
72 |
+
|
src/memory.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.memory import BaseMemory
|
2 |
+
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
|
3 |
+
from langchain_core.pydantic_v1 import Field
|
4 |
+
from typing import List
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
class EnhancedInMemoryHistory(BaseMemory):
|
8 |
+
"""
|
9 |
+
Custom memory class for storing chat history with timestamps.
|
10 |
+
|
11 |
+
Attributes:
|
12 |
+
messages (List[BaseMessage]): A list of messages exchanged between the user and the bot.
|
13 |
+
timestamps (List[datetime]): A list of timestamps when the messages were exchanged.
|
14 |
+
"""
|
15 |
+
|
16 |
+
messages: List[BaseMessage] = Field(default_factory=list)
|
17 |
+
timestamps: List[datetime] = Field(default_factory=list)
|
18 |
+
|
19 |
+
@property
|
20 |
+
def memory_variables(self):
|
21 |
+
"""Returns a list of memory variables (history) used in the conversation chain."""
|
22 |
+
return ["history"]
|
23 |
+
|
24 |
+
def add_messages(self, messages: List[BaseMessage]):
|
25 |
+
"""
|
26 |
+
Adds new messages to the memory and timestamps them.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
messages (List[BaseMessage]): A list of messages to add to the memory.
|
30 |
+
"""
|
31 |
+
current_time = datetime.now()
|
32 |
+
self.messages.extend(messages)
|
33 |
+
self.timestamps.extend([current_time] * len(messages))
|
34 |
+
|
35 |
+
def get_recent_messages(self, limit: int = 5):
|
36 |
+
"""
|
37 |
+
Retrieves the most recent messages.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
limit (int): Number of recent messages to retrieve (default is 5).
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
List[BaseMessage]: A list of the most recent messages.
|
44 |
+
"""
|
45 |
+
return self.messages[-limit:]
|
46 |
+
|
47 |
+
def clear(self):
|
48 |
+
"""Clears all messages and timestamps from the memory."""
|
49 |
+
self.messages = []
|
50 |
+
self.timestamps = []
|
51 |
+
|
52 |
+
def load_memory_variables(self, inputs: dict):
|
53 |
+
"""
|
54 |
+
Loads memory variables for the conversation chain.
|
55 |
+
|
56 |
+
Parameters:
|
57 |
+
inputs (dict): Input data for the conversation chain.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
dict: A dictionary with the conversation history.
|
61 |
+
"""
|
62 |
+
return {"history": "\n".join([msg.content for msg in self.messages])}
|
63 |
+
|
64 |
+
def save_context(self, inputs: dict, outputs: str):
|
65 |
+
"""
|
66 |
+
Saves the context of the conversation by storing the user input and bot output.
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
inputs (dict): The user input.
|
70 |
+
outputs (str): The bot's response.
|
71 |
+
"""
|
72 |
+
self.add_messages([HumanMessage(content=inputs['input']), AIMessage(content=str(outputs))])
|
73 |
+
|
74 |
+
def clear_memory(self):
|
75 |
+
"""Clears the memory (alias for the clear function)."""
|
76 |
+
self.clear()
|
77 |
+
|
src/models.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
from langchain.llms import CTransformers
|
3 |
+
from langchain.prompts import ChatPromptTemplate
|
4 |
+
|
5 |
+
# Text classification model for routing input
|
6 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
7 |
+
|
8 |
+
# Load LLM models using CTransformers
|
9 |
+
general_llm = CTransformers(
|
10 |
+
model="./llama-2-7b.Q4_K_M.gguf",
|
11 |
+
model_type="llama",
|
12 |
+
config={'max_new_tokens': 512, 'temperature': 0.7}
|
13 |
+
)
|
14 |
+
|
15 |
+
medical_llm = CTransformers(
|
16 |
+
model="./BioMistral-7B.Q4_K_M.gguf",
|
17 |
+
model_type="llama",
|
18 |
+
config={'max_new_tokens': 512, 'temperature': 0.7}
|
19 |
+
)
|
20 |
+
|
21 |
+
# Prompt template for generating responses
|
22 |
+
template = """
|
23 |
+
You are a versatile AI assistant that can provide both medical advice and help users with general concerns...
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Compile the prompt template using LangChain
|
27 |
+
prompt = ChatPromptTemplate.from_template(template)
|
28 |
+
|
29 |
+
def route_llm(user_input):
|
30 |
+
"""
|
31 |
+
Routes user input to the appropriate LLM (medical or general).
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
user_input (str): The user's input message.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
CTransformers: The selected LLM model (general or medical).
|
38 |
+
"""
|
39 |
+
result = classifier(user_input, ['medical', 'general'])
|
40 |
+
label = result['labels'][0] if 'labels' in result else 'general'
|
41 |
+
return medical_llm if label == 'medical' else general_llm
|
42 |
+
|