lferncastro commited on
Commit
7426d87
·
1 Parent(s): 5650230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -1
app.py CHANGED
@@ -1,4 +1,40 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def make_inference(query):
4
  inference = shakespeare_qa.run(query)
 
1
+ from langchain.text_splitter import CharacterTextSplitter
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import Chroma
4
+ from langchain import HuggingFacePipeline
5
+ from langchain.chains import RetrievalQA
6
+ from transformers import AutoTokenizer
7
+ import pickle
8
+ import os
9
+
10
+ with open('shakespeare.pkl', 'rb') as fp:
11
+ data = pickle.load(fp)
12
+
13
+ bloomz_tokenizer = AutoTokenizer.from_pretrained('bigscience/bloomz-1b7')
14
+
15
+ text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(bloomz_tokenizer, chunk_size=100, chunk_overlap=0, separator='\n')
16
+
17
+ documents = text_splitter.split_documents(data)
18
+
19
+ embeddings = HuggingFaceEmbeddings()
20
+
21
+ persist_directory = "vector_db"
22
+
23
+ vectordb = Chroma.from_documents(documents=documents, embedding=embeddings, persist_directory=persist_directory)
24
+
25
+ vectordb.persist()
26
+ vectordb = None
27
+
28
+ vectordb_persist = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
29
+
30
+ llm = HuggingFacePipeline.from_model_id(
31
+ model_id="bigscience/bloomz-1b7",
32
+ task="text-generation",
33
+ model_kwargs={"temperature" : 0, "max_length" : 500})
34
+
35
+ doc_retriever = vectordb_persist.as_retriever()
36
+
37
+ shakespeare_qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=doc_retriever)
38
 
39
  def make_inference(query):
40
  inference = shakespeare_qa.run(query)