Hjgugugjhuhjggg commited on
Commit
e6982de
·
verified ·
1 Parent(s): 277e316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -74
app.py CHANGED
@@ -17,6 +17,7 @@ import asyncio
17
  from io import BytesIO
18
  from transformers import pipeline
19
  import json
 
20
 
21
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
22
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -24,6 +25,11 @@ AWS_REGION = os.getenv("AWS_REGION")
24
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
25
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
26
 
 
 
 
 
 
27
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,
28
  aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
29
  region_name=AWS_REGION)
@@ -53,7 +59,8 @@ class GenerateRequest(BaseModel):
53
 
54
  @field_validator("task_type")
55
  def task_type_must_be_valid(cls, v):
56
- valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
 
57
  if v not in valid_types:
58
  raise ValueError(f"task_type must be one of: {valid_types}")
59
  return v
@@ -64,34 +71,51 @@ class S3ModelLoader:
64
  self.s3_client = s3_client
65
 
66
  def _get_s3_uri(self, model_name):
67
- return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
 
68
 
69
  async def load_model_and_tokenizer(self, model_name):
70
  s3_uri = self._get_s3_uri(model_name)
71
  try:
72
- config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
73
- model = AutoModelForSeq2SeqLM.from_pretrained(s3_uri, config=config, local_files_only=True)
74
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
75
-
76
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
77
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
 
 
 
 
 
78
 
79
  return model, tokenizer
80
  except EnvironmentError:
81
  try:
82
  config = AutoConfig.from_pretrained(model_name)
83
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
84
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, config=config)
85
-
86
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
87
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
 
 
 
88
 
89
  model.save_pretrained(s3_uri)
90
  tokenizer.save_pretrained(s3_uri)
91
  return model, tokenizer
92
  except Exception as e:
93
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
94
-
 
 
95
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
96
 
97
  @app.post("/generate")
@@ -111,7 +135,8 @@ async def generate(request: GenerateRequest):
111
  chunk_delay = request.chunk_delay
112
  stop_sequences = request.stop_sequences
113
 
114
- model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
 
115
  device = "cuda" if torch.cuda.is_available() else "cpu"
116
  model.to(device)
117
 
@@ -131,19 +156,20 @@ async def generate(request: GenerateRequest):
131
  device, chunk_delay),
132
  media_type="text/plain"
133
  )
134
-
135
  except Exception as e:
136
- raise HTTPException(status_code=500,
137
- detail=f"Internal server error: {str(e)}")
 
138
 
139
 
140
  async def stream_text(model, tokenizer, input_text,
141
  generation_config, stop_sequences,
142
  device, chunk_delay, max_length=2048):
143
- encoded_input = tokenizer(input_text,
144
- return_tensors="pt",
145
- truncation=True,
146
- max_length=max_length).to(device)
147
  input_length = encoded_input["input_ids"].shape[1]
148
  remaining_tokens = max_length - input_length
149
 
@@ -153,7 +179,7 @@ async def stream_text(model, tokenizer, input_text,
153
  generation_config.max_new_tokens = min(
154
  remaining_tokens, generation_config.max_new_tokens
155
  )
156
-
157
  def find_stop(output_text, stop_sequences):
158
  for seq in stop_sequences:
159
  if seq in output_text:
@@ -161,9 +187,9 @@ async def stream_text(model, tokenizer, input_text,
161
  return last_index + len(seq)
162
 
163
  return -1
164
-
165
  output_text = ""
166
-
167
  while True:
168
  outputs = model.generate(
169
  **encoded_input,
@@ -177,51 +203,50 @@ async def stream_text(model, tokenizer, input_text,
177
  output_scores=True,
178
  return_dict_in_generate=True,
179
  )
180
-
181
- new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):], skip_special_tokens=True)
182
-
 
 
 
183
  output_text += new_text
184
-
185
-
186
-
187
  stop_index = find_stop(output_text, stop_sequences)
188
-
189
  if stop_index != -1:
190
  final_output = output_text[:stop_index]
