kevin-pek commited on
Commit
0f7eddb
1 Parent(s): 91855c2

changed to langchain

Browse files
Files changed (2) hide show
  1. main.py +34 -37
  2. requirements.txt +57 -52
main.py CHANGED
@@ -1,45 +1,42 @@
1
- from haystack.nodes import PreProcessor, PDFToTextConverter, EmbeddingRetriever, TransformersReader
2
- from haystack.document_stores import InMemoryDocumentStore
3
- from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
 
 
 
 
 
 
4
  import gradio as gr
5
 
6
- preprocessor = PreProcessor(
7
- clean_empty_lines=True,
8
- clean_whitespace=True,
9
- clean_header_footer=True,
10
- split_by="word",
11
- split_length=100,
12
- split_respect_sentence_boundary=True,
13
- split_overlap=3
14
- )
15
- document_store = InMemoryDocumentStore(embedding_dim=384)
16
- reader = TransformersReader("sentence-transformers/all-MiniLM-L6-v2")
17
- retriever = EmbeddingRetriever(document_store=document_store, embedding_model="sentence-transformers/all-MiniLM-L6-v2")
18
- pipeline = ExtractiveQAPipeline(reader, retriever)
19
- converter = PDFToTextConverter(remove_numeric_tables=True)
20
 
21
- def print_answers(results):
22
- fields = ["answer", "score"] # "context"
23
- answers = results["answers"]
24
- filtered_answers = []
25
- for ans in answers:
26
- filtered_ans = {
27
- field: getattr(ans, field) for field in fields if getattr(ans, field) is not None
28
- }
29
- filtered_answers.append(filtered_ans)
30
- return filtered_answers
31
 
32
- def write_pdf(pdf_file):
33
- document = converter.convert(file_path=pdf_file.name, meta=None)[0]
34
- preprocessed_docs = preprocessor.process(document)
35
- document_store.write_documents(preprocessed_docs)
36
- document_store.update_embeddings(retriever)
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def predict(question, pdf_file):
39
- write_pdf(pdf_file)
40
- result = pipeline.run(query=question, params={"Retriever": { "top_k": 2 }})
41
- answers = print_answers(result)
42
- return answers
43
 
