Hjgugugjhuhjggg commited on
Commit
27720cf
·
verified ·
1 Parent(s): 54d55f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -19
app.py CHANGED
@@ -7,7 +7,7 @@ import boto3
7
  import uvicorn
8
  import soundfile as sf
9
  import imageio
10
- from typing import Dict
11
 
12
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
13
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -42,7 +42,7 @@ class GenerateRequest(BaseModel):
42
  do_sample: bool = True
43
  stop_sequences: list[str] = []
44
  no_repeat_ngram_size: int = 2
45
- continuation_id: str = None
46
 
47
  @field_validator("model_name")
48
  def model_name_cannot_be_empty(cls, v):
@@ -115,8 +115,10 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
115
  if continuation_id:
116
  if continuation_id not in active_generations:
117
  raise HTTPException(status_code=404, detail="Continuation ID not found.")
118
- previous_output = active_generations[continuation_id]["output"]
119
- input_text = previous_output
 
 
120
 
121
  generation_config = GenerationConfig(
122
  temperature=temperature,
@@ -132,13 +134,10 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
132
 
133
  generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
134
 
135
- if not continuation_id:
136
- continuation_id = os.urandom(16).hex()
137
- active_generations[continuation_id] = {"model_name": model_name, "output": generated_text}
138
- else:
139
- active_generations[continuation_id]["output"] = generated_text
140
 
141
- return JSONResponse({"text": generated_text, "continuation_id": continuation_id})
142
 
143
  except HTTPException as http_err:
144
  raise http_err
@@ -186,9 +185,9 @@ async def generate_image(request: GenerateRequest):
186
 
187
  image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
188
  image = image_generator(request.input_text)[0]
189
- continuation_id = os.urandom(16).hex()
190
- active_generations[continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
191
- return JSONResponse({"url": "Image generated successfully", "continuation_id": continuation_id})
192
 
193
  except HTTPException as http_err:
194
  raise http_err
@@ -203,9 +202,9 @@ async def generate_text_to_speech(request: GenerateRequest):
203
 
204
  tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
205
  output = tts_pipeline(request.input_text)
206
- continuation_id = os.urandom(16).hex()
207
- active_generations[continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
208
- return JSONResponse({"url": "Audio generated successfully", "continuation_id": continuation_id})
209
 
210
  except HTTPException as http_err:
211
  raise http_err
@@ -220,9 +219,9 @@ async def generate_video(request: GenerateRequest):
220
 
221
  video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
222
  output = video_pipeline(request.input_text)
223
- continuation_id = os.urandom(16).hex()
224
- active_generations[continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
225
- return JSONResponse({"url": "Video generated successfully", "continuation_id": continuation_id})
226
 
227
  except HTTPException as http_err:
228
  raise http_err
 
7
  import uvicorn
8
  import soundfile as sf
9
  import imageio
10
+ from typing import Dict, Optional
11
 
12
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
13
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
42
  do_sample: bool = True
43
  stop_sequences: list[str] = []
44
  no_repeat_ngram_size: int = 2
45
+ continuation_id: Optional[str] = None
46
 
47
  @field_validator("model_name")
48
  def model_name_cannot_be_empty(cls, v):
 
115
  if continuation_id:
116
  if continuation_id not in active_generations:
117
  raise HTTPException(status_code=404, detail="Continuation ID not found.")
118
+ previous_data = active_generations[continuation_id]
119
+ if previous_data["model_name"] != model_name:
120
+ raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
121
+ input_text = previous_data["output"]
122
 
123
  generation_config = GenerationConfig(
124
  temperature=temperature,
 
134
 
135
  generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
136
 
137
+ new_continuation_id = continuation_id if continuation_id else os.urandom(16).hex()
138
+ active_generations[new_continuation_id] = {"model_name": model_name, "output": generated_text}
 
 
 
139
 
140
+ return JSONResponse({"text": generated_text, "continuation_id": new_continuation_id})
141
 
142
  except HTTPException as http_err:
143
  raise http_err
 
185
 
186
  image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
187
  image = image_generator(request.input_text)[0]
188
+ new_continuation_id = os.urandom(16).hex()
189
+ active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
190
+ return JSONResponse({"url": "Image generated successfully", "continuation_id": new_continuation_id})
191
 
192
  except HTTPException as http_err:
193
  raise http_err
 
202
 
203
  tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
204
  output = tts_pipeline(request.input_text)
205
+ new_continuation_id = os.urandom(16).hex()
206
+ active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
207
+ return JSONResponse({"url": "Audio generated successfully", "continuation_id": new_continuation_id})
208
 
209
  except HTTPException as http_err:
210
  raise http_err
 
219
 
220
  video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
221
  output = video_pipeline(request.input_text)
222
+ new_continuation_id = os.urandom(16).hex()
223
+ active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
224
+ return JSONResponse({"url": "Video generated successfully", "continuation_id": new_continuation_id})
225
 
226
  except HTTPException as http_err:
227
  raise http_err