191
-
192
-
193
-
194
- chunked_output = [final_output[i:i+10] for i in range(0, len(final_output), 10)]
195
-
196
  for chunk in chunked_output:
197
-
198
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
199
  await asyncio.sleep(chunk_delay)
200
-
201
  yield json.dumps({"text": "", "is_end": True}) + "\n"
202
  break
203
-
204
  else:
205
- chunked_output = [new_text[i:i+10] for i in range(0, len(new_text), 10)]
 
206
  for chunk in chunked_output:
207
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
208
  await asyncio.sleep(chunk_delay)
209
-
210
-
211
  if len(output_text) >= generation_config.max_new_tokens:
212
-
213
- chunked_output = [output_text[i:i+10] for i in range(0, len(output_text), 10)]
214
-
215
  for chunk in chunked_output:
216
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
217
  await asyncio.sleep(chunk_delay)
218
  yield json.dumps({"text": "", "is_end": True}) + "\n"
219
  break
220
 
221
- encoded_input = tokenizer(output_text,
222
- return_tensors="pt",
223
- truncation=True,
224
- max_length=max_length).to(device)
225
 
226
  @app.post("/generate-image")
227
  async def generate_image(request: GenerateRequest):
@@ -229,62 +254,78 @@ async def generate_image(request: GenerateRequest):
229
  validated_body = request
230
  device = "cuda" if torch.cuda.is_available() else "cpu"
231
 
232
- image_generator = pipeline("text-to-image",
233
- model=validated_body.model_name,
234
- device=device)
 
235
  image = image_generator(validated_body.input_text)[0]
236
 
237
  img_byte_arr = BytesIO()
238
  image.save(img_byte_arr, format="PNG")
239
  img_byte_arr.seek(0)
240
 
241
- return StreamingResponse(img_byte_arr, media_type="image/png")
242
-
 
 
243
  except Exception as e:
244
- raise HTTPException(status_code=500,
245
- detail=f"Internal server error: {str(e)}")
246
-
 
 
 
247
  @app.post("/generate-text-to-speech")
248
  async def generate_text_to_speech(request: GenerateRequest):
249
  try:
250
  validated_body = request
251
  device = "cuda" if torch.cuda.is_available() else "cpu"
252
-
253
- audio_generator = pipeline("text-to-speech",
254
- model=validated_body.model_name,
255
- device=device)
 
256
  audio = audio_generator(validated_body.input_text)[0]
257
 
258
  audio_byte_arr = BytesIO()
259
  audio.save(audio_byte_arr)
260
  audio_byte_arr.seek(0)
261
 
262
- return StreamingResponse(audio_byte_arr, media_type="audio/wav")
 
 
263
 
264
  except Exception as e:
265
- raise HTTPException(status_code=500,
266
- detail=f"Internal server error: {str(e)}")
 
 
 
267
 
268
  @app.post("/generate-video")
269
  async def generate_video(request: GenerateRequest):
270
  try:
271
  validated_body = request
272
  device = "cuda" if torch.cuda.is_available() else "cpu"
273
- video_generator = pipeline("text-to-video",
274
- model=validated_body.model_name,
275
- device=device)
 
276
  video = video_generator(validated_body.input_text)[0]
277
 
278
  video_byte_arr = BytesIO()
279
  video.save(video_byte_arr)
280
  video_byte_arr.seek(0)
281
 
282
- return StreamingResponse(video_byte_arr,
283
- media_type="video/mp4")
284
-
 
285
  except Exception as e:
286
- raise HTTPException(status_code=500,
287
- detail=f"Internal server error: {str(e)}")
 
 
288
 
289
  if __name__ == "__main__":
290
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
17
  from io import BytesIO
18
  from transformers import pipeline
19
  import json
20
+ from huggingface_hub import login
21
 
22
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
23
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
25
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
26
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
27
 
28
+
29
+ if HUGGINGFACE_HUB_TOKEN:
30
+ login(token=HUGGINGFACE_HUB_TOKEN,
31
+ add_to_git_credential=False)
32
+
33
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,
34
  aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
