Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import boto3
|
|
7 |
import uvicorn
|
8 |
import soundfile as sf
|
9 |
import imageio
|
10 |
-
from typing import Dict
|
11 |
|
12 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
13 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
@@ -42,7 +42,7 @@ class GenerateRequest(BaseModel):
|
|
42 |
do_sample: bool = True
|
43 |
stop_sequences: list[str] = []
|
44 |
no_repeat_ngram_size: int = 2
|
45 |
-
continuation_id: str = None
|
46 |
|
47 |
@field_validator("model_name")
|
48 |
def model_name_cannot_be_empty(cls, v):
|
@@ -115,8 +115,10 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
115 |
if continuation_id:
|
116 |
if continuation_id not in active_generations:
|
117 |
raise HTTPException(status_code=404, detail="Continuation ID not found.")
|
118 |
-
|
119 |
-
|
|
|
|
|
120 |
|
121 |
generation_config = GenerationConfig(
|
122 |
temperature=temperature,
|
@@ -132,13 +134,10 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
132 |
|
133 |
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
134 |
|
135 |
-
if
|
136 |
-
|
137 |
-
active_generations[continuation_id] = {"model_name": model_name, "output": generated_text}
|
138 |
-
else:
|
139 |
-
active_generations[continuation_id]["output"] = generated_text
|
140 |
|
141 |
-
return JSONResponse({"text": generated_text, "continuation_id":
|
142 |
|
143 |
except HTTPException as http_err:
|
144 |
raise http_err
|
@@ -186,9 +185,9 @@ async def generate_image(request: GenerateRequest):
|
|
186 |
|
187 |
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
188 |
image = image_generator(request.input_text)[0]
|
189 |
-
|
190 |
-
active_generations[
|
191 |
-
return JSONResponse({"url": "Image generated successfully", "continuation_id":
|
192 |
|
193 |
except HTTPException as http_err:
|
194 |
raise http_err
|
@@ -203,9 +202,9 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
203 |
|
204 |
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
205 |
output = tts_pipeline(request.input_text)
|
206 |
-
|
207 |
-
active_generations[
|
208 |
-
return JSONResponse({"url": "Audio generated successfully", "continuation_id":
|
209 |
|
210 |
except HTTPException as http_err:
|
211 |
raise http_err
|
@@ -220,9 +219,9 @@ async def generate_video(request: GenerateRequest):
|
|
220 |
|
221 |
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
222 |
output = video_pipeline(request.input_text)
|
223 |
-
|
224 |
-
active_generations[
|
225 |
-
return JSONResponse({"url": "Video generated successfully", "continuation_id":
|
226 |
|
227 |
except HTTPException as http_err:
|
228 |
raise http_err
|
|
|
7 |
import uvicorn
|
8 |
import soundfile as sf
|
9 |
import imageio
|
10 |
+
from typing import Dict, Optional
|
11 |
|
12 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
13 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
|
42 |
do_sample: bool = True
|
43 |
stop_sequences: list[str] = []
|
44 |
no_repeat_ngram_size: int = 2
|
45 |
+
continuation_id: Optional[str] = None
|
46 |
|
47 |
@field_validator("model_name")
|
48 |
def model_name_cannot_be_empty(cls, v):
|
|
|
115 |
if continuation_id:
|
116 |
if continuation_id not in active_generations:
|
117 |
raise HTTPException(status_code=404, detail="Continuation ID not found.")
|
118 |
+
previous_data = active_generations[continuation_id]
|
119 |
+
if previous_data["model_name"] != model_name:
|
120 |
+
raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
|
121 |
+
input_text = previous_data["output"]
|
122 |
|
123 |
generation_config = GenerationConfig(
|
124 |
temperature=temperature,
|
|
|
134 |
|
135 |
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
136 |
|
137 |
+
new_continuation_id = continuation_id if continuation_id else os.urandom(16).hex()
|
138 |
+
active_generations[new_continuation_id] = {"model_name": model_name, "output": generated_text}
|
|
|
|
|
|
|
139 |
|
140 |
+
return JSONResponse({"text": generated_text, "continuation_id": new_continuation_id})
|
141 |
|
142 |
except HTTPException as http_err:
|
143 |
raise http_err
|
|
|
185 |
|
186 |
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
187 |
image = image_generator(request.input_text)[0]
|
188 |
+
new_continuation_id = os.urandom(16).hex()
|
189 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
|
190 |
+
return JSONResponse({"url": "Image generated successfully", "continuation_id": new_continuation_id})
|
191 |
|
192 |
except HTTPException as http_err:
|
193 |
raise http_err
|
|
|
202 |
|
203 |
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
204 |
output = tts_pipeline(request.input_text)
|
205 |
+
new_continuation_id = os.urandom(16).hex()
|
206 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
|
207 |
+
return JSONResponse({"url": "Audio generated successfully", "continuation_id": new_continuation_id})
|
208 |
|
209 |
except HTTPException as http_err:
|
210 |
raise http_err
|
|
|
219 |
|
220 |
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
221 |
output = video_pipeline(request.input_text)
|
222 |
+
new_continuation_id = os.urandom(16).hex()
|
223 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
|
224 |
+
return JSONResponse({"url": "Video generated successfully", "continuation_id": new_continuation_id})
|
225 |
|
226 |
except HTTPException as http_err:
|
227 |
raise http_err
|