Hjgugugjhuhjggg commited on
Commit
99862b8
·
verified ·
1 Parent(s): bdfd0ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -74,28 +74,22 @@ class S3ModelLoader:
74
  logging.info(f"Trying to load {model_name} from S3...")
75
  config = AutoConfig.from_pretrained(s3_uri)
76
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
- tokenizer = AutoTokenizer.from_pretrained(s3_uri)
78
 
79
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
- if config.pad_token_id is not None:
81
- tokenizer.pad_token_id = config.pad_token_id
82
- else:
83
- tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
84
 
85
  logging.info(f"Loaded {model_name} from S3 successfully.")
86
  return model, tokenizer
87
  except EnvironmentError:
88
  logging.info(f"Model {model_name} not found in S3. Downloading...")
89
  try:
90
- model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
91
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
92
-
 
93
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
94
- config = AutoConfig.from_pretrained(model_name)
95
- if config.pad_token_id is not None:
96
- tokenizer.pad_token_id = config.pad_token_id
97
- else:
98
- tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0
99
 
100
  logging.info(f"Downloaded {model_name} successfully.")
101
  logging.info(f"Saving {model_name} to S3...")
@@ -128,10 +122,7 @@ async def generate(request: Request, body: GenerateRequest):
128
  top_k=validated_body.top_k,
129
  repetition_penalty=validated_body.repetition_penalty,
130
  do_sample=validated_body.do_sample,
131
- num_return_sequences=validated_body.num_return_sequences,
132
- stopping_criteria=StoppingCriteriaList(
133
- [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
134
- )
135
  )
136
 
137
  async def stream_text():
@@ -149,7 +140,11 @@ async def generate(request: Request, body: GenerateRequest):
149
 
150
  generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
151
 
152
- output = model.generate(**encoded_input, generation_config=generation_config)
 
 
 
 
153
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
154
  generated_text += chunk
155
  yield chunk
 
74
  logging.info(f"Trying to load {model_name} from S3...")
75
  config = AutoConfig.from_pretrained(s3_uri)
76
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
78
 
79
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
81
 
82
  logging.info(f"Loaded {model_name} from S3 successfully.")
83
  return model, tokenizer
84
  except EnvironmentError:
85
  logging.info(f"Model {model_name} not found in S3. Downloading...")
86
  try:
87
+ config = AutoConfig.from_pretrained(model_name)
88
+ tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
89
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
90
+
91
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
 
93
 
94
  logging.info(f"Downloaded {model_name} successfully.")
95
  logging.info(f"Saving {model_name} to S3...")
 
122
  top_k=validated_body.top_k,
123
  repetition_penalty=validated_body.repetition_penalty,
124
  do_sample=validated_body.do_sample,
125
+ num_return_sequences=validated_body.num_return_sequences
 
 
 
126
  )
127
 
128
  async def stream_text():
 
140
 
141
  generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
142
 
143
+ stopping_criteria = StoppingCriteriaList(
144
+ [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
145
+ )
146
+
147
+ output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
148
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
149
  generated_text += chunk
150
  yield chunk