upf_code_generator_final / retrieval_qa_pipeline.py
anirudh248's picture
Update retrieval_qa_pipeline.py
d2334dd verified
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from datasets import load_dataset
def load_model_and_tokenizer(model_name: str):
"""
Load the pre-trained model and tokenizer from the Hugging Face Hub.
Args:
model_name (str): The Hugging Face repository name of the model.
Returns:
model: The loaded model.
tokenizer: The loaded tokenizer.
"""
print(f"Loading model and tokenizer from {model_name}...")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
def load_dataset_from_hf(dataset_name: str):
"""
Load the dataset from the Hugging Face Hub.
Args:
dataset_name (str): The Hugging Face repository name of the dataset.
Returns:
texts (list): The text descriptions from the dataset.
metadatas (list): Metadata for each text (e.g., upf_code).
"""
print(f"Loading dataset from {dataset_name}...")
dataset = load_dataset(dataset_name)
texts = dataset["train"]["power_intent_description"]
metadatas = [{"upf_code": code} for code in dataset["train"]["upf_code"]]
return texts, metadatas
def load_faiss_index(faiss_index_path: str):
"""
Load the FAISS index and associated embeddings.
Args:
faiss_index_path (str): Path to the saved FAISS index.
Returns:
vectorstore (FAISS): The FAISS vector store.
"""
print(f"Loading FAISS index from {faiss_index_path}...")
embeddings = HuggingFaceEmbeddings() # Default embeddings
vectorstore = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
return vectorstore
def build_retrieval_qa_pipeline(model, tokenizer, vectorstore):
"""
Build the retrieval-based QA pipeline.
Args:
model: The pre-trained model.
tokenizer: The tokenizer associated with the model.
vectorstore (FAISS): The FAISS vector store for retrieval.
Returns:
qa_chain (RetrievalQA): The retrieval-based QA pipeline.
"""
print("Building the retrieval-based QA pipeline...")
hf_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15
)
llm = HuggingFacePipeline(pipeline=hf_pipeline)
retriever = vectorstore.as_retriever()
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
return qa_chain
def main():
# Replace these names with your model and dataset repo names
model_name = "anirudh248/upf_code_generator_final"
dataset_name = "PranavKeshav/upf_dataset"
faiss_index_path = "faiss_index"
print("Starting pipeline setup...")
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_name)
# Load dataset
texts, metadatas = load_dataset_from_hf(dataset_name)
# Load FAISS index
vectorstore = load_faiss_index(faiss_index_path)
# Build QA pipeline
qa_chain = build_retrieval_qa_pipeline(model, tokenizer, vectorstore)
# Test the pipeline
print("Pipeline is ready! You can now ask questions.")
while True:
query = input("Enter your query (or type 'exit' to quit): ")
if query.lower() == "exit":
print("Exiting...")
break
response = qa_chain.run(query)
print(f"Response: {response}")
if __name__ == "__main__":
main()