Spaces:
Sleeping
Sleeping
Update pdfchatbot.py
Browse files- pdfchatbot.py +6 -37
pdfchatbot.py
CHANGED
@@ -6,16 +6,15 @@ import weaviate
|
|
6 |
import os
|
7 |
from PIL import Image
|
8 |
from config import MODEL_CONFIG
|
9 |
-
from
|
10 |
from langchain_weaviate.vectorstores import WeaviateVectorStore
|
11 |
from langchain.text_splitter import CharacterTextSplitter
|
12 |
-
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
13 |
from langchain.chains import ConversationalRetrievalChain
|
14 |
from langchain_community.document_loaders import PyPDFLoader
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
17 |
|
18 |
os.environ["HUGGINGFACE_API_TOKEN"] = os.getenv("HUGGINGFACE_API_TOKEN")
|
|
|
19 |
|
20 |
class PDFChatBot:
|
21 |
def __init__(self):
|
@@ -66,7 +65,7 @@ class PDFChatBot:
|
|
66 |
"""
|
67 |
Load embeddings from Hugging Face and set in the config file.
|
68 |
"""
|
69 |
-
self.embeddings =
|
70 |
|
71 |
def load_vectordb(self):
|
72 |
"""
|
@@ -82,42 +81,15 @@ class PDFChatBot:
|
|
82 |
|
83 |
self.vectordb = WeaviateVectorStore.from_documents(docs, self.embeddings, client=weaviate_client)
|
84 |
|
85 |
-
def load_tokenizer(self):
|
86 |
-
"""
|
87 |
-
Load the tokenizer from Hugging Face and set in the config file.
|
88 |
-
"""
|
89 |
-
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG.AUTO_TOKENIZER, token=os.getenv("HUGGINGFACE_API_TOKEN"))
|
90 |
-
|
91 |
-
def load_model(self):
|
92 |
-
"""
|
93 |
-
Load the causal language model from Hugging Face and set in the config file.
|
94 |
-
"""
|
95 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
96 |
-
MODEL_CONFIG.MODEL_LLM,
|
97 |
-
device_map='auto',
|
98 |
-
torch_dtype=torch.float32,
|
99 |
-
token=os.getenv("HUGGINGFACE_API_TOKEN"),
|
100 |
-
load_in_8bit=False
|
101 |
-
)
|
102 |
-
|
103 |
-
def create_pipeline(self):
|
104 |
-
"""
|
105 |
-
Create a pipeline for text generation using the loaded model and tokenizer.
|
106 |
-
"""
|
107 |
-
pipe = pipeline(
|
108 |
-
model=self.model,
|
109 |
-
task='text-generation',
|
110 |
-
tokenizer=self.tokenizer,
|
111 |
-
max_new_tokens=200
|
112 |
-
)
|
113 |
-
self.pipeline = HuggingFacePipeline(pipeline=pipe)
|
114 |
|
115 |
def create_chain(self):
|
116 |
"""
|
117 |
Create a Conversational Retrieval Chain
|
118 |
"""
|
|
|
|
|
119 |
self.chain = ConversationalRetrievalChain.from_llm(
|
120 |
-
|
121 |
chain_type="stuff",
|
122 |
retriever=self.vectordb.as_retriever(search_kwargs={"k": 1}),
|
123 |
condense_question_prompt=self.prompt,
|
@@ -135,9 +107,6 @@ class PDFChatBot:
|
|
135 |
self.documents = PyPDFLoader(file.name).load()
|
136 |
self.load_embeddings()
|
137 |
self.load_vectordb()
|
138 |
-
self.load_tokenizer()
|
139 |
-
self.load_model()
|
140 |
-
self.create_pipeline()
|
141 |
self.create_chain()
|
142 |
|
143 |
def generate_response(self, history, query, file):
|
|
|
6 |
import os
|
7 |
from PIL import Image
|
8 |
from config import MODEL_CONFIG
|
9 |
+
from langchain_openai import OpenAIEmbeddings
|
10 |
from langchain_weaviate.vectorstores import WeaviateVectorStore
|
11 |
from langchain.text_splitter import CharacterTextSplitter
|
|
|
12 |
from langchain.chains import ConversationalRetrievalChain
|
13 |
from langchain_community.document_loaders import PyPDFLoader
|
14 |
from langchain.prompts import PromptTemplate
|
|
|
15 |
|
16 |
os.environ["HUGGINGFACE_API_TOKEN"] = os.getenv("HUGGINGFACE_API_TOKEN")
|
17 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
18 |
|
19 |
class PDFChatBot:
|
20 |
def __init__(self):
|
|
|
65 |
"""
|
66 |
Load embeddings from Hugging Face and set in the config file.
|
67 |
"""
|
68 |
+
self.embeddings = OpenAIEmbeddings(model=MODEL_CONFIG.MODEL_EMBEDDINGS)
|
69 |
|
70 |
def load_vectordb(self):
|
71 |
"""
|
|
|
81 |
|
82 |
self.vectordb = WeaviateVectorStore.from_documents(docs, self.embeddings, client=weaviate_client)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def create_chain(self):
|
86 |
"""
|
87 |
Create a Conversational Retrieval Chain
|
88 |
"""
|
89 |
+
llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))
|
90 |
+
|
91 |
self.chain = ConversationalRetrievalChain.from_llm(
|
92 |
+
llm,
|
93 |
chain_type="stuff",
|
94 |
retriever=self.vectordb.as_retriever(search_kwargs={"k": 1}),
|
95 |
condense_question_prompt=self.prompt,
|
|
|
107 |
self.documents = PyPDFLoader(file.name).load()
|
108 |
self.load_embeddings()
|
109 |
self.load_vectordb()
|
|
|
|
|
|
|
110 |
self.create_chain()
|
111 |
|
112 |
def generate_response(self, history, query, file):
|