Hjgugugjhuhjggg commited on
Commit
fc5872a
·
verified ·
1 Parent(s): ecffbb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -67
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import os
2
- import torch
3
- from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import JSONResponse, StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
7
  import boto3
8
  import uvicorn
9
- from io import BytesIO
10
- from transformers import pipeline
 
11
 
12
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
13
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -15,6 +15,9 @@ AWS_REGION = os.getenv("AWS_REGION")
15
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
16
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
17
 
 
 
 
18
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
19
 
20
  app = FastAPI()
@@ -39,6 +42,7 @@ class GenerateRequest(BaseModel):
39
  do_sample: bool = True
40
  stop_sequences: list[str] = []
41
  no_repeat_ngram_size: int = 2
 
42
 
43
  @field_validator("model_name")
44
  def model_name_cannot_be_empty(cls, v):
@@ -70,37 +74,33 @@ class S3ModelLoader:
70
  async def load_model_and_tokenizer(self, model_name):
71
  s3_uri = self._get_s3_uri(model_name)
72
  try:
73
- config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
74
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
75
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
76
  tokenizer.add_special_tokens(SPECIAL_TOKENS)
77
  model.resize_token_embeddings(len(tokenizer))
78
  if tokenizer.pad_token_id is None:
79
  tokenizer.pad_token_id = tokenizer.eos_token_id
80
  return model, tokenizer
81
- except EnvironmentError:
82
- try:
83
- config = AutoConfig.from_pretrained(model_name)
84
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
85
- tokenizer.add_special_tokens(SPECIAL_TOKENS)
86
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
87
- model.resize_token_embeddings(len(tokenizer))
88
- if tokenizer.pad_token_id is None:
89
- tokenizer.pad_token_id = tokenizer.eos_token_id
90
- model.save_pretrained(s3_uri)
91
- tokenizer.save_pretrained(s3_uri)
92
- return model, tokenizer
93
- except Exception as e:
94
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
95
 
96
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
97
 
 
 
 
 
 
 
 
 
98
  @app.post("/generate")
99
- async def generate(request: GenerateRequest):
 
100
  try:
101
  model_name = request.model_name
102
  input_text = request.input_text
103
- task_type = request.task_type
104
  temperature = request.temperature
105
  max_new_tokens = request.max_new_tokens
106
  top_p = request.top_p
@@ -110,10 +110,13 @@ async def generate(request: GenerateRequest):
110
  do_sample = request.do_sample
111
  stop_sequences = request.stop_sequences
112
  no_repeat_ngram_size = request.no_repeat_ngram_size
 
113
 
114
- model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
115
- device = "cuda" if torch.cuda.is_available() else "cpu"
116
- model.to(device)
 
 
117
 
118
  generation_config = GenerationConfig(
119
  temperature=temperature,
@@ -127,15 +130,24 @@ async def generate(request: GenerateRequest):
127
  pad_token_id=tokenizer.pad_token_id
128
  )
129
 
130
- generated_text = generate_text(model, tokenizer, input_text, generation_config, stop_sequences, device)
131
- return JSONResponse({"text": generated_text})
 
 
 
 
 
132
 
 
 
 
 
133
  except Exception as e:
134
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
135
 
136
- def generate_text(model, tokenizer, input_text, generation_config, stop_sequences, device):
137
  max_model_length = model.config.max_position_embeddings
138
- encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
139
 
140
  stopping_criteria = StoppingCriteriaList()
141
 
@@ -159,62 +171,61 @@ def generate_text(model, tokenizer, input_text, generation_config, stop_sequence
159
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
160
  return generated_text
161
 
 
 
 
 
 
 
 
162
  @app.post("/generate-image")
163
  async def generate_image(request: GenerateRequest):
164
  try:
165
- validated_body = request
166
- device = "cuda" if torch.cuda.is_available() else "cpu"
167
-
168
- image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
169
- image = image_generator(validated_body.input_text)[0]
170
-
171
- img_byte_arr = BytesIO()
172
- image.save(img_byte_arr, format="PNG")
173
- img_byte_arr.seek(0)
174
 
175
- return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
 
 
176
 
 
 
177
  except Exception as e:
178
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
179
 
180
  @app.post("/generate-text-to-speech")
181
  async def generate_text_to_speech(request: GenerateRequest):
182
  try:
183
- validated_body = request
184
- device = "cuda" if torch.cuda.is_available() else "cpu"
185
 
186
- audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
187
- audio = audio_generator(validated_body.input_text)[0]
188
-
189
- audio_byte_arr = BytesIO()
190
- # It is expected that the audio is saved as wav.
191
- # Saving like this will not always work. Please check how your
192
- # audio_generator model is working.
193
- audio_generator.save_audio(audio_byte_arr, audio)
194
- audio_byte_arr.seek(0)
195
-
196
- return StreamingResponse(audio_byte_arr, media_type="audio/wav")
197
 
 
 
198
  except Exception as e:
199
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
200
 
201
  @app.post("/generate-video")
202
  async def generate_video(request: GenerateRequest):
203
  try:
204
- validated_body = request
205
- device = "cuda" if torch.cuda.is_available() else "cpu"
206
- video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
207
- video = video_generator(validated_body.input_text)[0]
208
-
209
- video_byte_arr = BytesIO()
210
- # Same as above. Please check how your video model is returning the
211
- # videos and save them accordingly.
212
- # It is expected that the video is saved as MP4
213
- video_generator.save_video(video_byte_arr, video)
214
- video_byte_arr.seek(0)
215
 
216
- return StreamingResponse(video_byte_arr, media_type="video/mp4")
 
 
 
 
217
 
 
 
218
  except Exception as e:
219
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
220
 
 
1
  import os
2
+ from fastapi import FastAPI, HTTPException, Depends
3
+ from fastapi.responses import JSONResponse
 
4
  from pydantic import BaseModel, field_validator
5
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline
6
  import boto3
7
  import uvicorn
8
+ import soundfile as sf
9
+ import imageio
10
+ from typing import Dict
11
 
12
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
13
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
15
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
16
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
17
 
18
+ if not all([AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION, S3_BUCKET_NAME]):
19
+ raise ValueError("Missing one or more AWS environment variables.")
20
+
21
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
22
 
23
  app = FastAPI()
 
42
  do_sample: bool = True
43
  stop_sequences: list[str] = []
44
  no_repeat_ngram_size: int = 2
45
+ continuation_id: str = None
46
 
47
  @field_validator("model_name")
48
  def model_name_cannot_be_empty(cls, v):
 
74
  async def load_model_and_tokenizer(self, model_name):
75
  s3_uri = self._get_s3_uri(model_name)
76
  try:
77
+ config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
78
+ model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False)
79
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=False)
80
  tokenizer.add_special_tokens(SPECIAL_TOKENS)
