Hjgugugjhuhjggg commited on
Commit
ecffbb4
·
verified ·
1 Parent(s): de3c0e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
7
  import boto3
@@ -55,16 +55,16 @@ class GenerateRequest(BaseModel):
55
 
56
  @field_validator("max_new_tokens")
57
  def max_new_tokens_must_be_within_limit(cls, v):
58
- if v > 4:
59
- raise ValueError("max_new_tokens cannot be greater than 4.")
60
  return v
61
 
62
  class S3ModelLoader:
63
- def.__init__(self, bucket_name, s3_client):
64
  self.bucket_name = bucket_name
65
  self.s3_client = s3_client
66
 
67
- def._get_s3_uri(self, model_name):
68
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
69
 
70
  async def load_model_and_tokenizer(self, model_name):
@@ -187,7 +187,10 @@ async def generate_text_to_speech(request: GenerateRequest):
187
  audio = audio_generator(validated_body.input_text)[0]
188
 
189
  audio_byte_arr = BytesIO()
190
- audio.save(audio_byte_arr)
 
 
 
191
  audio_byte_arr.seek(0)
192
 
193
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
@@ -204,7 +207,10 @@ async def generate_video(request: GenerateRequest):
204
  video = video_generator(validated_body.input_text)[0]
205
 
206
  video_byte_arr = BytesIO()
207
- video.save(video_byte_arr)
 
 
 
208
  video_byte_arr.seek(0)
209
 
210
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
 
1
  import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import JSONResponse, StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
7
  import boto3
 
55
 
56
  @field_validator("max_new_tokens")
57
  def max_new_tokens_must_be_within_limit(cls, v):
58
+ if v > 500:
59
+ raise ValueError("max_new_tokens cannot be greater than 500.")
60
  return v
61
 
62
  class S3ModelLoader:
63
+ def __init__(self, bucket_name, s3_client):
64
  self.bucket_name = bucket_name
65
  self.s3_client = s3_client
66
 
67
+ def _get_s3_uri(self, model_name):
68
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
69
 
70
  async def load_model_and_tokenizer(self, model_name):
 
187
  audio = audio_generator(validated_body.input_text)[0]
188
 
189
  audio_byte_arr = BytesIO()
190
+ # It is expected that the audio is saved as wav.
191
+ # Saving like this will not always work. Please check how your
192
+ # audio_generator model is working.
193
+ audio_generator.save_audio(audio_byte_arr, audio)
194
  audio_byte_arr.seek(0)
195
 
196
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
 
207
  video = video_generator(validated_body.input_text)[0]
208
 
209
  video_byte_arr = BytesIO()
210
+ # Same as above. Please check how your video model is returning the
211
+ # videos and save them accordingly.
212
+ # It is expected that the video is saved as MP4
213
+ video_generator.save_video(video_byte_arr, video)
214
  video_byte_arr.seek(0)
215
 
216
  return StreamingResponse(video_byte_arr, media_type="video/mp4")