Jacksonnavigator7 commited on
Commit
16487f7
1 Parent(s): b133f5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py CHANGED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import transformers
3
+ import torch
4
+ from langchain.llms import HuggingFacePipeline
5
+ from langchain.document_loaders import WebBaseLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from transformers import StoppingCriteria, StoppingCriteriaList
11
+
12
+ # Load the Llama model and setup the conversation pipeline
13
+ model_id = 'meta-llama/Llama-2-7b-chat-hf'
14
+ # Add your authentication token here
15
+ hf_auth = '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'
16
+
17
+ # Load Llama model
18
+ model_config = transformers.AutoConfig.from_pretrained(model_id, use_auth_token=hf_auth)
19
+ model = transformers.AutoModelForCausalLM.from_pretrained(
20
+ model_id,
21
+ trust_remote_code=True,
22
+ config=model_config,
23
+ device_map='auto',
24
+ use_auth_token=hf_auth
25
+ )
26
+
27
+ # Initialize the Llama pipeline
28
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth)
29
+
30
+ bnb_config = transformers.BitsAndBytesConfig(
31
+ load_in_4bit=True,
32
+ bnb_4bit_quant_type='nf4',
33
+ bnb_4bit_use_double_quant=True,
34
+ bnb_4bit_compute_dtype=torch.bfloat16
35
+ )
36
+
37
+ stop_list = ['\nHuman:', '\n```\n']
38
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
39
+ stop_token_ids = [torch.LongTensor(x).to('cuda') for x in stop_token_ids]
40
+
41
+ stopping_criteria = StoppingCriteriaList([transformers.StoppingCriteria(max_length=1024)])
42
+
43
+ generate_text = transformers.pipeline(
44
+ model=model,
45
+ tokenizer=tokenizer,
46
+ return_full_text=True,
47
+ task='text-generation',
48
+ stopping_criteria=stopping_criteria,
49
+ temperature=0.1,
50
+ max_new_tokens=512,
51
+ repetition_penalty=1.1
52
+ )
53
+
54
+ llm = HuggingFacePipeline(pipeline=generate_text)
55
+
56
+ # Load source documents
57
+ web_links = ["https://www.techtarget.com/whatis/definition/transistor",
58
+ "https://en.wikipedia.org/wiki/Transistor",
59
+ # Add more source links as needed
60
+ ]
61
+
62
+ loader = WebBaseLoader(web_links)
63
+ documents = loader.load()
64
+
65
+ # Split source documents
66
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
67
+ all_splits = text_splitter.split_documents(documents)
68
+
69
+ # Create embeddings and vector store
70
+ model_name = "sentence-transformers/all-mpnet-base-v2"
71
+ model_kwargs = {"device": "cuda"}
72
+
73
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
74
+ vectorstore = FAISS.from_documents(all_splits, embeddings)
75
+
76
+ # Create the conversation retrieval chain
77
+ chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
78
+
79
+ # Streamlit app
80
+ def main():
81
+ st.title("AI Chatbot")
82
+
83
+ user_question = st.text_input("Ask a question:")
84
+
85
+ sources = [
86
+ "Source 1",
87
+ "Source 2",
88
+ "Source 3",
89
+ # Add more sources as needed
90
+ ]
91
+ selected_source = st.selectbox("Select a source:", sources)
92
+
93
+ if st.button("Get Answer"):
94
+ chat_history = []
95
+
96
+ query = user_question
97
+ result = chain({"question": query, "chat_history": chat_history})
98
+
99
+ st.write("Answer:", result["answer"])
100
+
101
+ chat_history.append((query, result["answer"]))
102
+
103
+ if "source_documents" in result:
104
+ st.write("Source Documents:")
105
+ for source_doc in result["source_documents"]:
106
+ st.write(source_doc)
107
+
108
+ if __name__ == "__main__":
109
+ main()