Devops-hestabit
commited on
Commit
·
d02a1da
1
Parent(s):
96deb47
Update handler.py
Browse files- handler.py +13 -7
handler.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
import re
|
|
|
3 |
import torch
|
4 |
|
5 |
template = """Alice Gate's Persona: Alice Gate is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
|
@@ -26,15 +27,19 @@ class EndpointHandler():
|
|
26 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
27 |
self.model = AutoModelForCausalLM.from_pretrained(
|
28 |
path,
|
29 |
-
low_cpu_mem_usage = True,
|
30 |
trust_remote_code = False,
|
31 |
-
torch_dtype = torch.float16
|
32 |
).to('cuda')
|
33 |
|
34 |
def response(self, result, user_name):
|
35 |
result = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
|
36 |
-
|
|
|
37 |
result = " ".join(result.split())
|
|
|
|
|
|
|
38 |
return {
|
39 |
"message": result
|
40 |
}
|
@@ -43,11 +48,12 @@ class EndpointHandler():
|
|
43 |
inputs = data.pop("inputs", data)
|
44 |
user_name = inputs["user_name"]
|
45 |
user_input = "\n".join(inputs["user_input"])
|
|
|
|
|
|
|
|
|
46 |
input_ids = self.tokenizer(
|
47 |
-
|
48 |
-
user_name = user_name,
|
49 |
-
user_input = user_input
|
50 |
-
),
|
51 |
return_tensors = "pt"
|
52 |
).to("cuda")
|
53 |
generator = self.model.generate(
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
import re
|
3 |
+
import time
|
4 |
import torch
|
5 |
|
6 |
template = """Alice Gate's Persona: Alice Gate is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
|
|
|
27 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
28 |
self.model = AutoModelForCausalLM.from_pretrained(
|
29 |
path,
|
30 |
+
low_cpu_mem_usage = True,
|
31 |
trust_remote_code = False,
|
32 |
+
torch_dtype = torch.float16
|
33 |
).to('cuda')
|
34 |
|
35 |
def response(self, result, user_name):
|
36 |
result = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
|
37 |
+
parsed_result = re.sub('\*.*?\*', '', result).strip()
|
38 |
+
result = parsed_result if len(parsed_result) != 0 else result.replace("*","")
|
39 |
result = " ".join(result.split())
|
40 |
+
try:
|
41 |
+
result = result[:[m.start() for m in re.finditer(r'[.!?]', result)][-1]+1]
|
42 |
+
except Exception: pass
|
43 |
return {
|
44 |
"message": result
|
45 |
}
|
|
|
48 |
inputs = data.pop("inputs", data)
|
49 |
user_name = inputs["user_name"]
|
50 |
user_input = "\n".join(inputs["user_input"])
|
51 |
+
prompt = template.format(
|
52 |
+
user_name = user_name,
|
53 |
+
user_input = user_input
|
54 |
+
)
|
55 |
input_ids = self.tokenizer(
|
56 |
+
prompt,
|
|
|
|
|
|
|
57 |
return_tensors = "pt"
|
58 |
).to("cuda")
|
59 |
generator = self.model.generate(
|