Hjgugugjhuhjggg commited on
Commit
b5fcdec
·
verified ·
1 Parent(s): 7c21718

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -51
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
@@ -23,7 +23,9 @@ AWS_REGION = os.getenv("AWS_REGION")
23
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
25
 
26
- s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
 
 
27
 
28
  app = FastAPI()
29
 
@@ -88,7 +90,7 @@ class S3ModelLoader:
88
  return model, tokenizer
89
  except Exception as e:
90
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
91
-
92
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
93
 
94
  @app.post("/generate")
@@ -123,31 +125,47 @@ async def generate(request: GenerateRequest):
123
  )
124
 
125
  return StreamingResponse(
126
- stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
 
 
127
  media_type="text/plain"
128
  )
129
-
130
  except Exception as e:
131
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
132
 
133
- async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
134
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
 
 
 
 
 
135
  input_length = encoded_input["input_ids"].shape[1]
136
  remaining_tokens = max_length - input_length
137
 
138
  if remaining_tokens <= 0:
139
  yield ""
140
 
141
- generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
142
-
143
- def stop_criteria(input_ids, scores):
144
- decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
145
- return decoded_output in stop_sequences
 
 
 
 
146
 
147
- stopping_criteria = StoppingCriteriaList([stop_criteria])
148
 
 
149
  output_text = ""
150
- outputs = model.generate(
 
 
151
  **encoded_input,
152
  do_sample=generation_config.do_sample,
153
  max_new_tokens=generation_config.max_new_tokens,
@@ -156,42 +174,42 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
156
  top_k=generation_config.top_k,
157
  repetition_penalty=generation_config.repetition_penalty,
158
  num_return_sequences=generation_config.num_return_sequences,
159
- stopping_criteria=stopping_criteria,
160
  output_scores=True,
161
- return_dict_in_generate=True
162
- )
163
-
164
- for output in outputs.sequences:
165
- for token_id in output:
166
- token = tokenizer.decode(token_id, skip_special_tokens=True)
167
- yield token
168
- await asyncio.sleep(chunk_delay) # Simula el delay entre tokens
 
 
 
 
 
 
 
 
 
 
169
 
170
- if stop_sequences and any(stop in output_text for stop in stop_sequences):
171
- yield output_text
172
- return
 
173
 
174
- outputs = model.generate(
175
- **encoded_input,
176
- do_sample=generation_config.do_sample,
177
- max_new_tokens=generation_config.max_new_tokens,
178
- temperature=generation_config.temperature,
179
- top_p=generation_config.top_p,
180
- top_k=generation_config.top_k,
181
- repetition_penalty=generation_config.repetition_penalty,
182
- num_return_sequences=generation_config.num_return_sequences,
183
- stopping_criteria=stopping_criteria,
184
- output_scores=True,
185
- return_dict_in_generate=True
186
- )
187
 
188
  @app.post("/generate-image")
189
  async def generate_image(request: GenerateRequest):
190
  try:
191
  validated_body = request
192
  device = "cuda" if torch.cuda.is_available() else "cpu"
193
-
194
- image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
 
 
195
  image = image_generator(validated_body.input_text)[0]
196
 
197
  img_byte_arr = BytesIO()
@@ -199,17 +217,20 @@ async def generate_image(request: GenerateRequest):
199
  img_byte_arr.seek(0)
200
 
201
  return StreamingResponse(img_byte_arr, media_type="image/png")
202
-
203
  except Exception as e:
204
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
205
-
 
206
  @app.post("/generate-text-to-speech")
207
  async def generate_text_to_speech(request: GenerateRequest):
208
  try:
209
  validated_body = request
210
  device = "cuda" if torch.cuda.is_available() else "cpu"
211
 
212
- audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
 
 
213
  audio = audio_generator(validated_body.input_text)[0]
214
 
215
  audio_byte_arr = BytesIO()
@@ -219,24 +240,29 @@ async def generate_text_to_speech(request: GenerateRequest):
219
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
220
 
221
  except Exception as e:
222
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
223
 
224
  @app.post("/generate-video")
225
  async def generate_video(request: GenerateRequest):
226
  try:
227
  validated_body = request
228
  device = "cuda" if torch.cuda.is_available() else "cpu"
229
- video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
 
 
230
  video = video_generator(validated_body.input_text)[0]
231
 
232
  video_byte_arr = BytesIO()
233
  video.save(video_byte_arr)
234
  video_byte_arr.seek(0)
235
 
236
- return StreamingResponse(video_byte_arr, media_type="video/mp4")
237
-
 
238
  except Exception as e:
239
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
240
 
241
  if __name__ == "__main__":
242
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI
4
  from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
 
23
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
25
 
26
+ s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,
27
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
28
+ region_name=AWS_REGION)
29
 
30
  app = FastAPI()
31
 
 
90
  return model, tokenizer
