Hjgugugjhuhjggg commited on
Commit
a0b48c5
·
verified ·
1 Parent(s): 31c0598

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -14,6 +14,7 @@ from transformers import (
14
  )
15
  import boto3
16
  import uvicorn
 
17
  from transformers import pipeline
18
  import json
19
  from huggingface_hub import login
@@ -77,7 +78,7 @@ class S3ModelLoader:
77
  return f"s3://{self.bucket_name}/" \
78
  f"{model_name.replace('/', '-')}"
79
 
80
- def load_model_and_tokenizer(self, model_name):
81
  if model_name in model_cache:
82
  return model_cache[model_name]
83
 
@@ -147,7 +148,7 @@ class S3ModelLoader:
147
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
148
 
149
  @app.post("/generate")
150
- def generate(request: GenerateRequest):
151
  try:
152
  model_name = request.model_name
153
  input_text = request.input_text
@@ -162,7 +163,7 @@ def generate(request: GenerateRequest):
162
  do_sample = request.do_sample
163
  stop_sequences = request.stop_sequences
164
 
165
- model, tokenizer = model_loader.load_model_and_tokenizer(model_name)
166
  device = "cuda" if torch.cuda.is_available() else "cpu"
167
  model.to(device)
168
 
@@ -185,7 +186,7 @@ def generate(request: GenerateRequest):
185
  media_type="text/plain"
186
  )
187
  else:
188
- result = generate_text(model, tokenizer, input_text,
189
  generation_config, stop_sequences,
190
  device)
191
  return JSONResponse({"text": result, "is_end": True})
@@ -212,7 +213,7 @@ class StopOnSequences(StoppingCriteria):
212
  return True
213
  return False
214
 
215
- def stream_text(model, tokenizer, input_text,
216
  generation_config, stop_sequences,
217
  device):
218
 
@@ -228,7 +229,7 @@ def stream_text(model, tokenizer, input_text,
228
 
229
  while True:
230
 
231
- outputs = model.generate(
232
  **encoded_input,
233
  do_sample=generation_config.do_sample,
234
  max_new_tokens=generation_config.max_new_tokens,
@@ -270,7 +271,7 @@ def stream_text(model, tokenizer, input_text,
270
  output_text = ""
271
 
272
 
273
- def generate_text(model, tokenizer, input_text,
274
  generation_config, stop_sequences,
275
  device):
276
  encoded_input = tokenizer(
@@ -281,7 +282,7 @@ def generate_text(model, tokenizer, input_text,
281
  stop_criteria = StopOnSequences(stop_sequences, tokenizer)
282
  stopping_criteria = StoppingCriteriaList([stop_criteria])
283
 
284
- outputs = model.generate(
285
  **encoded_input,
286
  do_sample=generation_config.do_sample,
287
  max_new_tokens=generation_config.max_new_tokens,
@@ -302,8 +303,9 @@ def generate_text(model, tokenizer, input_text,
302
 
303
  return generated_text
304
 
 
305
  @app.post("/generate-image")
306
- def generate_image(request: GenerateRequest):
307
  try:
308
  validated_body = request
309
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -331,7 +333,7 @@ def generate_image(request: GenerateRequest):
331
 
332
 
333
  @app.post("/generate-text-to-speech")
334
- def generate_text_to_speech(request: GenerateRequest):
335
  try:
336
  validated_body = request
337
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -362,7 +364,7 @@ def generate_text_to_speech(request: GenerateRequest):
362
 
363
 
364
  @app.post("/generate-video")
365
- def generate_video(request: GenerateRequest):
366
  try:
367
  validated_body = request
368
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
  )
15
  import boto3
16
  import uvicorn
17
+ import asyncio
18
  from transformers import pipeline
19
  import json
20
  from huggingface_hub import login
 
78
  return f"s3://{self.bucket_name}/" \
79
  f"{model_name.replace('/', '-')}"
80
 
81
+ async def load_model_and_tokenizer(self, model_name):
82
  if model_name in model_cache:
83
  return model_cache[model_name]
84
 
 
148
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
149
 
150
  @app.post("/generate")
151
+ async def generate(request: GenerateRequest):
152
  try:
153
  model_name = request.model_name
154
  input_text = request.input_text
 
163
  do_sample = request.do_sample
164
  stop_sequences = request.stop_sequences
165
 
166
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
167
  device = "cuda" if torch.cuda.is_available() else "cpu"
168
  model.to(device)
169
 
 
186
  media_type="text/plain"
187
  )
188
  else:
189
+ result = await generate_text(model, tokenizer, input_text,
190
  generation_config, stop_sequences,
191
  device)
192
  return JSONResponse({"text": result, "is_end": True})
 
213
  return True
214
  return False
215
 
216
+ async def stream_text(model, tokenizer, input_text,
217
  generation_config, stop_sequences,
218
  device):
219
 
 
229
 
230
  while True:
231
 
232
+ outputs = await asyncio.to_thread(model.generate,
233
  **encoded_input,
234
  do_sample=generation_config.do_sample,
235
  max_new_tokens=generation_config.max_new_tokens,
 
271
  output_text = ""
272
 
273
 
274
+ async def generate_text(model, tokenizer, input_text,
275
  generation_config, stop_sequences,
276
  device):
277
  encoded_input = tokenizer(
 
282
  stop_criteria = StopOnSequences(stop_sequences, tokenizer)
283
  stopping_criteria = StoppingCriteriaList([stop_criteria])
284
 
285
+ outputs = await asyncio.to_thread(model.generate,
286
  **encoded_input,
287
  do_sample=generation_config.do_sample,
288
  max_new_tokens=generation_config.max_new_tokens,
 
303
 
304
  return generated_text
305
 
306
+
307
  @app.post("/generate-image")
308
+ async def generate_image(request: GenerateRequest):
309
  try:
310
  validated_body = request
311
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
333
 
334
 
335
  @app.post("/generate-text-to-speech")
336
+ async def generate_text_to_speech(request: GenerateRequest):
337
  try:
338
  validated_body = request
339
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
364
 
365
 
366
  @app.post("/generate-video")
367
+ async def generate_video(request: GenerateRequest):
368
  try:
369
  validated_body = request
370
  device = "cuda" if torch.cuda.is_available() else "cpu"