Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
from fastapi import
|
4 |
-
from fastapi.responses import JSONResponse, StreamingResponse
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
|
7 |
import boto3
|
8 |
import uvicorn
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
13 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
@@ -15,6 +15,9 @@ AWS_REGION = os.getenv("AWS_REGION")
|
|
15 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
16 |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
17 |
|
|
|
|
|
|
|
18 |
s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
|
19 |
|
20 |
app = FastAPI()
|
@@ -39,6 +42,7 @@ class GenerateRequest(BaseModel):
|
|
39 |
do_sample: bool = True
|
40 |
stop_sequences: list[str] = []
|
41 |
no_repeat_ngram_size: int = 2
|
|
|
42 |
|
43 |
@field_validator("model_name")
|
44 |
def model_name_cannot_be_empty(cls, v):
|
@@ -70,37 +74,33 @@ class S3ModelLoader:
|
|
70 |
async def load_model_and_tokenizer(self, model_name):
|
71 |
s3_uri = self._get_s3_uri(model_name)
|
72 |
try:
|
73 |
-
config = AutoConfig.from_pretrained(s3_uri, local_files_only=
|
74 |
-
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=
|
75 |
-
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=
|
76 |
tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
77 |
model.resize_token_embeddings(len(tokenizer))
|
78 |
if tokenizer.pad_token_id is None:
|
79 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
80 |
return model, tokenizer
|
81 |
-
except
|
82 |
-
|
83 |
-
config = AutoConfig.from_pretrained(model_name)
|
84 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
85 |
-
tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
86 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
|
87 |
-
model.resize_token_embeddings(len(tokenizer))
|
88 |
-
if tokenizer.pad_token_id is None:
|
89 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
90 |
-
model.save_pretrained(s3_uri)
|
91 |
-
tokenizer.save_pretrained(s3_uri)
|
92 |
-
return model, tokenizer
|
93 |
-
except Exception as e:
|
94 |
-
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
95 |
|
96 |
model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
@app.post("/generate")
|
99 |
-
async def generate(request: GenerateRequest):
|
|
|
100 |
try:
|
101 |
model_name = request.model_name
|
102 |
input_text = request.input_text
|
103 |
-
task_type = request.task_type
|
104 |
temperature = request.temperature
|
105 |
max_new_tokens = request.max_new_tokens
|
106 |
top_p = request.top_p
|
@@ -110,10 +110,13 @@ async def generate(request: GenerateRequest):
|
|
110 |
do_sample = request.do_sample
|
111 |
stop_sequences = request.stop_sequences
|
112 |
no_repeat_ngram_size = request.no_repeat_ngram_size
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
117 |
|
118 |
generation_config = GenerationConfig(
|
119 |
temperature=temperature,
|
@@ -127,15 +130,24 @@ async def generate(request: GenerateRequest):
|
|
127 |
pad_token_id=tokenizer.pad_token_id
|
128 |
)
|
129 |
|
130 |
-
generated_text =
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
132 |
|
|
|
|
|
|
|
|
|
133 |
except Exception as e:
|
134 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
135 |
|
136 |
-
def
|
137 |
max_model_length = model.config.max_position_embeddings
|
138 |
-
encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True)
|
139 |
|
140 |
stopping_criteria = StoppingCriteriaList()
|
141 |
|
@@ -159,62 +171,61 @@ def generate_text(model, tokenizer, input_text, generation_config, stop_sequence
|
|
159 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
160 |
return generated_text
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
@app.post("/generate-image")
|
163 |
async def generate_image(request: GenerateRequest):
|
164 |
try:
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
|
169 |
-
image = image_generator(validated_body.input_text)[0]
|
170 |
-
|
171 |
-
img_byte_arr = BytesIO()
|
172 |
-
image.save(img_byte_arr, format="PNG")
|
173 |
-
img_byte_arr.seek(0)
|
174 |
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
|
|
|
|
177 |
except Exception as e:
|
178 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
179 |
|
180 |
@app.post("/generate-text-to-speech")
|
181 |
async def generate_text_to_speech(request: GenerateRequest):
|
182 |
try:
|
183 |
-
|
184 |
-
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
# Saving like this will not always work. Please check how your
|
192 |
-
# audio_generator model is working.
|
193 |
-
audio_generator.save_audio(audio_byte_arr, audio)
|
194 |
-
audio_byte_arr.seek(0)
|
195 |
-
|
196 |
-
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
197 |
|
|
|
|
|
198 |
except Exception as e:
|
199 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
200 |
|
201 |
@app.post("/generate-video")
|
202 |
async def generate_video(request: GenerateRequest):
|
203 |
try:
|
204 |
-
|
205 |
-
|
206 |
-
video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
|
207 |
-
video = video_generator(validated_body.input_text)[0]
|
208 |
-
|
209 |
-
video_byte_arr = BytesIO()
|
210 |
-
# Same as above. Please check how your video model is returning the
|
211 |
-
# videos and save them accordingly.
|
212 |
-
# It is expected that the video is saved as MP4
|
213 |
-
video_generator.save_video(video_byte_arr, video)
|
214 |
-
video_byte_arr.seek(0)
|
215 |
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
|
|
|
|
|
218 |
except Exception as e:
|
219 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
220 |
|
|
|
1 |
import os
|
2 |
+
from fastapi import FastAPI, HTTPException, Depends
|
3 |
+
from fastapi.responses import JSONResponse
|
|
|
4 |
from pydantic import BaseModel, field_validator
|
5 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline
|
6 |
import boto3
|
7 |
import uvicorn
|
8 |
+
import soundfile as sf
|
9 |
+
import imageio
|
10 |
+
from typing import Dict
|
11 |
|
12 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
13 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
|
15 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
16 |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
17 |
|
18 |
+
if not all([AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION, S3_BUCKET_NAME]):
|
19 |
+
raise ValueError("Missing one or more AWS environment variables.")
|
20 |
+
|
21 |
s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
|
22 |
|
23 |
app = FastAPI()
|
|
|
42 |
do_sample: bool = True
|
43 |
stop_sequences: list[str] = []
|
44 |
no_repeat_ngram_size: int = 2
|
45 |
+
continuation_id: str = None
|
46 |
|
47 |
@field_validator("model_name")
|
48 |
def model_name_cannot_be_empty(cls, v):
|
|
|
74 |
async def load_model_and_tokenizer(self, model_name):
|
75 |
s3_uri = self._get_s3_uri(model_name)
|
76 |
try:
|
77 |
+
config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
|
78 |
+
model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False)
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=False)
|
80 |
tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
81 |
model.resize_token_embeddings(len(tokenizer))
|
82 |
if tokenizer.pad_token_id is None:
|
83 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
84 |
return model, tokenizer
|
85 |
+
except Exception as e:
|
86 |
+
raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
|
89 |
|
90 |
+
active_generations: Dict[str, Dict] = {}
|
91 |
+
|
92 |
+
async def get_model_and_tokenizer(model_name: str):
|
93 |
+
try:
|
94 |
+
return await model_loader.load_model_and_tokenizer(model_name)
|
95 |
+
except Exception as e:
|
96 |
+
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
97 |
+
|
98 |
@app.post("/generate")
|
99 |
+
async def generate(request: GenerateRequest, model_resources: tuple = Depends(get_model_and_tokenizer)):
|
100 |
+
model, tokenizer = model_resources
|
101 |
try:
|
102 |
model_name = request.model_name
|
103 |
input_text = request.input_text
|
|
|
104 |
temperature = request.temperature
|
105 |
max_new_tokens = request.max_new_tokens
|
106 |
top_p = request.top_p
|
|
|
110 |
do_sample = request.do_sample
|
111 |
stop_sequences = request.stop_sequences
|
112 |
no_repeat_ngram_size = request.no_repeat_ngram_size
|
113 |
+
continuation_id = request.continuation_id
|
114 |
|
115 |
+
if continuation_id:
|
116 |
+
if continuation_id not in active_generations:
|
117 |
+
raise HTTPException(status_code=404, detail="Continuation ID not found.")
|
118 |
+
previous_output = active_generations[continuation_id]["output"]
|
119 |
+
input_text = previous_output
|
120 |
|
121 |
generation_config = GenerationConfig(
|
122 |
temperature=temperature,
|
|
|
130 |
pad_token_id=tokenizer.pad_token_id
|
131 |
)
|
132 |
|
133 |
+
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
134 |
+
|
135 |
+
if not continuation_id:
|
136 |
+
continuation_id = os.urandom(16).hex()
|
137 |
+
active_generations[continuation_id] = {"model_name": model_name, "output": generated_text}
|
138 |
+
else:
|
139 |
+
active_generations[continuation_id]["output"] = generated_text
|
140 |
|
141 |
+
return JSONResponse({"text": generated_text, "continuation_id": continuation_id})
|
142 |
+
|
143 |
+
except HTTPException as http_err:
|
144 |
+
raise http_err
|
145 |
except Exception as e:
|
146 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
147 |
|
148 |
+
def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
|
149 |
max_model_length = model.config.max_position_embeddings
|
150 |
+
encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True)
|
151 |
|
152 |
stopping_criteria = StoppingCriteriaList()
|
153 |
|
|
|
171 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
172 |
return generated_text
|
173 |
|
174 |
+
async def load_pipeline_from_s3(task, model_name):
|
175 |
+
s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
|
176 |
+
try:
|
177 |
+
return pipeline(task, model=s3_uri)
|
178 |
+
except Exception as e:
|
179 |
+
raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
|
180 |
+
|
181 |
@app.post("/generate-image")
|
182 |
async def generate_image(request: GenerateRequest):
|
183 |
try:
|
184 |
+
if request.task_type != "text-to-image":
|
185 |
+
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
188 |
+
image = image_generator(request.input_text)[0]
|
189 |
+
continuation_id = os.urandom(16).hex()
|
190 |
+
active_generations[continuation_id] = {"model_name": request.model_name, "output": "Image generated successfully"}
|
191 |
+
return JSONResponse({"url": "Image generated successfully", "continuation_id": continuation_id})
|
192 |
|
193 |
+
except HTTPException as http_err:
|
194 |
+
raise http_err
|
195 |
except Exception as e:
|
196 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
197 |
|
198 |
@app.post("/generate-text-to-speech")
|
199 |
async def generate_text_to_speech(request: GenerateRequest):
|
200 |
try:
|
201 |
+
if request.task_type != "text-to-speech":
|
202 |
+
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
203 |
|
204 |
+
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
205 |
+
output = tts_pipeline(request.input_text)
|
206 |
+
continuation_id = os.urandom(16).hex()
|
207 |
+
active_generations[continuation_id] = {"model_name": request.model_name, "output": "Audio generated successfully"}
|
208 |
+
return JSONResponse({"url": "Audio generated successfully", "continuation_id": continuation_id})
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
except HTTPException as http_err:
|
211 |
+
raise http_err
|
212 |
except Exception as e:
|
213 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
214 |
|
215 |
@app.post("/generate-video")
|
216 |
async def generate_video(request: GenerateRequest):
|
217 |
try:
|
218 |
+
if request.task_type != "text-to-video":
|
219 |
+
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
222 |
+
output = video_pipeline(request.input_text)
|
223 |
+
continuation_id = os.urandom(16).hex()
|
224 |
+
active_generations[continuation_id] = {"model_name": request.model_name, "output": "Video generated successfully"}
|
225 |
+
return JSONResponse({"url": "Video generated successfully", "continuation_id": continuation_id})
|
226 |
|
227 |
+
except HTTPException as http_err:
|
228 |
+
raise http_err
|
229 |
except Exception as e:
|
230 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
231 |
|