35
  region_name=AWS_REGION)
 
59
 
60
  @field_validator("task_type")
61
  def task_type_must_be_valid(cls, v):
62
+ valid_types = ["text-to-text", "text-to-image",
63
+ "text-to-speech", "text-to-video"]
64
  if v not in valid_types:
65
  raise ValueError(f"task_type must be one of: {valid_types}")
66
  return v
 
71
  self.s3_client = s3_client
72
 
73
  def _get_s3_uri(self, model_name):
74
+ return f"s3://{self.bucket_name}/" \
75
+ f"{model_name.replace('/', '-')}"
76
 
77
  async def load_model_and_tokenizer(self, model_name):
78
  s3_uri = self._get_s3_uri(model_name)
79
  try:
80
+ config = AutoConfig.from_pretrained(
81
+ s3_uri, local_files_only=True
82
+ )
83
+ model = AutoModelForSeq2SeqLM.from_pretrained(
84
+ s3_uri, config=config, local_files_only=True
85
+ )
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ s3_uri, config=config, local_files_only=True
88
+ )
89
+
90
+ if tokenizer.eos_token_id is not None and \
91
+ tokenizer.pad_token_id is None:
92
+ tokenizer.pad_token_id = config.pad_token_id \
93
+ or tokenizer.eos_token_id
94
 
95
  return model, tokenizer
96
  except EnvironmentError:
97
  try:
98
  config = AutoConfig.from_pretrained(model_name)
99
+ tokenizer = AutoTokenizer.from_pretrained(
100
+ model_name, config=config
101
+ )
102
+ model = AutoModelForSeq2SeqLM.from_pretrained(
103
+ model_name, config=config
104
+ )
105
+
106
+ if tokenizer.eos_token_id is not None and \
107
+ tokenizer.pad_token_id is None:
108
+ tokenizer.pad_token_id = config.pad_token_id \
109
+ or tokenizer.eos_token_id
110
 
111
  model.save_pretrained(s3_uri)
112
  tokenizer.save_pretrained(s3_uri)
113
  return model, tokenizer
114
  except Exception as e:
115
+ raise HTTPException(
116
+ status_code=500, detail=f"Error loading model: {e}"
117
+ )
118
+
119
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
120
 
121
  @app.post("/generate")
 
135
  chunk_delay = request.chunk_delay
136
  stop_sequences = request.stop_sequences
137
 
138
+ model, tokenizer = await model_loader.\
139
+ load_model_and_tokenizer(model_name)
140
  device = "cuda" if torch.cuda.is_available() else "cpu"
141
  model.to(device)
142
 
 
156
  device, chunk_delay),
157
  media_type="text/plain"
158
  )
159
+
160
  except Exception as e:
161
+ raise HTTPException(
162
+ status_code=500, detail=f"Internal server error: {str(e)}"
163
+ )
164
 
165
 
166
  async def stream_text(model, tokenizer, input_text,
167
  generation_config, stop_sequences,
168
  device, chunk_delay, max_length=2048):
169
+ encoded_input = tokenizer(
170
+ input_text, return_tensors="pt",
171
+ truncation=True, max_length=max_length
172
+ ).to(device)
173
  input_length = encoded_input["input_ids"].shape[1]
174
  remaining_tokens = max_length - input_length
175
 
 
179
  generation_config.max_new_tokens = min(
180
  remaining_tokens, generation_config.max_new_tokens
181
  )
182
+
183
  def find_stop(output_text, stop_sequences):
184
  for seq in stop_sequences:
185
  if seq in output_text:
 
187
  return last_index + len(seq)
188
 
189
  return -1
190
+
191
  output_text = ""
192
+
193
  while True:
194
  outputs = model.generate(
195
  **encoded_input,
 
203
  output_scores=True,
204
  return_dict_in_generate=True,
205
  )
206
+
207
+ new_text = tokenizer.decode(
208
+ outputs.sequences[0][len(encoded_input["input_ids"][0]):],
209
+ skip_special_tokens=True
210
+ )
211
+
212
  output_text += new_text