91
  except Exception as e:
92
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
93
+
94
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
95
 
96
  @app.post("/generate")
 
125
  )
126
 
127
  return StreamingResponse(
128
+ stream_text(model, tokenizer, input_text,
129
+ generation_config, stop_sequences,
130
+ device, chunk_delay),
131
  media_type="text/plain"
132
  )
133
+
134
  except Exception as e:
135
+ raise HTTPException(status_code=500,
136
+ detail=f"Internal server error: {str(e)}")
137
+
138
 
139
+ async def stream_text(model, tokenizer, input_text,
140
+ generation_config, stop_sequences,
141
+ device, chunk_delay, max_length=2048):
142
+ encoded_input = tokenizer(input_text,
143
+ return_tensors="pt",
144
+ truncation=True,
145
+ max_length=max_length).to(device)
146
  input_length = encoded_input["input_ids"].shape[1]
147
  remaining_tokens = max_length - input_length
148
 
149
  if remaining_tokens <= 0:
150
  yield ""
151
 
152
+ generation_config.max_new_tokens = min(
153
+ remaining_tokens, generation_config.max_new_tokens
154
+ )
155
+
156
+ def find_stop(output_text, stop_sequences):
157
+ for seq in stop_sequences:
158
+ if seq in output_text:
159
+ last_index = output_text.rfind(seq)
160
+ return last_index + len(seq)
161
 
162
+ return -1
163
 
164
+
165
  output_text = ""
166
+
167
+ while True:
168
+ outputs = model.generate(
169
  **encoded_input,
170
  do_sample=generation_config.do_sample,
171
  max_new_tokens=generation_config.max_new_tokens,
 
174
  top_k=generation_config.top_k,
175
  repetition_penalty=generation_config.repetition_penalty,
176
  num_return_sequences=generation_config.num_return_sequences,
 
177
  output_scores=True,
178
+ return_dict_in_generate=True,
179
+ )
180
+
181
+ new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):], skip_special_tokens=True)
182
+
183
+ output_text += new_text
184
+
185
+ yield new_text
186
+ await asyncio.sleep(chunk_delay)
187
+
188
+
189
+ stop_index = find_stop(output_text, stop_sequences)
190
+ if stop_index != -1:
191
+ yield output_text[:stop_index]
192
+ break
193
+
194
+ if len(output_text) >= generation_config.max_new_tokens:
195
+ break
196
 
197
+ encoded_input = tokenizer(output_text,
198
+ return_tensors="pt",
199
+ truncation=True,
200
+ max_length=max_length).to(device)
201
 
202
+
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  @app.post("/generate-image")
205
  async def generate_image(request: GenerateRequest):
206
  try:
207
  validated_body = request
208
  device = "cuda" if torch.cuda.is_available() else "cpu"
209
+
210
+ image_generator = pipeline("text-to-image",
211
+ model=validated_body.model_name,
212
+ device=device)
213
  image = image_generator(validated_body.input_text)[0]
214
 
215
  img_byte_arr = BytesIO()
 
217
  img_byte_arr.seek(0)
218
 
219
  return StreamingResponse(img_byte_arr, media_type="image/png")
220
+
221
  except Exception as e:
222
+ raise HTTPException(status_code=500,
223
+ detail=f"Internal server error: {str(e)}")
224
+
225
  @app.post("/generate-text-to-speech")
226
  async def generate_text_to_speech(request: GenerateRequest):
227
  try:
228
  validated_body = request
229
  device = "cuda" if torch.cuda.is_available() else "cpu"
230
 
231
+ audio_generator = pipeline("text-to-speech",
232
+ model=validated_body.model_name,
233
+ device=device)
234
  audio = audio_generator(validated_body.input_text)[0]
235
 
236
  audio_byte_arr = BytesIO()
 
240
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
241
 
242
  except Exception as e:
243
+ raise HTTPException(status_code=500,
244
+ detail=f"Internal server error: {str(e)}")
245
 
246
  @app.post("/generate-video")
247
  async def generate_video(request: GenerateRequest):
248
  try:
249
  validated_body = request
250
  device = "cuda" if torch.cuda.is_available() else "cpu"
251
+ video_generator = pipeline("text-to-video",
252
+ model=validated_body.model_name,
253
+ device=device)
254
  video = video_generator(validated_body.input_text)[0]
255
 
256
  video_byte_arr = BytesIO()
257
  video.save(video_byte_arr)
258
  video_byte_arr.seek(0)
259
 
260
+ return StreamingResponse(video_byte_arr,
261
+ media_type="video/mp4")
262
+
263
  except Exception as e:
264
+ raise HTTPException(status_code=500,
265
+ detail=f"Internal server error: {str(e)}")
266
 
267
  if __name__ == "__main__":
268
  uvicorn.run(app, host="0.0.0.0", port=7860)