PranavKeshav commited on
Commit
30e74c0
·
verified ·
1 Parent(s): 2ef59a1

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +58 -0
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ from langchain.vectorstores import FAISS
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.llms import HuggingFacePipeline
5
+ from langchain.chains import RetrievalQA
6
+ import torch
7
+
8
+ class Handler:
9
+ def __init__(self):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print(f"Using device: {self.device}")
12
+
13
+ # Load the fine-tuned model and tokenizer
14
+ print("Loading model and tokenizer...")
15
+ self.model = AutoModelForCausalLM.from_pretrained("PranavKeshav/upf_code_generator").to(self.device)
16
+ self.tokenizer = AutoTokenizer.from_pretrained("PranavKeshav/upf_code_generator").to(self.device)
17
+
18
+ # Load the FAISS index and embeddings
19
+ print("Loading FAISS index and embeddings...")
20
+ self.embeddings = HuggingFaceEmbeddings()
21
+ self.vectorstore = FAISS.load_local("faiss_index", self.embeddings, allow_dangerous_deserialization=True)
22
+
23
+ # Create the Hugging Face pipeline for text generation
24
+ print("Creating Hugging Face pipeline...")
25
+
26
+ def run_inference(prompt: str):
27
+ # Assuming 2048 is the desired max sequence length
28
+ return self.model.generate(
29
+ prompt, temperature=0.7, max_length=2048, top_p=0.95, repetition_penalty=1.15
30
+ )
31
+
32
+ self.hf_pipeline = pipeline(
33
+ "text-generation",
34
+ model=self.model,
35
+ tokenizer=self.tokenizer,
36
+ temperature=0.7,
37
+ max_new_tokens=2048,
38
+ top_p=0.95,
39
+ repetition_penalty=1.15
40
+ )
41
+
42
+ self.hf_pipeline.model.generate = run_inference
43
+ # Wrap the pipeline in LangChain
44
+ self.llm = HuggingFacePipeline(pipeline=self.hf_pipeline)
45
+
46
+ # Create the retriever and pipeline
47
+ self.retriever = self.vectorstore.as_retriever()
48
+ self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever)
49
+
50
+ def __call__(self, request):
51
+ # Get the prompt from the request
52
+ prompt = request.json.get("prompt")
53
+
54
+ # Generate UPF code using the QA chain
55
+ response = self.qa_chain.run(prompt)
56
+
57
+ # Return the response
58
+ return {"response": response}