Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -121,17 +121,16 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
121 |
raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
|
122 |
input_text = previous_data["output"]
|
123 |
|
124 |
-
generation_config = GenerationConfig(
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
)
|
135 |
|
136 |
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
137 |
|
@@ -147,19 +146,23 @@ async def generate(request: GenerateRequest, model_resources: tuple = Depends(ge
|
|
147 |
|
148 |
def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
|
149 |
max_model_length = model.config.max_position_embeddings
|
150 |
-
encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True)
|
151 |
|
152 |
stopping_criteria = StoppingCriteriaList()
|
153 |
|
154 |
-
class CustomStoppingCriteria(
|
|
|
|
|
|
|
|
|
155 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
156 |
-
decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
157 |
-
for stop in stop_sequences:
|
158 |
if decoded_output.endswith(stop):
|
159 |
return True
|
160 |
return False
|
161 |
|
162 |
-
stopping_criteria.append(CustomStoppingCriteria())
|
163 |
|
164 |
outputs = model.generate(
|
165 |
encoded_input.input_ids,
|
@@ -174,7 +177,7 @@ def generate_text_internal(model, tokenizer, input_text, generation_config, stop
|
|
174 |
async def load_pipeline_from_s3(task, model_name):
|
175 |
s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
|
176 |
try:
|
177 |
-
return pipeline(task, model=s3_uri)
|
178 |
except Exception as e:
|
179 |
raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
|
180 |
|
@@ -186,9 +189,11 @@ async def generate_image(request: GenerateRequest):
|
|
186 |
|
187 |
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
188 |
image = image_generator(request.input_text)[0]
|
|
|
|
|
189 |
new_continuation_id = os.urandom(16).hex()
|
190 |
-
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Image
|
191 |
-
return JSONResponse({"url":
|
192 |
|
193 |
except HTTPException as http_err:
|
194 |
raise http_err
|
@@ -202,10 +207,12 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
202 |
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
203 |
|
204 |
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
205 |
-
|
|
|
|
|
206 |
new_continuation_id = os.urandom(16).hex()
|
207 |
-
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Audio
|
208 |
-
return JSONResponse({"url":
|
209 |
|
210 |
except HTTPException as http_err:
|
211 |
raise http_err
|
@@ -219,10 +226,12 @@ async def generate_video(request: GenerateRequest):
|
|
219 |
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
220 |
|
221 |
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
222 |
-
|
|
|
|
|
223 |
new_continuation_id = os.urandom(16).hex()
|
224 |
-
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": "Video
|
225 |
-
return JSONResponse({"url":
|
226 |
|
227 |
except HTTPException as http_err:
|
228 |
raise http_err
|
|
|
121 |
raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
|
122 |
input_text = previous_data["output"]
|
123 |
|
124 |
+
generation_config = GenerationConfig.from_pretrained(model_name) # Load default config and override
|
125 |
+
generation_config.temperature = temperature
|
126 |
+
generation_config.max_new_tokens = max_new_tokens
|
127 |
+
generation_config.top_p = top_p
|
128 |
+
generation_config.top_k = top_k
|
129 |
+
generation_config.repetition_penalty = repetition_penalty
|
130 |
+
generation_config.do_sample = do_sample
|
131 |
+
generation_config.num_return_sequences = num_return_sequences
|
132 |
+
generation_config.no_repeat_ngram_size = no_repeat_ngram_size
|
133 |
+
generation_config.pad_token_id = tokenizer.pad_token_id
|
|
|
134 |
|
135 |
generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
|
136 |
|
|
|
146 |
|
147 |
def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
|
148 |
max_model_length = model.config.max_position_embeddings
|
149 |
+
encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(model.device) # Ensure input is on the same device as the model
|
150 |
|
151 |
stopping_criteria = StoppingCriteriaList()
|
152 |
|
153 |
+
class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
|
154 |
+
def __init__(self, stop_sequences, tokenizer):
|
155 |
+
self.stop_sequences = stop_sequences
|
156 |
+
self.tokenizer = tokenizer
|
157 |
+
|
158 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
159 |
+
decoded_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
160 |
+
for stop in self.stop_sequences:
|
161 |
if decoded_output.endswith(stop):
|
162 |
return True
|
163 |
return False
|
164 |
|
165 |
+
stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))
|
166 |
|
167 |
outputs = model.generate(
|
168 |
encoded_input.input_ids,
|
|
|
177 |
async def load_pipeline_from_s3(task, model_name):
|
178 |
s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
|
179 |
try:
|
180 |
+
return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
|
181 |
except Exception as e:
|
182 |
raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
|
183 |
|
|
|
189 |
|
190 |
image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
|
191 |
image = image_generator(request.input_text)[0]
|
192 |
+
image_path = f"generated_image_{os.urandom(8).hex()}.png" # Save image locally
|
193 |
+
image.save(image_path)
|
194 |
new_continuation_id = os.urandom(16).hex()
|
195 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Image saved to {image_path}"} # Return path or upload URL
|
196 |
+
return JSONResponse({"url": image_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
|
197 |
|
198 |
except HTTPException as http_err:
|
199 |
raise http_err
|
|
|
207 |
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
208 |
|
209 |
tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
|
210 |
+
audio_output = tts_pipeline(request.input_text)
|
211 |
+
audio_path = f"generated_audio_{os.urandom(8).hex()}.wav"
|
212 |
+
sf.write(audio_path, audio_output["sampling_rate"], audio_output["audio"])
|
213 |
new_continuation_id = os.urandom(16).hex()
|
214 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Audio saved to {audio_path}"}
|
215 |
+
return JSONResponse({"url": audio_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
|
216 |
|
217 |
except HTTPException as http_err:
|
218 |
raise http_err
|
|
|
226 |
raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
|
227 |
|
228 |
video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
|
229 |
+
video_frames = video_pipeline(request.input_text).frames
|
230 |
+
video_path = f"generated_video_{os.urandom(8).hex()}.mp4"
|
231 |
+
imageio.mimsave(video_path, video_frames, fps=30) # Adjust fps as needed
|
232 |
new_continuation_id = os.urandom(16).hex()
|
233 |
+
active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Video saved to {video_path}"}
|
234 |
+
return JSONResponse({"url": video_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
|
235 |
|
236 |
except HTTPException as http_err:
|
237 |
raise http_err
|