Hjgugugjhuhjggg commited on
Commit
cea60ee
·
verified ·
1 Parent(s): c66e8e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -18
app.py CHANGED
@@ -14,7 +14,7 @@ from functools import cached_property
14
  import base64
15
  from optimum.onnxruntime import ORTModelForCausalLM
16
  from optimum.bettertransformer import BetterTransformer
17
- import bitsandbytes as bnb
18
 
19
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
20
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -41,7 +41,6 @@ class GenerateRequest(BaseModel):
41
  num_return_sequences: int = 1
42
  do_sample: bool = False
43
  stop_sequences: list[str] = []
44
- quantize: bool = True
45
  use_onnx: bool = False
46
  use_bettertransformer: bool = True
47
  @field_validator("model_name")
@@ -62,17 +61,12 @@ class S3ModelLoader:
62
  self.model_cache = {}
63
  def _get_s3_uri(self, model_name):
64
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
65
- async def _load_model_and_tokenizer(self, model_name, quantize, use_onnx, use_bettertransformer):
66
  s3_uri = self._get_s3_uri(model_name)
67
  try:
68
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
69
  if use_onnx:
70
  model = ORTModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
71
- elif quantize:
72
- model = AutoModelForCausalLM.from_pretrained(
73
- s3_uri, config=config, local_files_only=False,
74
- load_in_8bit=True
75
- ).to(self.device)
76
  else:
77
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
78
  if use_bettertransformer:
@@ -87,11 +81,6 @@ class S3ModelLoader:
87
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
88
  if use_onnx:
89
  model = ORTModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
90
- elif quantize:
91
- model = AutoModelForCausalLM.from_pretrained(
92
- model_name, config=config, token=HUGGINGFACE_HUB_TOKEN,
93
- load_in_8bit=True
94
- ).to(self.device)
95
  else:
96
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
97
  if use_bettertransformer:
@@ -104,10 +93,10 @@ class S3ModelLoader:
104
  @cached_property
105
  def device(self):
106
  return torch.device("cpu")
107
- async def get_model_and_tokenizer(self, model_name, quantize, use_onnx, use_bettertransformer):
108
- key = f"{model_name}-{quantize}-{use_onnx}-{use_bettertransformer}"
109
  if key not in self.model_cache:
110
- model, tokenizer = await self._load_model_and_tokenizer(model_name, quantize, use_onnx, use_bettertransformer)
111
  self.model_cache[key] = {"model":model, "tokenizer":tokenizer}
112
  return self.model_cache[key]["model"], self.model_cache[key]["tokenizer"]
113
  async def get_pipeline(self, model_name, task_type):
@@ -133,10 +122,9 @@ async def generate(request: GenerateRequest):
133
  num_return_sequences = request.num_return_sequences
134
  do_sample = request.do_sample
135
  stop_sequences = request.stop_sequences
136
- quantize = request.quantize
137
  use_onnx = request.use_onnx
138
  use_bettertransformer = request.use_bettertransformer
139
- model, tokenizer = await model_loader.get_model_and_tokenizer(model_name, quantize, use_onnx, use_bettertransformer)
140
  if "text-to-text" == task_type:
141
  generation_config = GenerationConfig(temperature=temperature,max_new_tokens=max_new_tokens,top_p=top_p,top_k=top_k,repetition_penalty=repetition_penalty,do_sample=do_sample,num_return_sequences=num_return_sequences,eos_token_id = tokenizer.eos_token_id)
142
  if stream:
 
14
  import base64
15
  from optimum.onnxruntime import ORTModelForCausalLM
16
  from optimum.bettertransformer import BetterTransformer
17
+
18
 
19
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
20
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
41
  num_return_sequences: int = 1
42
  do_sample: bool = False
43
  stop_sequences: list[str] = []
 
44
  use_onnx: bool = False
45
  use_bettertransformer: bool = True
46
  @field_validator("model_name")
 
61
  self.model_cache = {}
62
  def _get_s3_uri(self, model_name):
63
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
64
+ async def _load_model_and_tokenizer(self, model_name, use_onnx, use_bettertransformer):
65
  s3_uri = self._get_s3_uri(model_name)
66
  try:
67
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
68
  if use_onnx:
69
  model = ORTModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
 
 
 
 
 
70
  else:
71
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
72
  if use_bettertransformer:
 
81
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
82
  if use_onnx:
83
  model = ORTModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
 
 
 
 
 
84
  else:
85
  model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
86
  if use_bettertransformer:
 
93
  @cached_property
94
  def device(self):
95
  return torch.device("cpu")
96
+ async def get_model_and_tokenizer(self, model_name, use_onnx, use_bettertransformer):
97
+ key = f"{model_name}-{use_onnx}-{use_bettertransformer}"
98
  if key not in self.model_cache:
99
+ model, tokenizer = await self._load_model_and_tokenizer(model_name, use_onnx, use_bettertransformer)
100
  self.model_cache[key] = {"model":model, "tokenizer":tokenizer}
101
  return self.model_cache[key]["model"], self.model_cache[key]["tokenizer"]
102
  async def get_pipeline(self, model_name, task_type):
 
122
  num_return_sequences = request.num_return_sequences
123
  do_sample = request.do_sample
124
  stop_sequences = request.stop_sequences
 
125
  use_onnx = request.use_onnx
126
  use_bettertransformer = request.use_bettertransformer
127
+ model, tokenizer = await model_loader.get_model_and_tokenizer(model_name, use_onnx, use_bettertransformer)
128
  if "text-to-text" == task_type:
129
  generation_config = GenerationConfig(temperature=temperature,max_new_tokens=max_new_tokens,top_p=top_p,top_k=top_k,repetition_penalty=repetition_penalty,do_sample=do_sample,num_return_sequences=num_return_sequences,eos_token_id = tokenizer.eos_token_id)
130
  if stream: