Hjgugugjhuhjggg commited on
Commit
78f7e86
·
verified ·
1 Parent(s): 2e0bd60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -75
app.py CHANGED
@@ -17,7 +17,8 @@ from transformers import pipeline
17
  import json
18
  from huggingface_hub import login
19
  import base64
20
-
 
21
 
22
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
23
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -25,23 +26,29 @@ AWS_REGION = os.getenv("AWS_REGION")
25
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
26
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
27
 
28
-
29
  if HUGGINGFACE_HUB_TOKEN:
30
- login(token=HUGGINGFACE_HUB_TOKEN,
31
- add_to_git_credential=False)
32
 
33
- s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,
34
- aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
35
- region_name=AWS_REGION)
 
 
 
36
 
37
  app = FastAPI()
38
 
 
 
 
 
 
39
  class GenerateRequest(BaseModel):
40
  model_name: str
41
  input_text: str = ""
42
  task_type: str
43
  temperature: float = 1.0
44
- max_new_tokens: int = 200 # this will be limited to 10
45
  stream: bool = True
46
  top_p: float = 1.0
47
  top_k: int = 50
@@ -64,6 +71,7 @@ class GenerateRequest(BaseModel):
64
  raise ValueError(f"task_type must be one of: {valid_types}")
65
  return v
66
 
 
67
  class S3ModelLoader:
68
  def __init__(self, bucket_name, s3_client):
69
  self.bucket_name = bucket_name
@@ -74,24 +82,25 @@ class S3ModelLoader:
74
  f"{model_name.replace('/', '-')}"
75
 
76
  async def load_model_and_tokenizer(self, model_name):
 
77
  s3_uri = self._get_s3_uri(model_name)
78
  try:
79
  config = AutoConfig.from_pretrained(
80
  s3_uri, local_files_only=False
81
  )
82
-
83
  model = AutoModelForCausalLM.from_pretrained(
84
  s3_uri, config=config, local_files_only=False
85
  )
86
-
87
  tokenizer = AutoTokenizer.from_pretrained(
88
  s3_uri, config=config, local_files_only=False
89
  )
90
 
91
- if tokenizer.eos_token_id is not None and \
92
- tokenizer.pad_token_id is None:
93
- tokenizer.pad_token_id = config.pad_token_id \
94
- or tokenizer.eos_token_id
 
 
95
 
96
  return model, tokenizer
97
  except EnvironmentError:
@@ -102,16 +111,16 @@ class S3ModelLoader:
102
  tokenizer = AutoTokenizer.from_pretrained(
103
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
104
  )
105
-
106
  model = AutoModelForCausalLM.from_pretrained(
107
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
108
  )
 
 
109
 
110
-
111
- if tokenizer.eos_token_id is not None and \
112
- tokenizer.pad_token_id is None:
113
- tokenizer.pad_token_id = config.pad_token_id \
114
- or tokenizer.eos_token_id
115
 
116
  model.save_pretrained(s3_uri)
117
  tokenizer.save_pretrained(s3_uri)
@@ -121,8 +130,10 @@ class S3ModelLoader:
121
  status_code=500, detail=f"Error loading model: {e}"
122
  )
123
 
 
124
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
125
 
 
126
  @app.post("/generate")
127
  async def generate(request: GenerateRequest):
128
  try:
@@ -130,7 +141,7 @@ async def generate(request: GenerateRequest):
130
  input_text = request.input_text
131
  task_type = request.task_type
132
  temperature = request.temperature
133
- max_new_tokens = request.max_new_tokens #This value will be used to constraint the output
134
  stream = request.stream
135
  top_p = request.top_p
136
  top_k = request.top_k
@@ -139,15 +150,14 @@ async def generate(request: GenerateRequest):
139
  do_sample = request.do_sample
140
  stop_sequences = request.stop_sequences
141
 
142
- model, tokenizer = await model_loader.\
143
- load_model_and_tokenizer(model_name)
144
  device = "cuda" if torch.cuda.is_available() else "cpu"