81
  model.resize_token_embeddings(len(tokenizer))
82
  if tokenizer.pad_token_id is None:
83
  tokenizer.pad_token_id = tokenizer.eos_token_id
84
  return model, tokenizer
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
89
 
90
+ active_generations: Dict[str, Dict] = {}
91
+
92
+ async def get_model_and_tokenizer(model_name: str):
93
+ try:
94
+ return await model_loader.load_model_and_tokenizer(model_name)
95
+ except Exception as e:
96
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
97
+
98
  @app.post("/generate")
99
+ async def generate(request: GenerateRequest, model_resources: tuple = Depends(get_model_and_tokenizer)):
100
+ model, tokenizer = model_resources
101
  try:
102
  model_name = request.model_name
103
  input_text = request.input_text
 
104
  temperature = request.temperature
105
  max_new_tokens = request.max_new_tokens
106
  top_p = request.top_p
 
110
  do_sample = request.do_sample
111
  stop_sequences = request.stop_sequences
112
  no_repeat_ngram_size = request.no_repeat_ngram_size
113
+ continuation_id = request.continuation_id
114
 
115
+ if continuation_id:
116
+ if continuation_id not in active_generations:
117
+ raise HTTPException(status_code=404, detail="Continuation ID not found.")
118
+ previous_output = active_generations[continuation_id]["output"]
119
+ input_text = previous_output
120
 
121
  generation_config = GenerationConfig(
122
  temperature=temperature,
 
130
  pad_token_id=tokenizer.pad_token_id
131
  )
132
 
133
+ generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
134
+
135
+ if not continuation_id:
136
+ continuation_id = os.urandom(16).hex()
137
+ active_generations[continuation_id] = {"model_name": model_name, "output": generated_text}
138
+ else:
139
+ active_generations[continuation_id]["output"] = generated_text
140
 
141
+ return JSONResponse({"text": generated_text, "continuation_id": continuation_id})
142
+
143
+ except HTTPException as http_err:
144
+ raise http_err
145
  except Exception as e:
146
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
147
 
148
+ def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
149
  max_model_length = model.config.max_position_embeddings
150
+ encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True)
151
 
152
  stopping_criteria = StoppingCriteriaList()
153
 
 
171
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
172
  return generated_text
173
 
174
+ async def load_pipeline_from_s3(task, model_name):
175
+ s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
176
+ try:
177
+ return pipeline(task, model=s3_uri)
178
+ except Exception as e:
179
+ raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
180
+
181
  @app.post("/generate-image")
182
  async def generate_image(request: GenerateRequest):
183
  try:
184
+ if request.task_type != "text-to-image":
185
+ raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
 
 
 
 
 
 
 
186
 
187
+ image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
188
+ image = image_generator(request.input_text)[0]
189
+ continuation_id = os.urandom(16).hex()
190
+ active_generations[continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
191
+ return JSONResponse({"url": "Image generated successfully", "continuation_id": continuation_id})
192
 
193
+ except HTTPException as http_err:
194
+ raise http_err
195
  except Exception as e:
196
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
197
 
198
  @app.post("/generate-text-to-speech")
199
  async def generate_text_to_speech(request: GenerateRequest):
200
  try:
201
+ if request.task_type != "text-to-speech":
202
+ raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
203
 
204
+ tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
205
+ output = tts_pipeline(request.input_text)
206
+ continuation_id = os.urandom(16).hex()
207
+ active_generations[continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
208
+ return JSONResponse({"url": "Audio generated successfully", "continuation_id": continuation_id})
 
 
 
 
 
 
209
 
210
+ except HTTPException as http_err:
211
+ raise http_err
212
  except Exception as e:
213
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
214
 
215
  @app.post("/generate-video")
216
  async def generate_video(request: GenerateRequest):
217
  try:
218
+ if request.task_type != "text-to-video":
219
+ raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
 
 
 
 
 
 
 
 
 
220
 
221
+ video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
222
+ output = video_pipeline(request.input_text)
223
+ continuation_id = os.urandom(16).hex()
224
+ active_generations[continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
225
+ return JSONResponse({"url": "Video generated successfully", "continuation_id": continuation_id})
226
 
227
+ except HTTPException as http_err:
228
+ raise http_err
229
  except Exception as e:
230
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
231