Spaces:
Runtime error
Runtime error
from langchain_community.document_loaders import TextLoader | |
import os | |
loaders = [] | |
folder_path = "Data" | |
for i in range(12): | |
file_path = os.path.join(folder_path,"{}.txt".format(i)) | |
loaders.append(TextLoader(file_path)) | |
docs = [] | |
for loader in loaders: | |
docs.extend(loader.load()) | |
from langchain.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
embeddings = HuggingFaceInferenceAPIEmbeddings( | |
api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2" | |
) | |
vectordb = Chroma.from_documents( | |
documents=docs, | |
embedding=embeddings | |
) | |
from langchain_community.llms import HuggingFaceHub | |
llm = HuggingFaceHub( | |
repo_id="google/gemma-1.1-2b-it", | |
task="text-generation", | |
model_kwargs={ | |
"max_new_tokens": 7000, | |
"top_k": 5, | |
"temperature": 0.1, | |
"repetition_penalty": 1.03, | |
}, | |
huggingfacehub_api_token = HF_TOKEN | |
) | |
from langchain.prompts import PromptTemplate | |
template = """You are a Chatbot at a Restaurant. Help the customer pick the right dish to order. The items in the context are dishes. The field below the item is the cost of the dish. About is the description of the dish. Use the context below to answe the questions | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,) | |
from langchain.memory import ConversationBufferMemory | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
from langchain.chains import ConversationalRetrievalChain | |
retriever = vectordb.as_retriever() | |
qa = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
memory=memory, | |
) | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
which might reference context in the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is.""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{question}"), | |
] | |
) | |
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser() | |
def contextualized_question(input: dict): | |
if input.get("chat_history"): | |
return contextualize_q_chain | |
else: | |
return input["question"] | |
rag_chain = ( | |
RunnablePassthrough.assign( | |
context=contextualized_question | retriever | |
) | |
| QA_CHAIN_PROMPT | |
| llm | |
) | |
from langchain_core.messages import AIMessage, HumanMessage | |
os.environ["LANGCHAIN_WANDB_TRACING"] = "true" | |
os.environ["WANDB_PROJECT"] = "Restaurant_ChatBot" | |
print("Welcome to the Restaurant. How can I help you today?") | |
chat_history = [] | |
def predict(message): | |
ai_msg = rag_chain.invoke({"question": message, "chat_history": chat_history}) | |
idx = ai_msg.find("Answer") | |
chat_history.extend([HumanMessage(content=message), ai_msg]) | |
return ai_msg[idx:] | |