145
  model.to(device)
146
-
147
  if "text-to-text" == task_type:
148
  generation_config = GenerationConfig(
149
  temperature=temperature,
150
- max_new_tokens=min(max_new_tokens,10), # Constrain max_new_tokens to 10
151
  top_p=top_p,
152
  top_k=top_k,
153
  repetition_penalty=repetition_penalty,
@@ -156,13 +166,21 @@ async def generate(request: GenerateRequest):
156
  )
157
 
158
  return StreamingResponse(
159
- stream_text(model, tokenizer, input_text,
160
- generation_config, stop_sequences,
161
- device, max_length=10),
162
- media_type="text/plain"
 
 
 
 
 
 
163
  )
164
  else:
165
- return HTTPException(status_code=400, detail="Task type not text-to-text")
 
 
166
 
167
  except Exception as e:
168
  raise HTTPException(
@@ -170,12 +188,11 @@ async def generate(request: GenerateRequest):
170
  )
171
 
172
 
173
- async def stream_text(model, tokenizer, input_text,
174
- generation_config, stop_sequences,
175
- device, max_length):
176
  encoded_input = tokenizer(
177
- input_text, return_tensors="pt",
178
- truncation=True, max_length=max_length
179
  ).to(device)
180
  input_length = encoded_input["input_ids"].shape[1]
181
  remaining_tokens = max_length - input_length
