Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ from transformers import (
|
|
14 |
)
|
15 |
import boto3
|
16 |
import uvicorn
|
|
|
17 |
from transformers import pipeline
|
18 |
import json
|
19 |
from huggingface_hub import login
|
@@ -77,7 +78,7 @@ class S3ModelLoader:
|
|
77 |
return f"s3://{self.bucket_name}/" \
|
78 |
f"{model_name.replace('/', '-')}"
|
79 |
|
80 |
-
def load_model_and_tokenizer(self, model_name):
|
81 |
if model_name in model_cache:
|
82 |
return model_cache[model_name]
|
83 |
|
@@ -147,7 +148,7 @@ class S3ModelLoader:
|
|
147 |
model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
|
148 |
|
149 |
@app.post("/generate")
|
150 |
-
def generate(request: GenerateRequest):
|
151 |
try:
|
152 |
model_name = request.model_name
|
153 |
input_text = request.input_text
|
@@ -162,7 +163,7 @@ def generate(request: GenerateRequest):
|
|
162 |
do_sample = request.do_sample
|
163 |
stop_sequences = request.stop_sequences
|
164 |
|
165 |
-
model, tokenizer = model_loader.load_model_and_tokenizer(model_name)
|
166 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
167 |
model.to(device)
|
168 |
|
@@ -185,7 +186,7 @@ def generate(request: GenerateRequest):
|
|
185 |
media_type="text/plain"
|
186 |
)
|
187 |
else:
|
188 |
-
result = generate_text(model, tokenizer, input_text,
|
189 |
generation_config, stop_sequences,
|
190 |
device)
|
191 |
return JSONResponse({"text": result, "is_end": True})
|
@@ -212,7 +213,7 @@ class StopOnSequences(StoppingCriteria):
|
|
212 |
return True
|
213 |
return False
|
214 |
|
215 |
-
def stream_text(model, tokenizer, input_text,
|
216 |
generation_config, stop_sequences,
|
217 |
device):
|
218 |
|
@@ -228,7 +229,7 @@ def stream_text(model, tokenizer, input_text,
|
|
228 |
|
229 |
while True:
|
230 |
|
231 |
-
outputs = model.generate
|
232 |
**encoded_input,
|
233 |
do_sample=generation_config.do_sample,
|
234 |
max_new_tokens=generation_config.max_new_tokens,
|
@@ -270,7 +271,7 @@ def stream_text(model, tokenizer, input_text,
|
|
270 |
output_text = ""
|
271 |
|
272 |
|
273 |
-
def generate_text(model, tokenizer, input_text,
|
274 |
generation_config, stop_sequences,
|
275 |
device):
|
276 |
encoded_input = tokenizer(
|
@@ -281,7 +282,7 @@ def generate_text(model, tokenizer, input_text,
|
|
281 |
stop_criteria = StopOnSequences(stop_sequences, tokenizer)
|
282 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
283 |
|
284 |
-
outputs = model.generate
|
285 |
**encoded_input,
|
286 |
do_sample=generation_config.do_sample,
|
287 |
max_new_tokens=generation_config.max_new_tokens,
|
@@ -302,8 +303,9 @@ def generate_text(model, tokenizer, input_text,
|
|
302 |
|
303 |
return generated_text
|
304 |
|
|
|
305 |
@app.post("/generate-image")
|
306 |
-
def generate_image(request: GenerateRequest):
|
307 |
try:
|
308 |
validated_body = request
|
309 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -331,7 +333,7 @@ def generate_image(request: GenerateRequest):
|
|
331 |
|
332 |
|
333 |
@app.post("/generate-text-to-speech")
|
334 |
-
def generate_text_to_speech(request: GenerateRequest):
|
335 |
try:
|
336 |
validated_body = request
|
337 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -362,7 +364,7 @@ def generate_text_to_speech(request: GenerateRequest):
|
|
362 |
|
363 |
|
364 |
@app.post("/generate-video")
|
365 |
-
def generate_video(request: GenerateRequest):
|
366 |
try:
|
367 |
validated_body = request
|
368 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
14 |
)
|
15 |
import boto3
|
16 |
import uvicorn
|
17 |
+
import asyncio
|
18 |
from transformers import pipeline
|
19 |
import json
|
20 |
from huggingface_hub import login
|
|
|
78 |
return f"s3://{self.bucket_name}/" \
|
79 |
f"{model_name.replace('/', '-')}"
|
80 |
|
81 |
+
async def load_model_and_tokenizer(self, model_name):
|
82 |
if model_name in model_cache:
|
83 |
return model_cache[model_name]
|
84 |
|
|
|
148 |
model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
|
149 |
|
150 |
@app.post("/generate")
|
151 |
+
async def generate(request: GenerateRequest):
|
152 |
try:
|
153 |
model_name = request.model_name
|
154 |
input_text = request.input_text
|
|
|
163 |
do_sample = request.do_sample
|
164 |
stop_sequences = request.stop_sequences
|
165 |
|
166 |
+
model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
|
167 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
168 |
model.to(device)
|
169 |
|
|
|
186 |
media_type="text/plain"
|
187 |
)
|
188 |
else:
|
189 |
+
result = await generate_text(model, tokenizer, input_text,
|
190 |
generation_config, stop_sequences,
|
191 |
device)
|
192 |
return JSONResponse({"text": result, "is_end": True})
|
|
|
213 |
return True
|
214 |
return False
|
215 |
|
216 |
+
async def stream_text(model, tokenizer, input_text,
|
217 |
generation_config, stop_sequences,
|
218 |
device):
|
219 |
|
|
|
229 |
|
230 |
while True:
|
231 |
|
232 |
+
outputs = await asyncio.to_thread(model.generate,
|
233 |
**encoded_input,
|
234 |
do_sample=generation_config.do_sample,
|
235 |
max_new_tokens=generation_config.max_new_tokens,
|
|
|
271 |
output_text = ""
|
272 |
|
273 |
|
274 |
+
async def generate_text(model, tokenizer, input_text,
|
275 |
generation_config, stop_sequences,
|
276 |
device):
|
277 |
encoded_input = tokenizer(
|
|
|
282 |
stop_criteria = StopOnSequences(stop_sequences, tokenizer)
|
283 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
284 |
|
285 |
+
outputs = await asyncio.to_thread(model.generate,
|
286 |
**encoded_input,
|
287 |
do_sample=generation_config.do_sample,
|
288 |
max_new_tokens=generation_config.max_new_tokens,
|
|
|
303 |
|
304 |
return generated_text
|
305 |
|
306 |
+
|
307 |
@app.post("/generate-image")
|
308 |
+
async def generate_image(request: GenerateRequest):
|
309 |
try:
|
310 |
validated_body = request
|
311 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
333 |
|
334 |
|
335 |
@app.post("/generate-text-to-speech")
|
336 |
+
async def generate_text_to_speech(request: GenerateRequest):
|
337 |
try:
|
338 |
validated_body = request
|
339 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
364 |
|
365 |
|
366 |
@app.post("/generate-video")
|
367 |
+
async def generate_video(request: GenerateRequest):
|
368 |
try:
|
369 |
validated_body = request
|
370 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|