vtiyyal1 commited on
Commit
05e0faf
·
verified ·
1 Parent(s): 229d228

Upload 3 files

Browse files

fixed rerank to open ai

Files changed (2) hide show
  1. full_chain.py +4 -3
  2. rerank.py +77 -72
full_chain.py CHANGED
@@ -2,13 +2,14 @@ import os
2
  import pandas as pd
3
  from get_keywords import get_keywords
4
  from get_articles import save_solr_articles_full
5
- from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
6
- crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
7
  #from feed_to_llm import feed_articles_to_gpt_with_links
 
8
  from feed_to_llm_v2 import feed_articles_to_gpt_with_links
9
 
10
  def get_response(question, rerank_type="crossencoder", llm_type="chat"):
11
- csv_path = save_solr_articles_full(question, keyword_type="rake", num_articles=10)
12
  reranked_out = crossencoder_rerank_answer(csv_path, question)
13
  return feed_articles_to_gpt_with_links(reranked_out, question)
14
 
 
2
  import pandas as pd
3
  from get_keywords import get_keywords
4
  from get_articles import save_solr_articles_full
5
+ # from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
6
+ # crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
7
  #from feed_to_llm import feed_articles_to_gpt_with_links
8
+ from rerank import crossencoder_rerank_answer
9
  from feed_to_llm_v2 import feed_articles_to_gpt_with_links
10
 
11
  def get_response(question, rerank_type="crossencoder", llm_type="chat"):
12
+ csv_path = save_solr_articles_full(question, keyword_type="rake", num_articles=15)
13
  reranked_out = crossencoder_rerank_answer(csv_path, question)
14
  return feed_articles_to_gpt_with_links(reranked_out, question)
15
 
rerank.py CHANGED
@@ -1,11 +1,16 @@
1
  # reranks the top articles from a given csv file
2
- from langchain_openai import ChatOpenAI
3
- from langchain.chains import RetrievalQA
4
- from langchain_community.document_loaders.csv_loader import CSVLoader
5
- from langchain_community.vectorstores import DocArrayInMemorySearch
6
  from sentence_transformers import CrossEncoder
7
  import pandas as pd
8
  import time
 
 
 
 
 
9
 
10
  """
11
  This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
@@ -24,73 +29,73 @@ Update: Use langchain_RAG instead.
24
  """
25
 
26
 
27
- def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
28
- llm = ChatOpenAI(temperature=0.0)
29
- loader = CSVLoader(csv_path, source_column="url")
30
-
31
- index = VectorstoreIndexCreator(
32
- vectorstore_cls=DocArrayInMemorySearch,
33
- ).from_loaders([loader])
34
-
35
- # prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
36
- # {context}
37
- # Question: {question}"""
38
- # PROMPT = PromptTemplate(
39
- # template=prompt_template, input_variables=["context", "question"]
40
- # )
41
- # chain_type_kwargs = {"prompt": PROMPT}
42
-
43
- qa = RetrievalQA.from_chain_type(
44
- llm=llm,
45
- chain_type="stuff",
46
- retriever=index.vectorstore.as_retriever(),
47
- verbose=False,
48
- return_source_documents=True,
49
- # chain_type_kwargs=chain_type_kwargs,
50
- # chain_type_kwargs = {
51
- # "document_separator": "<<<<>>>>>"
52
- # },
53
- )
54
-
55
- answer = qa({"query": question})
56
- sources = answer['source_documents']
57
- sources_out = [source.metadata['source'] for source in sources]
58
-
59
- return answer['result'], sources_out
60
-
61
-
62
- """
63
- Langchain with sources.
64
- This function is deprecated. Use langchain_RAG instead.
65
- """
66
-
67
-
68
- def langchain_with_sources(csv_path, question, top_n=4):
69
- llm = ChatOpenAI(temperature=0.0)
70
- loader = CSVLoader(csv_path, source_column="uuid")
71
- index = VectorstoreIndexCreator(
72
- vectorstore_cls=DocArrayInMemorySearch,
73
- ).from_loaders([loader])
74
-
75
- qa = RetrievalQAWithSourcesChain.from_chain_type(
76
- llm=llm,
77
- chain_type="stuff",
78
- retriever=index.vectorstore.as_retriever(),
79
- )
80
- output = qa({"question": question}, return_only_outputs=True)
81
- return output['answer'], output['sources']
82
-
83
-
84
- """
85
- Reranks the top articles using crossencoder.
86
- Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
87
- Input:
88
- csv_path: str
89
- question: str
90
- top_n: int
91
- Output:
92
- out_values: list of [content, uuid, title]
93
- """
94
 