@@ -186,14 +203,12 @@ async def stream_text(model, tokenizer, input_text,
186
  generation_config.max_new_tokens = min(
187
  remaining_tokens, generation_config.max_new_tokens
188
  )
189
-
190
 
191
  def find_stop(output_text, stop_sequences):
192
  for seq in stop_sequences:
193
  if seq in output_text:
194
  last_index = output_text.rfind(seq)
195
  return last_index + len(seq)
196
-
197
  return -1
198
 
199
  output_text = ""
@@ -214,7 +229,7 @@ async def stream_text(model, tokenizer, input_text,
214
 
215
  new_text = tokenizer.decode(
216
  outputs.sequences[0][len(encoded_input["input_ids"][0]):],
217
- skip_special_tokens=True
218
  )
219
 
220
  output_text += new_text
@@ -223,8 +238,9 @@ async def stream_text(model, tokenizer, input_text,
223
 
224
  if stop_index != -1:
225
  final_output = output_text[:stop_index]
226
- chunked_output = [final_output[i:i+10]
227
- for i in range(0, len(final_output), 10)]
 
228
 
229
  for chunk in chunked_output:
230
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
@@ -233,15 +249,17 @@ async def stream_text(model, tokenizer, input_text,
233
  break
234
 
235
  else:
236
- chunked_output = [new_text[i:i+10]
237
- for i in range(0, len(new_text), 10)]
 
238
 
239
- for chunk in chunked_output:
240
- yield json.dumps({"text": chunk, "is_end": False}) + "\n"
241
 
242
  if len(output_text) >= generation_config.max_new_tokens:
243
- chunked_output = [output_text[i:i+10]
244
- for i in range(0, len(output_text), 10)]
 
245
 
246
  for chunk in chunked_output:
247
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
@@ -249,10 +267,10 @@ async def stream_text(model, tokenizer, input_text,
249
  break
250
 
251
  encoded_input = tokenizer(
252
- output_text, return_tensors="pt",
253
- truncation=True, max_length=max_length
254
  ).to(device)
255
 
 
256
  @app.post("/generate-image")
257
  async def generate_image(request: GenerateRequest):
258
  try:
@@ -260,19 +278,27 @@ async def generate_image(request: GenerateRequest):
260
  device = "cuda" if torch.cuda.is_available() else "cpu"
261
 
262
  image_generator = pipeline(
263
- "text-to-image", model=validated_body.model_name,
264
- device=device
265
  )
266
  image = image_generator(validated_body.input_text)[0]
267
-
268
- image_data = list(image.getdata())
269
-
270
- return json.dumps({"image_data": image_data, "is_end": True})
 
 
 
 
 
 
 
 
 
 
271
 
272
  except Exception as e:
273
  raise HTTPException(
274
- status_code=500,
275
- detail=f"Internal server error: {str(e)}"
276
  )
277
 
278
 
@@ -283,22 +309,25 @@ async def generate_text_to_speech(request: GenerateRequest):
283
  device = "cuda" if torch.cuda.is_available() else "cpu"
284
 
285
  audio_generator = pipeline(
286
- "text-to-speech", model=validated_body.model_name,
287
- device=device
288
  )
289
  audio = audio_generator(validated_body.input_text)
290
-
291
-
292
  audio_bytes = audio["audio"]
293
-
294
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
295
-
296
- return json.dumps({"audio": audio_base64, "is_end": True})
 
 
 
 
 
 
 
297
 
298
  except Exception as e:
299
  raise HTTPException(
300
- status_code=500,
301
- detail=f"Internal server error: {str(e)}"
302
  )
303
 
304
 
@@ -308,21 +337,26 @@ async def generate_video(request: GenerateRequest):
308
  validated_body = request
309
  device = "cuda" if torch.cuda.is_available() else "cpu"
310
  video_generator = pipeline(
311
- "text-to-video", model=validated_body.model_name,
312
- device=device
313
  )
314
  video = video_generator(validated_body.input_text)
315
-
316
-
317
- video_base64 = base64.b64encode(video).decode('utf-8')
318
-
319
- return json.dumps({"video": video_base64, "is_end": True})
 
 
 
 
 
 
320
 
321
  except Exception as e:
322
  raise HTTPException(
323
- status_code=500,
324
- detail=f"Internal server error: {str(e)}"
325
  )
326
 
 
327
  if __name__ == "__main__":
328
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
17
  import json
18
  from huggingface_hub import login
19
  import base64
20
+ import io
21
+ from PIL import Image
22
 
23
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
24
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
26
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
27
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
28
 
 
29
  if HUGGINGFACE_HUB_TOKEN:
30
+ login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=False)
 
31
 
32
+ s3_client = boto3.client(
33
+ "s3",
34
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
35
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
36
+ region_name=AWS_REGION,
37
+ )
38
 
39
  app = FastAPI()
40
 
41
+ # Global variables for tokenizer tokens
42
+ EOS_TOKEN_ID = None
43
+ PAD_TOKEN_ID = None
44
+
45
+
46
  class GenerateRequest(BaseModel):
47
  model_name: str
48
  input_text: str = ""
49
  task_type: str
50
  temperature: float = 1.0
51
+ max_new_tokens: int = 200 # this will be limited to 10
52
  stream: bool = True
53
  top_p: float = 1.0
54
  top_k: int = 50
 
71
  raise ValueError(f"task_type must be one of: {valid_types}")
72
  return v
73
 
74
+
75
  class S3ModelLoader:
76
  def __init__(self, bucket_name, s3_client):
77
  self.bucket_name = bucket_name
 
82
  f"{model_name.replace('/', '-')}"
83
 
84
  async def load_model_and_tokenizer(self, model_name):
85
+ global EOS_TOKEN_ID, PAD_TOKEN_ID
86
  s3_uri = self._get_s3_uri(model_name)
87
  try:
88
  config = AutoConfig.from_pretrained(
89
  s3_uri, local_files_only=False
90
  )
 
91
  model = AutoModelForCausalLM.from_pretrained(
92
  s3_uri, config=config, local_files_only=False
93
  )
 
94
  tokenizer = AutoTokenizer.from_pretrained(
95
  s3_uri, config=config, local_files_only=False
96
  )
97
 
98
+ EOS_TOKEN_ID = tokenizer.eos_token_id
99
+ PAD_TOKEN_ID = tokenizer.pad_token_id
100
+
101
+ if EOS_TOKEN_ID is not None and PAD_TOKEN_ID is None:
102
+ PAD_TOKEN_ID = config.pad_token_id or EOS_TOKEN_ID
103
+ tokenizer.pad_token_id = PAD_TOKEN_ID
104
 
105
  return model, tokenizer
106
  except EnvironmentError:
 
111
  tokenizer = AutoTokenizer.from_pretrained(
112
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
113
  )
114
+
115
  model = AutoModelForCausalLM.from_pretrained(
116
  model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
117
  )
118
+ EOS_TOKEN_ID = tokenizer.eos_token_id
119
+ PAD_TOKEN_ID = tokenizer.pad_token_id
120
 
121
+ if EOS_TOKEN_ID is not None and PAD_TOKEN_ID is None:
122
+ PAD_TOKEN_ID = config.pad_token_id or EOS_TOKEN_ID
123
+ tokenizer.pad_token_id = PAD_TOKEN_ID
 
 
124
 
125
  model.save_pretrained(s3_uri)
126
  tokenizer.save_pretrained(s3_uri)
 
130
  status_code=500, detail=f"Error loading model: {e}"
131
  )
132
 
133
+
134
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
135
 
136
+
137
  @app.post("/generate")
138
  async def generate(request: GenerateRequest):
139
  try:
 
141
  input_text = request.input_text
142
  task_type = request.task_type
143
  temperature = request.temperature
144
+ max_new_tokens = request.max_new_tokens
145
  stream = request.stream
146
  top_p = request.top_p
147
  top_k = request.top_k
 
150
  do_sample = request.do_sample
151
  stop_sequences = request.stop_sequences
152
 
153
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
 
154
  device = "cuda" if torch.cuda.is_available() else "cpu"
155
  model.to(device)
156
+
157
  if "text-to-text" == task_type:
158
  generation_config = GenerationConfig(
159
  temperature=temperature,
160
+ max_new_tokens=min(max_new_tokens, 10), # Constrain max_new_tokens to 10
161
  top_p=top_p,
162
  top_k=top_k,
163
  repetition_penalty=repetition_penalty,
 
166
  )
167
 
168
  return StreamingResponse(
169
+ stream_text(
170
+ model,
171
+ tokenizer,
172
+ input_text,
173
+ generation_config,
174
+ stop_sequences,
175
+ device,
176
+ max_length=10,
177
+ ),
178
+ media_type="text/plain",
179
  )
180
  else:
181
+ raise HTTPException(
182
+ status_code=400, detail="Task type not text-to-text"
183
+ )
184
 
185
  except Exception as e:
186
  raise HTTPException(
 
188
  )
189
 
190
 
191
+ async def stream_text(
192
+ model, tokenizer, input_text, generation_config, stop_sequences, device, max_length
193
+ ):
194
  encoded_input = tokenizer(
195
+ input_text, return_tensors="pt", truncation=True, max_length=max_length
 
196
  ).to(device)
197
  input_length = encoded_input["input_ids"].shape[1]
198
  remaining_tokens = max_length - input_length
 
203
  generation_config.max_new_tokens = min(
204
  remaining_tokens, generation_config.max_new_tokens
205
  )
 
206
 
207
  def find_stop(output_text, stop_sequences):
208
  for seq in stop_sequences:
209
  if seq in output_text:
210
  last_index = output_text.rfind(seq)
211
  return last_index + len(seq)
 
212
  return -1
213
 
214
  output_text = ""
 
229
 
230
  new_text = tokenizer.decode(
231
  outputs.sequences[0][len(encoded_input["input_ids"][0]):],
232
+ skip_special_tokens=True,
233
  )
234
 
235
  output_text += new_text
 
238
 
239
  if stop_index != -1:
240
  final_output = output_text[:stop_index]
241
+ chunked_output = [
242
+ final_output[i: i + 10] for i in range(0, len(final_output), 10)
243
+ ]
244
 
245
  for chunk in chunked_output:
246
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
 
249
  break
250
 
251
  else:
252
+ chunked_output = [
253
+ new_text[i: i + 10] for i in range(0, len(new_text), 10)
254
+ ]
255
 
256
+ for chunk in chunked_output:
257
+ yield json.dumps({"text": chunk, "is_end": False}) + "\n"
258
 
259
  if len(output_text) >= generation_config.max_new_tokens:
260
+ chunked_output = [
261
+ output_text[i: i + 10] for i in range(0, len(output_text), 10)
262
+ ]
263
 
264
  for chunk in chunked_output:
265
  yield json.dumps({"text": chunk, "is_end": False}) + "\n"
 
267
  break
268
 
269
  encoded_input = tokenizer(
270
+ output_text, return_tensors="pt", truncation=True, max_length=max_length
 
271
  ).to(device)
272
 
273
+
274
  @app.post("/generate-image")
275
  async def generate_image(request: GenerateRequest):
276
  try:
 
278
  device = "cuda" if torch.cuda.is_available() else "cpu"
279
 
280
  image_generator = pipeline(
281
+ "text-to-image", model=validated_body.model_name, device=device
 
282
  )
283
  image = image_generator(validated_body.input_text)[0]
284
+
285
+ async def stream_image():
286
+ buffered = io.BytesIO()
287
+ image.save(buffered, format="PNG")
288
+ image_bytes = buffered.getvalue()
289
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
290
+ chunk_size = 1000
291
+ for i in range(0, len(image_base64), chunk_size):
292
+ chunk = image_base64[i: i + chunk_size]
293
+ yield json.dumps({"image": chunk, "is_end": False}) + "\n"
294
+
295
+ yield json.dumps({"image": "", "is_end": True}) + "\n"
296
+
297
+ return StreamingResponse(stream_image(), media_type="text/plain")
298
 
299
  except Exception as e:
300
  raise HTTPException(
301
+ status_code=500, detail=f"Internal server error: {str(e)}"
 
302
  )
303
 
304
 
 
309
  device = "cuda" if torch.cuda.is_available() else "cpu"
310
 
311
  audio_generator = pipeline(
312
+ "text-to-speech", model=validated_body.model_name, device=device
 
313
  )
314
  audio = audio_generator(validated_body.input_text)
 
 
315
  audio_bytes = audio["audio"]
316
+
317
+ async def stream_audio():
318
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
319
+ chunk_size = 1000
320
+ for i in range(0, len(audio_base64), chunk_size):
321
+ chunk = audio_base64[i: i + chunk_size]
322
+ yield json.dumps({"audio": chunk, "is_end": False}) + "\n"
323
+
324
+ yield json.dumps({"audio": "", "is_end": True}) + "\n"
325
+
326
+ return StreamingResponse(stream_audio(), media_type="text/plain")
327
 
328
  except Exception as e:
329
  raise HTTPException(
330
+ status_code=500, detail=f"Internal server error: {str(e)}"
 
331
  )
332
 
333
 
 
337
  validated_body = request
338
  device = "cuda" if torch.cuda.is_available() else "cpu"
339
  video_generator = pipeline(
340
+ "text-to-video", model=validated_body.model_name, device=device
 
341
  )
342
  video = video_generator(validated_body.input_text)
343
+
344
+ async def stream_video():
345
+ video_base64 = base64.b64encode(video).decode("utf-8")
346
+ chunk_size = 1000
347
+ for i in range(0, len(video_base64), chunk_size):
348
+ chunk = video_base64[i: i + chunk_size]
349
+ yield json.dumps({"video": chunk, "is_end": False}) + "\n"
350
+
351
+ yield json.dumps({"video": "", "is_end": True}) + "\n"
352
+ return StreamingResponse(stream_video(), media_type="text/plain")
353
+
354
 
355
  except Exception as e:
356
  raise HTTPException(
357
+ status_code=500, detail=f"Internal server error: {str(e)}"
 
358
  )
359
 
360
+
361
  if __name__ == "__main__":
362
  uvicorn.run(app, host="0.0.0.0", port=7860)