44
  interface = gr.Interface(
45
  fn=predict,
 
1
+ import os
2
+ import sys
3
+
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.document_loaders import DirectoryLoader, TextLoader
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.indexes import VectorstoreIndexCreator
8
+ from langchain.text_splitter import CharacterTextSplitter
9
+ from langchain.vectorstores import FAISS
10
  import gradio as gr
11
 
12
+ # Enable to cache & reuse the model to disk (for repeated queries on the same data)
13
+ PERSIST = False
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ query = sys.argv[1]
16
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
17
 
18
+ if PERSIST and os.path.exists("persist"):
19
+ print("Reusing index...\n")
20
+ raw_documents = DirectoryLoader("persist").load()
21
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
22
+ documents = text_splitter.split_documents(raw_documents)
23
+ vectorstore = FAISS.from_documents(documents, embedding=embeddings)
24
+ from langchain.indexes.vectorstore import VectorStoreIndexWrapper
25
+ index = VectorStoreIndexWrapper(vectorstore=vectorstore)
26
+ else:
27
+ loader = TextLoader('data.txt')
28
+ # This code can also import folders, including various filetypes like PDFs using the DirectoryLoader.
29
+ # loader = DirectoryLoader(".", glob="*.txt")
30
+ if PERSIST:
31
+ index = VectorstoreIndexCreator(vectorstore_kwargs={"persist_directory":"persist"}).from_loaders([loader])
32
+ else:
33
+ index = VectorstoreIndexCreator().from_loaders([loader])
34
 
35
+ chain = RetrievalQA.from_chain_type(
36
+ llm=,
37
+ retriever=index.vectorstore.as_retriever(search_kwargs={"k": 5}),
38
+ )
39
+ print(chain.run(query))
40
 
41
  interface = gr.Interface(
42
  fn=predict,
requirements.txt CHANGED
@@ -1,73 +1,78 @@
1
- accelerate==0.19.0
2
- appdirs==1.4.4
 
 
 
 
3
  attrs==23.1.0
4
- azure-ai-formrecognizer==3.3.0b1
5
- azure-common==1.1.28
6
- azure-core==1.27.0
7
- backoff==2.2.1
8
- boilerpy3==1.0.6
9
- canals==0.2.2
10
- cattrs==23.1.2
11
  certifi==2023.5.7
12
  charset-normalizer==3.1.0
13
  click==8.1.3
14
- dill==0.3.6
15
- docopt==0.6.2
16
- Events==0.4
17
- farm-haystack==1.17.1
18
- filelock==3.12.0
19
- fsspec==2023.5.0
20
- generalimport==0.3.1
 
 
 
 
 
 
 
 
 
21
  huggingface-hub==0.15.1
22
  idna==3.4
23
- inflect==6.0.4
24
- isodate==0.6.1
25
  Jinja2==3.1.2
26
- joblib==1.2.0
27
  jsonschema==4.17.3
 
 
 
 
 
28
  MarkupSafe==2.1.3
29
- monotonic==1.6
30
- more-itertools==9.1.0
31
- mpmath==1.3.0
32
- msrest==0.7.1
33
- networkx==3.1
34
- nltk==3.8.1
35
- num2words==0.5.12
36
- numpy==1.24.3
37
- oauthlib==3.2.2
 
 
38
  packaging==23.1
39
  pandas==2.0.2
40
  Pillow==9.5.0
41
- posthog==3.0.1
42
- protobuf==3.20.2
43
- psutil==5.9.5
44
- pydantic==1.10.8
45
  pyrsistent==0.19.3
46
  python-dateutil==2.8.2
 
47
  pytz==2023.3
48
  PyYAML==6.0
49
- quantulum3==0.9.0
50
- rank-bm25==0.2.2
51
- regex==2023.6.3
52
  requests==2.31.0
53
- requests-cache==0.9.8
54
- requests-oauthlib==1.3.1
55
- scikit-learn==1.2.2
56
- scipy==1.10.1
57
- sentence-transformers==2.2.2
58
- sentencepiece==0.1.99
59
  six==1.16.0
60
- sseclient-py==1.7.2
61
- sympy==1.12
 
62
  tenacity==8.2.2
63
- threadpoolctl==3.1.0
64
- tiktoken==0.4.0
65
- tokenizers==0.13.3
66
- torch==2.0.1
67
- torchvision==0.15.2
68
  tqdm==4.65.0
69
- transformers==4.29.1
70
- typing_extensions==4.5.0
71
  tzdata==2023.3
72
- url-normalize==1.4.3
73
- urllib3==2.0.2
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ anyio==3.7.0
6
+ async-timeout==4.0.2
7
  attrs==23.1.0
 
 
 
 
 
 
 
8
  certifi==2023.5.7
9
  charset-normalizer==3.1.0
10
  click==8.1.3
11
+ contourpy==1.1.0
12
+ cycler==0.11.0
13
+ dataclasses-json==0.5.8
14
+ exceptiongroup==1.1.1
15
+ faiss-cpu==1.7.4
16
+ fastapi==0.97.0
17
+ ffmpy==0.3.0
18
+ filelock==3.12.2
19
+ fonttools==4.40.0
20
+ frozenlist==1.3.3
21
+ fsspec==2023.6.0
22
+ gradio==3.35.2
23
+ gradio_client==0.2.7
24
+ h11==0.14.0
25
+ httpcore==0.17.2
26
+ httpx==0.24.1
27
  huggingface-hub==0.15.1
28
  idna==3.4
29
+ importlib-resources==5.12.0
 
30
  Jinja2==3.1.2
 
31
  jsonschema==4.17.3
32
+ kiwisolver==1.4.4
33
+ langchain==0.0.205
34
+ langchainplus-sdk==0.0.16
35
+ linkify-it-py==2.0.2
36
+ markdown-it-py==2.2.0
37
  MarkupSafe==2.1.3
38
+ marshmallow==3.19.0
39
+ marshmallow-enum==1.5.1
40
+ matplotlib==3.7.1
41
+ mdit-py-plugins==0.3.3
42
+ mdurl==0.1.2
43
+ multidict==6.0.4
44
+ mypy-extensions==1.0.0
45
+ numexpr==2.8.4
46
+ numpy==1.25.0
47
+ openapi-schema-pydantic==1.2.4
48
+ orjson==3.9.1
49
  packaging==23.1
50
  pandas==2.0.2
51
  Pillow==9.5.0
52
+ pydantic==1.10.9
53
+ pydub==0.25.1
54
+ Pygments==2.15.1
55
+ pyparsing==3.1.0
56
  pyrsistent==0.19.3
57
  python-dateutil==2.8.2
58
+ python-multipart==0.0.6
59
  pytz==2023.3
60
  PyYAML==6.0
 
 
 
61
  requests==2.31.0
62
+ semantic-version==2.10.0
 
 
 
 
 
63
  six==1.16.0
64
+ sniffio==1.3.0
65
+ SQLAlchemy==2.0.16
66
+ starlette==0.27.0
67
  tenacity==8.2.2
68
+ toolz==0.12.0
 
 
 
 
69
  tqdm==4.65.0
70
+ typing-inspect==0.9.0
71
+ typing_extensions==4.6.3
72
  tzdata==2023.3
73
+ uc-micro-py==1.0.2
74
+ urllib3==2.0.3
75
+ uvicorn==0.22.0
76
+ websockets==11.0.3
77
+ yarl==1.9.2
78
+ zipp==3.15.0