95
 
96
  # returns list of top n similar articles using crossencoder
@@ -187,7 +192,7 @@ def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=1
187
  new_uuids = []
188
  new_titles = []
189
  new_domains = []
190
-
191
  for idx in range(len(contents)):
192
  sents = sent_tokenize(contents[idx])
193
  sents_merged = []
 
1
  # reranks the top articles from a given csv file
2
+ # from langchain_openai import ChatOpenAI
3
+ # from langchain.chains import RetrievalQA
4
+ # from langchain_community.document_loaders.csv_loader import CSVLoader
5
+ # from langchain_community.vectorstores import DocArrayInMemorySearch
6
  from sentence_transformers import CrossEncoder
7
  import pandas as pd
8
  import time
9
+ import nltk
10
+ nltk.download('stopwords')
11
+ nltk.download('punkt')
12
+ from nltk.tokenize import sent_tokenize
13
+
14
 
15
  """
16
  This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
 
29
  """
30
 
31
 
32
+ # def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
33
+ # llm = ChatOpenAI(temperature=0.0)
34
+ # loader = CSVLoader(csv_path, source_column="url")
35
+
36
+ # index = VectorstoreIndexCreator(
37
+ # vectorstore_cls=DocArrayInMemorySearch,
38
+ # ).from_loaders([loader])
39
+
40
+ # # prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
41
+ # # {context}
42
+ # # Question: {question}"""
43
+ # # PROMPT = PromptTemplate(
44
+ # # template=prompt_template, input_variables=["context", "question"]
45
+ # # )
46
+ # # chain_type_kwargs = {"prompt": PROMPT}
47
+
48
+ # qa = RetrievalQA.from_chain_type(
49
+ # llm=llm,
50
+ # chain_type="stuff",
51
+ # retriever=index.vectorstore.as_retriever(),
52
+ # verbose=False,
53
+ # return_source_documents=True,
54
+ # # chain_type_kwargs=chain_type_kwargs,
55
+ # # chain_type_kwargs = {
56
+ # # "document_separator": "<<<<>>>>>"
57
+ # # },
58
+ # )
59
+
60
+ # answer = qa({"query": question})
61
+ # sources = answer['source_documents']
62
+ # sources_out = [source.metadata['source'] for source in sources]
63
+
64
+ # return answer['result'], sources_out
65
+
66
+
67
+ # """
68
+ # Langchain with sources.
69
+ # This function is deprecated. Use langchain_RAG instead.
70
+ # """
71
+
72
+
73
+ # def langchain_with_sources(csv_path, question, top_n=4):
74
+ # llm = ChatOpenAI(temperature=0.0)
75
+ # loader = CSVLoader(csv_path, source_column="uuid")
76
+ # index = VectorstoreIndexCreator(
77
+ # vectorstore_cls=DocArrayInMemorySearch,
78
+ # ).from_loaders([loader])
79
+
80
+ # qa = RetrievalQAWithSourcesChain.from_chain_type(
81
+ # llm=llm,
82
+ # chain_type="stuff",
83
+ # retriever=index.vectorstore.as_retriever(),
84
+ # )
85
+ # output = qa({"question": question}, return_only_outputs=True)
86
+ # return output['answer'], output['sources']
87
+
88
+
89
+ # """
90
+ # Reranks the top articles using crossencoder.
91
+ # Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
92
+ # Input:
93
+ # csv_path: str
94
+ # question: str
95
+ # top_n: int
96
+ # Output:
97
+ # out_values: list of [content, uuid, title]
98
+ # """
99
 
100
 
101
  # returns list of top n similar articles using crossencoder
 
192
  new_uuids = []
193
  new_titles = []
194
  new_domains = []
195
+
196
  for idx in range(len(contents)):
197
  sents = sent_tokenize(contents[idx])
198
  sents_merged = []