Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
-
from fastapi.responses import JSONResponse
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
|
7 |
import boto3
|
@@ -55,16 +55,16 @@ class GenerateRequest(BaseModel):
|
|
55 |
|
56 |
@field_validator("max_new_tokens")
|
57 |
def max_new_tokens_must_be_within_limit(cls, v):
|
58 |
-
if v >
|
59 |
-
raise ValueError("max_new_tokens cannot be greater than
|
60 |
return v
|
61 |
|
62 |
class S3ModelLoader:
|
63 |
-
def
|
64 |
self.bucket_name = bucket_name
|
65 |
self.s3_client = s3_client
|
66 |
|
67 |
-
def
|
68 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
69 |
|
70 |
async def load_model_and_tokenizer(self, model_name):
|
@@ -187,7 +187,10 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
187 |
audio = audio_generator(validated_body.input_text)[0]
|
188 |
|
189 |
audio_byte_arr = BytesIO()
|
190 |
-
audio.
|
|
|
|
|
|
|
191 |
audio_byte_arr.seek(0)
|
192 |
|
193 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
@@ -204,7 +207,10 @@ async def generate_video(request: GenerateRequest):
|
|
204 |
video = video_generator(validated_body.input_text)[0]
|
205 |
|
206 |
video_byte_arr = BytesIO()
|
207 |
-
video
|
|
|
|
|
|
|
208 |
video_byte_arr.seek(0)
|
209 |
|
210 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
|
7 |
import boto3
|
|
|
55 |
|
56 |
@field_validator("max_new_tokens")
|
57 |
def max_new_tokens_must_be_within_limit(cls, v):
|
58 |
+
if v > 500:
|
59 |
+
raise ValueError("max_new_tokens cannot be greater than 500.")
|
60 |
return v
|
61 |
|
62 |
class S3ModelLoader:
|
63 |
+
def __init__(self, bucket_name, s3_client):
|
64 |
self.bucket_name = bucket_name
|
65 |
self.s3_client = s3_client
|
66 |
|
67 |
+
def _get_s3_uri(self, model_name):
|
68 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
69 |
|
70 |
async def load_model_and_tokenizer(self, model_name):
|
|
|
187 |
audio = audio_generator(validated_body.input_text)[0]
|
188 |
|
189 |
audio_byte_arr = BytesIO()
|
190 |
+
# It is expected that the audio is saved as wav.
|
191 |
+
# Saving like this will not always work. Please check how your
|
192 |
+
# audio_generator model is working.
|
193 |
+
audio_generator.save_audio(audio_byte_arr, audio)
|
194 |
audio_byte_arr.seek(0)
|
195 |
|
196 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
|
|
207 |
video = video_generator(validated_body.input_text)[0]
|
208 |
|
209 |
video_byte_arr = BytesIO()
|
210 |
+
# Same as above. Please check how your video model is returning the
|
211 |
+
# videos and save them accordingly.
|
212 |
+
# It is expected that the video is saved as MP4
|
213 |
+
video_generator.save_video(video_byte_arr, video)
|
214 |
video_byte_arr.seek(0)
|
215 |
|
216 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|