WinterGYC commited on
Commit
feabcc6
·
1 Parent(s): 53131d6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -23
handler.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- import logging
5
 
6
  # get dtype
7
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
@@ -9,29 +9,13 @@ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
  # load the model
12
- self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
13
- self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=dtype, trust_remote_code=True)
14
- # create inference pipeline
15
- self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
16
 
17
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  inputs = data.pop("inputs", data)
19
- parameters = data.pop("parameters", None)
20
-
21
- # pass inputs with all kwargs in data
22
- if parameters is not None:
23
- prediction = self.pipeline(inputs, **parameters)
24
- else:
25
- prediction = self.pipeline(inputs)
26
- logging.warn("---start---")
27
- logging.warn(prediction)
28
- logging.warn("---end---")
29
-
30
  # ignoring parameters! Default to configs in generation_config.json.
31
- messages = [{"role": "user", "content": data.pop("inputs", data)}]
32
- response = self.model.chat(self.tokenizer, messages)
33
- logging.warn("---start chat response---")
34
- logging.warn(response)
35
- logging.warn("---end chat response---")
36
-
37
- return [[{response: 1.0}]]
 
1
  import torch
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ from transformers.generation.utils import GenerationConfig
5
 
6
  # get dtype
7
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
  # load the model
12
+ this.tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Chat", use_fast=False, trust_remote_code=True)
13
+ this.model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Chat", device_map="auto", torch_dtype=dtype, trust_remote_code=True)
14
+ this.model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan-13B-Chat")
 
15
 
16
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
17
  inputs = data.pop("inputs", data)
 
 
 
 
 
 
 
 
 
 
 
18
  # ignoring parameters! Default to configs in generation_config.json.
19
+ messages = [{"role": "user", "content": inputs}]
20
+ response = this.model.chat(this.tokenizer, messages)
21
+ return [{'generated_text': response}]