MrD05 commited on
Commit
911f092
·
1 Parent(s): 7007e60

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +69 -0
handler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ from langchain.llms import HuggingFacePipeline
3
+ from langchain import PromptTemplate, LLMChain
4
+ import torch
5
+
6
+ template = """{char_name}'s Persona: {char_persona}
7
+ <START>
8
+ {chat_history}
9
+ {char_name}: {char_greeting}
10
+ <END>
11
+ {user_name}: {user_input}
12
+ {char_name}: """
13
+
14
+ class EndpointHandler():
15
+
16
+ def __init__(self, path=""):
17
+ tokenizer = AutoTokenizer.from_pretrained(path,torch_dtype=torch.float32)
18
+ model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
19
+ local_llm = HuggingFacePipeline(
20
+ pipeline = pipeline(
21
+ "text-generation",
22
+ model = model,
23
+ tokenizer = tokenizer,
24
+ max_length = 2048,
25
+ temperature = 0.5,
26
+ top_p = 0.9,
27
+ top_k = 0,
28
+ repetition_penalty = 1.1,
29
+ pad_token_id = 50256,
30
+ num_return_sequences = 1
31
+ )
32
+ )
33
+ prompt_template = PromptTemplate(
34
+ template = template,
35
+ input_variables = [
36
+ "user_input",
37
+ "user_name",
38
+ "char_name",
39
+ "char_persona",
40
+ "char_greeting",
41
+ "chat_history"
42
+ ],
43
+ validate_template = True
44
+ )
45
+ self.llm_engine = LLMChain(
46
+ llm = local_llm,
47
+ prompt = prompt_template
48
+ )
49
+
50
+ def __call__(self, data):
51
+ inputs = data.pop("inputs", data)
52
+ try:
53
+ response = self.llm_engine.predict(
54
+ user_input = inputs["user_input"],
55
+ user_name = inputs["user_name"],
56
+ char_name = inputs["char_name"],
57
+ char_persona = inputs["char_persona"],
58
+ char_greeting = inputs["char_greeting"],
59
+ chat_history = inputs["chat_history"]
60
+ ).split("\n",1)[0]
61
+ return {
62
+ "inputs": inputs,
63
+ "text": response
64
+ }
65
+ except Exception as e:
66
+ return {
67
+ "inputs": inputs,
68
+ "error": str(e)
69
+ }