Devops-hestabit commited on
Commit
d02a1da
·
1 Parent(s): 96deb47

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -7
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import re
 
3
  import torch
4
 
5
  template = """Alice Gate's Persona: Alice Gate is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
@@ -26,15 +27,19 @@ class EndpointHandler():
26
  self.tokenizer = AutoTokenizer.from_pretrained(path)
27
  self.model = AutoModelForCausalLM.from_pretrained(
28
  path,
29
- low_cpu_mem_usage = True,
30
  trust_remote_code = False,
31
- torch_dtype = torch.float16,
32
  ).to('cuda')
33
 
34
  def response(self, result, user_name):
35
  result = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
36
- result = re.sub('\*.*?\*', '', result)
 
37
  result = " ".join(result.split())
 
 
 
38
  return {
39
  "message": result
40
  }
@@ -43,11 +48,12 @@ class EndpointHandler():
43
  inputs = data.pop("inputs", data)
44
  user_name = inputs["user_name"]
45
  user_input = "\n".join(inputs["user_input"])
 
 
 
 
46
  input_ids = self.tokenizer(
47
- template.format(
48
- user_name = user_name,
49
- user_input = user_input
50
- ),
51
  return_tensors = "pt"
52
  ).to("cuda")
53
  generator = self.model.generate(
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import re
3
+ import time
4
  import torch
5
 
6
  template = """Alice Gate's Persona: Alice Gate is a young, computer engineer-nerd with a knack for problem solving and a passion for technology.
 
27
  self.tokenizer = AutoTokenizer.from_pretrained(path)
28
  self.model = AutoModelForCausalLM.from_pretrained(
29
  path,
30
+ low_cpu_mem_usage = True,
31
  trust_remote_code = False,
32
+ torch_dtype = torch.float16
33
  ).to('cuda')
34
 
35
  def response(self, result, user_name):
36
  result = result.rsplit("Alice Gate:", 1)[1].split(f"{user_name}:",1)[0].strip()
37
+ parsed_result = re.sub('\*.*?\*', '', result).strip()
38
+ result = parsed_result if len(parsed_result) != 0 else result.replace("*","")
39
  result = " ".join(result.split())
40
+ try:
41
+ result = result[:[m.start() for m in re.finditer(r'[.!?]', result)][-1]+1]
42
+ except Exception: pass
43
  return {
44
  "message": result
45
  }
 
48
  inputs = data.pop("inputs", data)
49
  user_name = inputs["user_name"]
50
  user_input = "\n".join(inputs["user_input"])
51
+ prompt = template.format(
52
+ user_name = user_name,
53
+ user_input = user_input
54
+ )
55
  input_ids = self.tokenizer(
56
+ prompt,
 
 
 
57
  return_tensors = "pt"
58
  ).to("cuda")
59
  generator = self.model.generate(