anirudh248 commited on
Commit
d2334dd
·
verified ·
1 Parent(s): ad009a2

Update retrieval_qa_pipeline.py

Browse files
Files changed (1) hide show
  1. retrieval_qa_pipeline.py +18 -20
retrieval_qa_pipeline.py CHANGED
@@ -1,5 +1,3 @@
1
- # retrieval_qa_pipeline.py
2
-
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.vectorstores import FAISS
@@ -10,26 +8,26 @@ from datasets import load_dataset
10
  def load_model_and_tokenizer(model_name: str):
11
  """
12
  Load the pre-trained model and tokenizer from the Hugging Face Hub.
13
-
14
  Args:
15
  model_name (str): The Hugging Face repository name of the model.
16
-
17
  Returns:
18
  model: The loaded model.
19
  tokenizer: The loaded tokenizer.
20
  """
21
  print(f"Loading model and tokenizer from {model_name}...")
22
- model = AutoModelForCausalLM.from_pretrained(model_name)
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
  return model, tokenizer
25
 
26
  def load_dataset_from_hf(dataset_name: str):
27
  """
28
  Load the dataset from the Hugging Face Hub.
29
-
30
  Args:
31
  dataset_name (str): The Hugging Face repository name of the dataset.
32
-
33
  Returns:
34
  texts (list): The text descriptions from the dataset.
35
  metadatas (list): Metadata for each text (e.g., upf_code).
@@ -43,27 +41,27 @@ def load_dataset_from_hf(dataset_name: str):
43
  def load_faiss_index(faiss_index_path: str):
44
  """
45
  Load the FAISS index and associated embeddings.
46
-
47
  Args:
48
  faiss_index_path (str): Path to the saved FAISS index.
49
-
50
  Returns:
51
  vectorstore (FAISS): The FAISS vector store.
52
  """
53
  print(f"Loading FAISS index from {faiss_index_path}...")
54
  embeddings = HuggingFaceEmbeddings() # Default embeddings
55
- vectorstore = FAISS.load_local(faiss_index_path, embeddings)
56
  return vectorstore
57
 
58
  def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
59
  """
60
  Build the retrieval-based QA pipeline.
61
-
62
  Args:
63
  model: The pre-trained model.
64
  tokenizer: The tokenizer associated with the model.
65
  vectorstore (FAISS): The FAISS vector store for retrieval.
66
-
67
  Returns:
68
  qa_chain (RetrievalQA): The retrieval-based QA pipeline.
69
  """
@@ -77,11 +75,11 @@ def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
77
  top_p=0.95,
78
  repetition_penalty=1.15
79
  )
80
-
81
  llm = HuggingFacePipeline(pipeline=hf_pipeline)
82
  retriever = vectorstore.as_retriever()
83
  qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
84
-
85
  return qa_chain
86
 
87
  def main():
@@ -89,21 +87,21 @@ def main():
89
  model_name = "anirudh248/upf_code_generator_final"
90
  dataset_name = "PranavKeshav/upf_dataset"
91
  faiss_index_path = "faiss_index"
92
-
93
  print("Starting pipeline setup...")
94
-
95
  # Load model and tokenizer
96
  model, tokenizer = load_model_and_tokenizer(model_name)
97
-
98
  # Load dataset
99
  texts, metadatas = load_dataset_from_hf(dataset_name)
100
-
101
  # Load FAISS index
102
  vectorstore = load_faiss_index(faiss_index_path)
103
-
104
  # Build QA pipeline
105
  qa_chain = build_retrieval_qa_pipeline(model, tokenizer, vectorstore)
106
-
107
  # Test the pipeline
108
  print("Pipeline is ready! You can now ask questions.")
109
  while True:
 
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.vectorstores import FAISS
 
8
  def load_model_and_tokenizer(model_name: str):
9
  """
10
  Load the pre-trained model and tokenizer from the Hugging Face Hub.
11
+
12
  Args:
13
  model_name (str): The Hugging Face repository name of the model.
14
+
15
  Returns:
16
  model: The loaded model.
17
  tokenizer: The loaded tokenizer.
18
  """
19
  print(f"Loading model and tokenizer from {model_name}...")
20
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  return model, tokenizer
23
 
24
  def load_dataset_from_hf(dataset_name: str):
25
  """
26
  Load the dataset from the Hugging Face Hub.
27
+
28
  Args:
29
  dataset_name (str): The Hugging Face repository name of the dataset.
30
+
31
  Returns:
32
  texts (list): The text descriptions from the dataset.
33
  metadatas (list): Metadata for each text (e.g., upf_code).
 
41
  def load_faiss_index(faiss_index_path: str):
42
  """
43
  Load the FAISS index and associated embeddings.
44
+
45
  Args:
46
  faiss_index_path (str): Path to the saved FAISS index.
47
+
48
  Returns:
49
  vectorstore (FAISS): The FAISS vector store.
50
  """
51
  print(f"Loading FAISS index from {faiss_index_path}...")
52
  embeddings = HuggingFaceEmbeddings() # Default embeddings
53
+ vectorstore = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
54
  return vectorstore
55
 
56
  def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
57
  """
58
  Build the retrieval-based QA pipeline.
59
+
60
  Args:
61
  model: The pre-trained model.
62
  tokenizer: The tokenizer associated with the model.
63
  vectorstore (FAISS): The FAISS vector store for retrieval.
64
+
65
  Returns:
66
  qa_chain (RetrievalQA): The retrieval-based QA pipeline.
67
  """
 
75
  top_p=0.95,
76
  repetition_penalty=1.15
77
  )
78
+
79
  llm = HuggingFacePipeline(pipeline=hf_pipeline)
80
  retriever = vectorstore.as_retriever()
81
  qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
82
+
83
  return qa_chain
84
 
85
  def main():
 
87
  model_name = "anirudh248/upf_code_generator_final"
88
  dataset_name = "PranavKeshav/upf_dataset"
89
  faiss_index_path = "faiss_index"
90
+
91
  print("Starting pipeline setup...")
92
+
93
  # Load model and tokenizer
94
  model, tokenizer = load_model_and_tokenizer(model_name)
95
+
96
  # Load dataset
97
  texts, metadatas = load_dataset_from_hf(dataset_name)
98
+
99
  # Load FAISS index
100
  vectorstore = load_faiss_index(faiss_index_path)
101
+
102
  # Build QA pipeline
103
  qa_chain = build_retrieval_qa_pipeline(model, tokenizer, vectorstore)
104
+
105
  # Test the pipeline
106
  print("Pipeline is ready! You can now ask questions.")
107
  while True: