Llamachat / app.py
Jacksonnavigator7's picture
Update app.py
16487f7
raw
history blame
No virus
3.43 kB
import streamlit as st
import transformers
import torch
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from transformers import StoppingCriteria, StoppingCriteriaList
# Load the Llama model and setup the conversation pipeline
model_id = 'meta-llama/Llama-2-7b-chat-hf'
# Add your authentication token here
hf_auth = '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'
# Load Llama model
model_config = transformers.AutoConfig.from_pretrained(model_id, use_auth_token=hf_auth)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
config=model_config,
device_map='auto',
use_auth_token=hf_auth
)
# Initialize the Llama pipeline
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth)
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
stop_list = ['\nHuman:', '\n```\n']
stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to('cuda') for x in stop_token_ids]
stopping_criteria = StoppingCriteriaList([transformers.StoppingCriteria(max_length=1024)])
generate_text = transformers.pipeline(
model=model,
tokenizer=tokenizer,
return_full_text=True,
task='text-generation',
stopping_criteria=stopping_criteria,
temperature=0.1,
max_new_tokens=512,
repetition_penalty=1.1
)
llm = HuggingFacePipeline(pipeline=generate_text)
# Load source documents
web_links = ["https://www.techtarget.com/whatis/definition/transistor",
"https://en.wikipedia.org/wiki/Transistor",
# Add more source links as needed
]
loader = WebBaseLoader(web_links)
documents = loader.load()
# Split source documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
all_splits = text_splitter.split_documents(documents)
# Create embeddings and vector store
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cuda"}
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
vectorstore = FAISS.from_documents(all_splits, embeddings)
# Create the conversation retrieval chain
chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
# Streamlit app
def main():
st.title("AI Chatbot")
user_question = st.text_input("Ask a question:")
sources = [
"Source 1",
"Source 2",
"Source 3",
# Add more sources as needed
]
selected_source = st.selectbox("Select a source:", sources)
if st.button("Get Answer"):
chat_history = []
query = user_question
result = chain({"question": query, "chat_history": chat_history})
st.write("Answer:", result["answer"])
chat_history.append((query, result["answer"]))
if "source_documents" in result:
st.write("Source Documents:")
for source_doc in result["source_documents"]:
st.write(source_doc)
if __name__ == "__main__":
main()