File size: 3,433 Bytes
16487f7
 
 
 
 
 
 
 
 
 
 
 
 
 
7cb353c
16487f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import transformers
import torch
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from transformers import StoppingCriteria, StoppingCriteriaList

# Load the Llama model and setup the conversation pipeline
model_id = 'meta-llama/Llama-2-7b-chat-hf'
# Add your authentication token here
hf_auth = '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'

# Load Llama model
model_config = transformers.AutoConfig.from_pretrained(model_id, use_auth_token=hf_auth)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    config=model_config,
    device_map='auto',
    use_auth_token=hf_auth
)

# Initialize the Llama pipeline
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_auth)

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

stop_list = ['\nHuman:', '\n```\n']
stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
stop_token_ids = [torch.LongTensor(x).to('cuda') for x in stop_token_ids]

stopping_criteria = StoppingCriteriaList([transformers.StoppingCriteria(max_length=1024)])

generate_text = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    task='text-generation',
    stopping_criteria=stopping_criteria,
    temperature=0.1,
    max_new_tokens=512,
    repetition_penalty=1.1
)

llm = HuggingFacePipeline(pipeline=generate_text)

# Load source documents
web_links = ["https://www.techtarget.com/whatis/definition/transistor",
             "https://en.wikipedia.org/wiki/Transistor",
             # Add more source links as needed
             ]

loader = WebBaseLoader(web_links)
documents = loader.load()

# Split source documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
all_splits = text_splitter.split_documents(documents)

# Create embeddings and vector store
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {"device": "cuda"}

embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
vectorstore = FAISS.from_documents(all_splits, embeddings)

# Create the conversation retrieval chain
chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)

# Streamlit app
def main():
    st.title("AI Chatbot")

    user_question = st.text_input("Ask a question:")

    sources = [
        "Source 1",
        "Source 2",
        "Source 3",
        # Add more sources as needed
    ]
    selected_source = st.selectbox("Select a source:", sources)

    if st.button("Get Answer"):
        chat_history = []

        query = user_question
        result = chain({"question": query, "chat_history": chat_history})

        st.write("Answer:", result["answer"])

        chat_history.append((query, result["answer"]))

        if "source_documents" in result:
            st.write("Source Documents:")
            for source_doc in result["source_documents"]:
                st.write(source_doc)

if __name__ == "__main__":
    main()