Upload 3 files
Browse filesfixed rerank to open ai
- full_chain.py +4 -3
- rerank.py +77 -72
full_chain.py
CHANGED
@@ -2,13 +2,14 @@ import os
|
|
2 |
import pandas as pd
|
3 |
from get_keywords import get_keywords
|
4 |
from get_articles import save_solr_articles_full
|
5 |
-
from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
|
6 |
-
|
7 |
#from feed_to_llm import feed_articles_to_gpt_with_links
|
|
|
8 |
from feed_to_llm_v2 import feed_articles_to_gpt_with_links
|
9 |
|
10 |
def get_response(question, rerank_type="crossencoder", llm_type="chat"):
|
11 |
-
csv_path = save_solr_articles_full(question, keyword_type="rake", num_articles=
|
12 |
reranked_out = crossencoder_rerank_answer(csv_path, question)
|
13 |
return feed_articles_to_gpt_with_links(reranked_out, question)
|
14 |
|
|
|
2 |
import pandas as pd
|
3 |
from get_keywords import get_keywords
|
4 |
from get_articles import save_solr_articles_full
|
5 |
+
# from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
|
6 |
+
# crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
|
7 |
#from feed_to_llm import feed_articles_to_gpt_with_links
|
8 |
+
from rerank import crossencoder_rerank_answer
|
9 |
from feed_to_llm_v2 import feed_articles_to_gpt_with_links
|
10 |
|
11 |
def get_response(question, rerank_type="crossencoder", llm_type="chat"):
|
12 |
+
csv_path = save_solr_articles_full(question, keyword_type="rake", num_articles=15)
|
13 |
reranked_out = crossencoder_rerank_answer(csv_path, question)
|
14 |
return feed_articles_to_gpt_with_links(reranked_out, question)
|
15 |
|
rerank.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
# reranks the top articles from a given csv file
|
2 |
-
from langchain_openai import ChatOpenAI
|
3 |
-
from langchain.chains import RetrievalQA
|
4 |
-
from langchain_community.document_loaders.csv_loader import CSVLoader
|
5 |
-
from langchain_community.vectorstores import DocArrayInMemorySearch
|
6 |
from sentence_transformers import CrossEncoder
|
7 |
import pandas as pd
|
8 |
import time
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
"""
|
11 |
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
|
@@ -24,73 +29,73 @@ Update: Use langchain_RAG instead.
|
|
24 |
"""
|
25 |
|
26 |
|
27 |
-
def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
"""
|
63 |
-
|
64 |
-
|
65 |
-
"""
|
66 |
-
|
67 |
-
|
68 |
-
def langchain_with_sources(csv_path, question, top_n=4):
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
"""
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
"""
|
94 |
|
95 |
|
96 |
# returns list of top n similar articles using crossencoder
|
@@ -187,7 +192,7 @@ def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=1
|
|
187 |
new_uuids = []
|
188 |
new_titles = []
|
189 |
new_domains = []
|
190 |
-
|
191 |
for idx in range(len(contents)):
|
192 |
sents = sent_tokenize(contents[idx])
|
193 |
sents_merged = []
|
|
|
1 |
# reranks the top articles from a given csv file
|
2 |
+
# from langchain_openai import ChatOpenAI
|
3 |
+
# from langchain.chains import RetrievalQA
|
4 |
+
# from langchain_community.document_loaders.csv_loader import CSVLoader
|
5 |
+
# from langchain_community.vectorstores import DocArrayInMemorySearch
|
6 |
from sentence_transformers import CrossEncoder
|
7 |
import pandas as pd
|
8 |
import time
|
9 |
+
import nltk
|
10 |
+
nltk.download('stopwords')
|
11 |
+
nltk.download('punkt')
|
12 |
+
from nltk.tokenize import sent_tokenize
|
13 |
+
|
14 |
|
15 |
"""
|
16 |
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
|
|
|
29 |
"""
|
30 |
|
31 |
|
32 |
+
# def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
|
33 |
+
# llm = ChatOpenAI(temperature=0.0)
|
34 |
+
# loader = CSVLoader(csv_path, source_column="url")
|
35 |
+
|
36 |
+
# index = VectorstoreIndexCreator(
|
37 |
+
# vectorstore_cls=DocArrayInMemorySearch,
|
38 |
+
# ).from_loaders([loader])
|
39 |
+
|
40 |
+
# # prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
41 |
+
# # {context}
|
42 |
+
# # Question: {question}"""
|
43 |
+
# # PROMPT = PromptTemplate(
|
44 |
+
# # template=prompt_template, input_variables=["context", "question"]
|
45 |
+
# # )
|
46 |
+
# # chain_type_kwargs = {"prompt": PROMPT}
|
47 |
+
|
48 |
+
# qa = RetrievalQA.from_chain_type(
|
49 |
+
# llm=llm,
|
50 |
+
# chain_type="stuff",
|
51 |
+
# retriever=index.vectorstore.as_retriever(),
|
52 |
+
# verbose=False,
|
53 |
+
# return_source_documents=True,
|
54 |
+
# # chain_type_kwargs=chain_type_kwargs,
|
55 |
+
# # chain_type_kwargs = {
|
56 |
+
# # "document_separator": "<<<<>>>>>"
|
57 |
+
# # },
|
58 |
+
# )
|
59 |
+
|
60 |
+
# answer = qa({"query": question})
|
61 |
+
# sources = answer['source_documents']
|
62 |
+
# sources_out = [source.metadata['source'] for source in sources]
|
63 |
+
|
64 |
+
# return answer['result'], sources_out
|
65 |
+
|
66 |
+
|
67 |
+
# """
|
68 |
+
# Langchain with sources.
|
69 |
+
# This function is deprecated. Use langchain_RAG instead.
|
70 |
+
# """
|
71 |
+
|
72 |
+
|
73 |
+
# def langchain_with_sources(csv_path, question, top_n=4):
|
74 |
+
# llm = ChatOpenAI(temperature=0.0)
|
75 |
+
# loader = CSVLoader(csv_path, source_column="uuid")
|
76 |
+
# index = VectorstoreIndexCreator(
|
77 |
+
# vectorstore_cls=DocArrayInMemorySearch,
|
78 |
+
# ).from_loaders([loader])
|
79 |
+
|
80 |
+
# qa = RetrievalQAWithSourcesChain.from_chain_type(
|
81 |
+
# llm=llm,
|
82 |
+
# chain_type="stuff",
|
83 |
+
# retriever=index.vectorstore.as_retriever(),
|
84 |
+
# )
|
85 |
+
# output = qa({"question": question}, return_only_outputs=True)
|
86 |
+
# return output['answer'], output['sources']
|
87 |
+
|
88 |
+
|
89 |
+
# """
|
90 |
+
# Reranks the top articles using crossencoder.
|
91 |
+
# Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
|
92 |
+
# Input:
|
93 |
+
# csv_path: str
|
94 |
+
# question: str
|
95 |
+
# top_n: int
|
96 |
+
# Output:
|
97 |
+
# out_values: list of [content, uuid, title]
|
98 |
+
# """
|
99 |
|
100 |
|
101 |
# returns list of top n similar articles using crossencoder
|
|
|
192 |
new_uuids = []
|
193 |
new_titles = []
|
194 |
new_domains = []
|
195 |
+
|
196 |
for idx in range(len(contents)):
|
197 |
sents = sent_tokenize(contents[idx])
|
198 |
sents_merged = []
|