Hjgugugjhuhjggg commited on
Commit
7ece340
·
verified ·
1 Parent(s): 74cfed2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -25
app.py CHANGED
@@ -121,17 +121,16 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
121
  raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
122
  input_text = previous_data["output"]
123
 
124
- generation_config = GenerationConfig(
125
- temperature=temperature,
126
- max_new_tokens=max_new_tokens,
127
- top_p=top_p,
128
- top_k=top_k,
129
- repetition_penalty=repetition_penalty,
130
- do_sample=do_sample,
131
- num_return_sequences=num_return_sequences,
132
- no_repeat_ngram_size=no_repeat_ngram_size,
133
- pad_token_id=tokenizer.pad_token_id
134
- )
135
 
136
  generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
137
 
@@ -147,19 +146,23 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
147
 
148
  def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
149
  max_model_length = model.config.max_position_embeddings
150
- encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True)
151
 
152
  stopping_criteria = StoppingCriteriaList()
153
 
154
- class CustomStoppingCriteria(StoppingCriteriaList):
 
 
 
 
155
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
156
- decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
157
- for stop in stop_sequences:
158
  if decoded_output.endswith(stop):
159
  return True
160
  return False
161
 
162
- stopping_criteria.append(CustomStoppingCriteria())
163
 
164
  outputs = model.generate(
165
  encoded_input.input_ids,
@@ -174,7 +177,7 @@ def generate_text_internal(model, tokenizer, input_text, generation_config, stop
174
  async def load_pipeline_from_s3(task, model_name):
175
  s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
176
  try:
177
- return pipeline(task, model=s3_uri)
178
  except Exception as e:
179
  raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
180
 
@@ -186,9 +189,11 @@ async def generate_image(request: GenerateRequest):
186
 
187
  image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
188
  image = image_generator(request.input_text)[0]
 
 
189
  new_continuation_id = os.urandom(16).hex()
190
- active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
191
- return JSONResponse({"url": "Image generated successfully", "continuation_id": new_continuation_id, "model_name": request.model_name})
192
 
193
  except HTTPException as http_err:
194
  raise http_err
@@ -202,10 +207,12 @@ async def generate_text_to_speech(request: GenerateRequest):
202
  raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
203
 
204
  tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
205
- output = tts_pipeline(request.input_text)
 
 
206
  new_continuation_id = os.urandom(16).hex()
207
- active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
208
- return JSONResponse({"url": "Audio generated successfully", "continuation_id": new_continuation_id, "model_name": request.model_name})
209
 
210
  except HTTPException as http_err:
211
  raise http_err
@@ -219,10 +226,12 @@ async def generate_video(request: GenerateRequest):
219
  raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
220
 
221
  video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
222
- output = video_pipeline(request.input_text)
 
 
223
  new_continuation_id = os.urandom(16).hex()
224
- active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
225
- return JSONResponse({"url": "Video generated successfully", "continuation_id": new_continuation_id, "model_name": request.model_name})
226
 
227
  except HTTPException as http_err:
228
  raise http_err
 
121
  raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
122
  input_text = previous_data["output"]
123
 
124
+ generation_config = GenerationConfig.from_pretrained(model_name) # Load default config and override
125
+ generation_config.temperature = temperature
126
+ generation_config.max_new_tokens = max_new_tokens
127
+ generation_config.top_p = top_p
128
+ generation_config.top_k = top_k
129
+ generation_config.repetition_penalty = repetition_penalty
130
+ generation_config.do_sample = do_sample
131
+ generation_config.num_return_sequences = num_return_sequences
132
+ generation_config.no_repeat_ngram_size = no_repeat_ngram_size
133
+ generation_config.pad_token_id = tokenizer.pad_token_id
 
134
 
135
  generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
136
 
 
146
 
147
  def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
148
  max_model_length = model.config.max_position_embeddings
149
+ encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(model.device) # Ensure input is on the same device as the model
150
 
151
  stopping_criteria = StoppingCriteriaList()
152
 
153
+ class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
154
+ def __init__(self, stop_sequences, tokenizer):
155
+ self.stop_sequences = stop_sequences
156
+ self.tokenizer = tokenizer
157
+
158
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
159
+ decoded_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
160
+ for stop in self.stop_sequences:
161
  if decoded_output.endswith(stop):
162
  return True
163
  return False
164
 
165
+ stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))
166
 
167
  outputs = model.generate(
168
  encoded_input.input_ids,
 
177
  async def load_pipeline_from_s3(task, model_name):
178
  s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
179
  try:
180
+ return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
181
  except Exception as e:
182
  raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
183
 
 
189
 
190
  image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
191
  image = image_generator(request.input_text)[0]
192
+ image_path = f"generated_image_{os.urandom(8).hex()}.png" # Save image locally
193
+ image.save(image_path)
194
  new_continuation_id = os.urandom(16).hex()
195
+ active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Image saved to {image_path}"} # Return path or upload URL
196
+ return JSONResponse({"url": image_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
197
 
198
  except HTTPException as http_err:
199
  raise http_err
 
207
  raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
208
 
209
  tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
210
+ audio_output = tts_pipeline(request.input_text)
211
+ audio_path = f"generated_audio_{os.urandom(8).hex()}.wav"
212
+ sf.write(audio_path, audio_output["sampling_rate"], audio_output["audio"])
213
  new_continuation_id = os.urandom(16).hex()
214
+ active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Audio saved to {audio_path}"}
215
+ return JSONResponse({"url": audio_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
216
 
217
  except HTTPException as http_err:
218
  raise http_err
 
226
  raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
227
 
228
  video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
229
+ video_frames = video_pipeline(request.input_text).frames
230
+ video_path = f"generated_video_{os.urandom(8).hex()}.mp4"
231
+ imageio.mimsave(video_path, video_frames, fps=30) # Adjust fps as needed
232
  new_continuation_id = os.urandom(16).hex()
233
+ active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Video saved to {video_path}"}
234
+ return JSONResponse({"url": video_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
235
 
236
  except HTTPException as http_err:
237
  raise http_err