Hjgugugjhuhjggg commited on
Commit
6e229a7
·
verified ·
1 Parent(s): 3ed39a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -28
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import (
7
  AutoConfig,
8
  pipeline,
9
  AutoModelForSeq2SeqLM,
 
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList
@@ -14,10 +15,11 @@ from transformers import (
14
  import boto3
15
  import uvicorn
16
  import asyncio
17
- from io import BytesIO
18
  from transformers import pipeline
19
  import json
20
  from huggingface_hub import login
 
 
21
 
22
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
23
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -80,9 +82,20 @@ class S3ModelLoader:
80
  config = AutoConfig.from_pretrained(
81
  s3_uri, local_files_only=True
82
  )
83
- model = AutoModelForSeq2SeqLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
84
  s3_uri, config=config, local_files_only=True
85
  )
 
86
  tokenizer = AutoTokenizer.from_pretrained(
87
  s3_uri, config=config, local_files_only=True
88
  )
@@ -101,7 +114,17 @@ class S3ModelLoader:
101
  tokenizer = AutoTokenizer.from_pretrained(
102
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
103
  )
104
- model = AutoModelForSeq2SeqLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
105
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
106
  )
107
 
@@ -261,14 +284,10 @@ async def generate_image(request: GenerateRequest):
261
  device=device
262
  )
263
  image = image_generator(validated_body.input_text)[0]
264
-
265
- img_byte_arr = BytesIO()
266
- image.save(img_byte_arr, format="PNG")
267
- img_byte_arr.seek(0)
268
-
269
- return StreamingResponse(
270
- img_byte_arr, media_type="image/png"
271
- )
272
 
273
  except Exception as e:
274
  raise HTTPException(
@@ -287,15 +306,14 @@ async def generate_text_to_speech(request: GenerateRequest):
287
  "text-to-speech", model=validated_body.model_name,
288
  device=device
289
  )
290
- audio = audio_generator(validated_body.input_text)[0]
291
 
292
- audio_byte_arr = BytesIO()
293
- audio.save(audio_byte_arr)
294
- audio_byte_arr.seek(0)
295
-
296
- return StreamingResponse(
297
- audio_byte_arr, media_type="audio/wav"
298
- )
299
 
300
  except Exception as e:
301
  raise HTTPException(
@@ -313,15 +331,12 @@ async def generate_video(request: GenerateRequest):
313
  "text-to-video", model=validated_body.model_name,
314
  device=device
315
  )
316
- video = video_generator(validated_body.input_text)[0]
317
-
318
- video_byte_arr = BytesIO()
319
- video.save(video_byte_arr)
320
- video_byte_arr.seek(0)
321
-
322
- return StreamingResponse(
323
- video_byte_arr, media_type="video/mp4"
324
- )
325
 
326
  except Exception as e:
327
  raise HTTPException(
 
7
  AutoConfig,
8
  pipeline,
9
  AutoModelForSeq2SeqLM,
10
+ AutoModelForCausalLM,
11
  AutoTokenizer,
12
  GenerationConfig,
13
  StoppingCriteriaList
 
15
  import boto3
16
  import uvicorn
17
  import asyncio
 
18
  from transformers import pipeline
19
  import json
20
  from huggingface_hub import login
21
+ import base64
22
+
23
 
24
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
25
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
82
  config = AutoConfig.from_pretrained(
83
  s3_uri, local_files_only=True
84
  )
85
+
86
+ if "llama" in model_name:
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ s3_uri, config=config, local_files_only=True, rope_scaling = {"type": "linear", "factor": 2.0}
89
+ )
90
+ elif 'qwen' in model_name:
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ s3_uri, config=config, local_files_only=True
93
+ )
94
+ else:
95
+ model = AutoModelForSeq2SeqLM.from_pretrained(
96
  s3_uri, config=config, local_files_only=True
97
  )
98
+
99
  tokenizer = AutoTokenizer.from_pretrained(
100
  s3_uri, config=config, local_files_only=True
101
  )
 
114
  tokenizer = AutoTokenizer.from_pretrained(
115
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
116
  )
117
+
118
+ if "llama" in model_name:
119
+ model = AutoModelForCausalLM.from_pretrained(
120
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN, rope_scaling = {"type": "linear", "factor": 2.0}
121
+ )
122
+ elif 'qwen' in model_name:
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
125
+ )
126
+ else:
127
+ model = AutoModelForSeq2SeqLM.from_pretrained(
128
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
129
  )
130
 
 
284
  device=device
285
  )
286
  image = image_generator(validated_body.input_text)[0]
287
+
288
+ image_data = list(image.getdata())
289
+
290
+ return json.dumps({"image_data": image_data, "is_end": True})
 
 
 
 
291
 
292
  except Exception as e:
293
  raise HTTPException(
 
306
  "text-to-speech", model=validated_body.model_name,
307
  device=device
308
  )
309
+ audio = audio_generator(validated_body.input_text)
310
 
311
+
312
+ audio_bytes = audio["audio"]
313
+
314
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
315
+
316
+ return json.dumps({"audio": audio_base64, "is_end": True})
 
317
 
318
  except Exception as e:
319
  raise HTTPException(
 
331
  "text-to-video", model=validated_body.model_name,
332
  device=device
333
  )
334
+ video = video_generator(validated_body.input_text)
335
+
336
+
337
+ video_base64 = base64.b64encode(video).decode('utf-8')
338
+
339
+ return json.dumps({"video": video_base64, "is_end": True})
 
 
 
340
 
341
  except Exception as e:
342
  raise HTTPException(