paloma99 commited on
Commit
12ff379
·
verified ·
1 Parent(s): a06444c

Create chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +86 -0
chatbot.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append('../..')
4
+
5
+ #langchain
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import ChatPromptTemplate
11
+ from langchain.schema import StrOutputParser
12
+ from langchain.schema.runnable import Runnable
13
+ from langchain.schema.runnable.config import RunnableConfig
14
+ from langchain.chains import (
15
+ LLMChain, ConversationalRetrievalChain)
16
+ from langchain.vectorstores import Chroma
17
+ from langchain.memory import ConversationBufferMemory
18
+ from langchain.chains import LLMChain
19
+ from langchain.prompts.prompt import PromptTemplate
20
+ from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
21
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
22
+ from langchain.document_loaders import PyPDFDirectoryLoader
23
+
24
+ from langchain_community.llms import HuggingFaceHub
25
+
26
+ from pydantic import BaseModel
27
+ import shutil
28
+
29
+
30
+ loader = PyPDFDirectoryLoader('pdfs')
31
+ data=loader.load()
32
+ # split documents
33
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
34
+ docs = text_splitter.split_documents(data)
35
+ # define embedding
36
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-l6-v2')
37
+ # create vector database from data
38
+ persist_directory = 'docs/chroma/'
39
+
40
+ # Remove old database files if any
41
+ shutil.rmtree(persist_directory, ignore_errors=True)
42
+ vectordb = Chroma.from_documents(
43
+ documents=docs,
44
+ embedding=embeddings,
45
+ persist_directory=persist_directory
46
+ )
47
+ # define retriever
48
+ retriever = vectordb.as_retriever(search_type="mmr")
49
+ template = """Your name is AngryGreta and you are a recycling chatbot created to help people. Use the following pieces of context to answer the question at the end. Answer in the same language of the question. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
50
+ CONTEXT: {context}
51
+ CHAT HISTORY:
52
+ {chat_history}
53
+ Question: {question}
54
+ Helpful Answer:"""
55
+
56
+ # Create the chat prompt templates
57
+ system_prompt = SystemMessagePromptTemplate.from_template(template)
58
+ qa_prompt = ChatPromptTemplate(
59
+ messages=[
60
+ system_prompt,
61
+ MessagesPlaceholder(variable_name="chat_history"),
62
+ HumanMessagePromptTemplate.from_template("{question}")
63
+ ]
64
+ )
65
+ llm = HuggingFaceHub(
66
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
67
+ task="text-generation",
68
+ model_kwargs={
69
+ "max_new_tokens": 512,
70
+ "top_k": 30,
71
+ "temperature": 0.1,
72
+ "repetition_penalty": 1.03,
73
+ },
74
+ )
75
+ llm_chain = LLMChain(llm=llm, prompt=qa_prompt)
76
+
77
+ memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", output_key='answer', return_messages=True)
78
+
79
+ qa_chain = ConversationalRetrievalChain.from_llm(
80
+ llm = llm,
81
+ memory = memory,
82
+ retriever = retriever,
83
+ verbose = True,
84
+ combine_docs_chain_kwargs={'prompt': qa_prompt},
85
+ get_chat_history = lambda h : h
86
+ )