Yousefsalem commited on
Commit
8817c45
1 Parent(s): bab8e0c

Upload 4 files

Browse files
Files changed (4) hide show
  1. src/__init__.py +0 -0
  2. src/chatbot.py +72 -0
  3. src/memory.py +77 -0
  4. src/models.py +42 -0
src/__init__.py ADDED
File without changes
src/chatbot.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.chains import ConversationChain
3
+ from .memory import EnhancedInMemoryHistory
4
+ from .models import route_llm, prompt
5
+
6
+ # Function to process input and generate a response
7
+ def process_input(user_input, session_id='1'):
8
+ """
9
+ Processes the user input and generates a response using the conversation chain.
10
+
11
+ Parameters:
12
+ user_input (str): The user's input message.
13
+ session_id (str): The session ID for the chat (default is "1").
14
+
15
+ Returns:
16
+ str: The generated response from the chatbot.
17
+ """
18
+ memory = get_by_session_id(session_id)
19
+
20
+ if user_input.lower() == 'exit':
21
+ return "Exiting the chat session."
22
+
23
+ llm = route_llm(user_input)
24
+
25
+ conversation_chain = ConversationChain(
26
+ llm=llm,
27
+ prompt=prompt,
28
+ memory=memory,
29
+ input_key='input',
30
+ verbose=True
31
+ )
32
+
33
+ response = conversation_chain.run({"input": user_input})
34
+ memory.save_context({'input': user_input}, response)
35
+
36
+ return response
37
+
38
+ # Function for Gradio interface
39
+ def chatbot_interface(user_input, chat_history=None, session_id="1"):
40
+ """
41
+ Interface function for Gradio to handle input and output between the user and the chatbot.
42
+
43
+ Parameters:
44
+ user_input (str): The user's input message.
45
+ session_id (str): The session ID for the chat (default is "1").
46
+ chat_history (list): List of previous chat messages in the format [[user, bot], ...]
47
+
48
+ Returns:
49
+ list: Updated chat history including the new user and bot messages.
50
+ """
51
+ if chat_history is None:
52
+ chat_history = []
53
+
54
+ if user_input == "":
55
+ bot_response = "Hi there! How can I help you today?"
56
+ else:
57
+ bot_response = process_input(user_input, session_id)
58
+
59
+ chat_history.append([user_input, bot_response])
60
+
61
+ return chat_history
62
+
63
+ # Gradio launch
64
+ def launch_gradio_interface():
65
+ gr.Interface(
66
+ fn=chatbot_interface,
67
+ inputs=[gr.Textbox(lines=7, label="Your input", placeholder="Type your message here...")],
68
+ outputs=gr.Chatbot(label="Chat History"),
69
+ title="AI Chatbot",
70
+ live=False
71
+ ).launch()
72
+
src/memory.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.memory import BaseMemory
2
+ from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
3
+ from langchain_core.pydantic_v1 import Field
4
+ from typing import List
5
+ from datetime import datetime
6
+
7
+ class EnhancedInMemoryHistory(BaseMemory):
8
+ """
9
+ Custom memory class for storing chat history with timestamps.
10
+
11
+ Attributes:
12
+ messages (List[BaseMessage]): A list of messages exchanged between the user and the bot.
13
+ timestamps (List[datetime]): A list of timestamps when the messages were exchanged.
14
+ """
15
+
16
+ messages: List[BaseMessage] = Field(default_factory=list)
17
+ timestamps: List[datetime] = Field(default_factory=list)
18
+
19
+ @property
20
+ def memory_variables(self):
21
+ """Returns a list of memory variables (history) used in the conversation chain."""
22
+ return ["history"]
23
+
24
+ def add_messages(self, messages: List[BaseMessage]):
25
+ """
26
+ Adds new messages to the memory and timestamps them.
27
+
28
+ Parameters:
29
+ messages (List[BaseMessage]): A list of messages to add to the memory.
30
+ """
31
+ current_time = datetime.now()
32
+ self.messages.extend(messages)
33
+ self.timestamps.extend([current_time] * len(messages))
34
+
35
+ def get_recent_messages(self, limit: int = 5):
36
+ """
37
+ Retrieves the most recent messages.
38
+
39
+ Parameters:
40
+ limit (int): Number of recent messages to retrieve (default is 5).
41
+
42
+ Returns:
43
+ List[BaseMessage]: A list of the most recent messages.
44
+ """
45
+ return self.messages[-limit:]
46
+
47
+ def clear(self):
48
+ """Clears all messages and timestamps from the memory."""
49
+ self.messages = []
50
+ self.timestamps = []
51
+
52
+ def load_memory_variables(self, inputs: dict):
53
+ """
54
+ Loads memory variables for the conversation chain.
55
+
56
+ Parameters:
57
+ inputs (dict): Input data for the conversation chain.
58
+
59
+ Returns:
60
+ dict: A dictionary with the conversation history.
61
+ """
62
+ return {"history": "\n".join([msg.content for msg in self.messages])}
63
+
64
+ def save_context(self, inputs: dict, outputs: str):
65
+ """
66
+ Saves the context of the conversation by storing the user input and bot output.
67
+
68
+ Parameters:
69
+ inputs (dict): The user input.
70
+ outputs (str): The bot's response.
71
+ """
72
+ self.add_messages([HumanMessage(content=inputs['input']), AIMessage(content=str(outputs))])
73
+
74
+ def clear_memory(self):
75
+ """Clears the memory (alias for the clear function)."""
76
+ self.clear()
77
+
src/models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from langchain.llms import CTransformers
3
+ from langchain.prompts import ChatPromptTemplate
4
+
5
+ # Text classification model for routing input
6
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
7
+
8
+ # Load LLM models using CTransformers
9
+ general_llm = CTransformers(
10
+ model="./llama-2-7b.Q4_K_M.gguf",
11
+ model_type="llama",
12
+ config={'max_new_tokens': 512, 'temperature': 0.7}
13
+ )
14
+
15
+ medical_llm = CTransformers(
16
+ model="./BioMistral-7B.Q4_K_M.gguf",
17
+ model_type="llama",
18
+ config={'max_new_tokens': 512, 'temperature': 0.7}
19
+ )
20
+
21
+ # Prompt template for generating responses
22
+ template = """
23
+ You are a versatile AI assistant that can provide both medical advice and help users with general concerns...
24
+ """
25
+
26
+ # Compile the prompt template using LangChain
27
+ prompt = ChatPromptTemplate.from_template(template)
28
+
29
+ def route_llm(user_input):
30
+ """
31
+ Routes user input to the appropriate LLM (medical or general).
32
+
33
+ Parameters:
34
+ user_input (str): The user's input message.
35
+
36
+ Returns:
37
+ CTransformers: The selected LLM model (general or medical).
38
+ """
39
+ result = classifier(user_input, ['medical', 'general'])
40
+ label = result['labels'][0] if 'labels' in result else 'general'
41
+ return medical_llm if label == 'medical' else general_llm
42
+