213
+
 
 
214
  stop_index = find_stop(output_text, stop_sequences)
215
+
216
  if stop_index != -1:
217
  final_output = output_text[:stop_index]
218
+
219
+ chunked_output = [final_output[i:i+10]
220
+ for i in range(0, len(final_output), 10)]
221
+
 
222
  for chunk in chunked_output:
 
223
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
224
  await asyncio.sleep(chunk_delay)
225
+
226
  yield json.dumps({"text": "", "is_end": True}) + "\n"
227
  break
228
+
229
  else:
230
+ chunked_output = [new_text[i:i+10]
231
+ for i in range(0, len(new_text), 10)]
232
  for chunk in chunked_output:
233
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
234
  await asyncio.sleep(chunk_delay)
235
+
 
236
  if len(output_text) >= generation_config.max_new_tokens:
237
+ chunked_output = [output_text[i:i+10]
238
+ for i in range(0, len(output_text), 10)]
239
+
240
  for chunk in chunked_output:
241
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
242
  await asyncio.sleep(chunk_delay)
243
  yield json.dumps({"text": "", "is_end": True}) + "\n"
244
  break
245
 
246
+ encoded_input = tokenizer(
247
+ output_text, return_tensors="pt",
248
+ truncation=True, max_length=max_length
249
+ ).to(device)
250
 
251
  @app.post("/generate-image")
252
  async def generate_image(request: GenerateRequest):
 
254
  validated_body = request
255
  device = "cuda" if torch.cuda.is_available() else "cpu"
256
 
257
+ image_generator = pipeline(
258
+ "text-to-image", model=validated_body.model_name,
259
+ device=device
260
+ )
261
  image = image_generator(validated_body.input_text)[0]
262
 
263
  img_byte_arr = BytesIO()
264
  image.save(img_byte_arr, format="PNG")
265
  img_byte_arr.seek(0)
266
 
267
+ return StreamingResponse(
268
+ img_byte_arr, media_type="image/png"
269
+ )
270
+
271
  except Exception as e:
272
+ raise HTTPException(
273
+ status_code=500,
274
+ detail=f"Internal server error: {str(e)}"
275
+ )
276
+
277
+
278
  @app.post("/generate-text-to-speech")
279
  async def generate_text_to_speech(request: GenerateRequest):
280
  try:
281
  validated_body = request
282
  device = "cuda" if torch.cuda.is_available() else "cpu"
283
+
284
+ audio_generator = pipeline(
285
+ "text-to-speech", model=validated_body.model_name,
286
+ device=device
287
+ )
288
  audio = audio_generator(validated_body.input_text)[0]
289
 
290
  audio_byte_arr = BytesIO()
291
  audio.save(audio_byte_arr)
292
  audio_byte_arr.seek(0)
293
 
294
+ return StreamingResponse(
295
+ audio_byte_arr, media_type="audio/wav"
296
+ )
297
 
298
  except Exception as e:
299
+ raise HTTPException(
300
+ status_code=500,
301
+ detail=f"Internal server error: {str(e)}"
302
+ )
303
+
304
 
305
  @app.post("/generate-video")
306
  async def generate_video(request: GenerateRequest):
307
  try:
308
  validated_body = request
309
  device = "cuda" if torch.cuda.is_available() else "cpu"
310
+ video_generator = pipeline(
311
+ "text-to-video", model=validated_body.model_name,
312
+ device=device
313
+ )
314
  video = video_generator(validated_body.input_text)[0]
315
 
316
  video_byte_arr = BytesIO()
317
  video.save(video_byte_arr)
318
  video_byte_arr.seek(0)
319
 
320
+ return StreamingResponse(
321
+ video_byte_arr, media_type="video/mp4"
322
+ )
323
+
324
  except Exception as e:
325
+ raise HTTPException(
326
+ status_code=500,
327
+ detail=f"Internal server error: {str(e)}"
328
+ )
329
 
330
  if __name__ == "__main__":
331
  uvicorn.run(app, host="0.0.0.0", port=7860)