Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ from fastapi import FastAPI, HTTPException
|
|
6 |
from fastapi.responses import JSONResponse
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
8 |
import asyncio
|
9 |
-
import concurrent.futures
|
10 |
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
@@ -16,7 +15,7 @@ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
16 |
AWS_REGION = os.getenv("AWS_REGION")
|
17 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
18 |
|
19 |
-
MAX_TOKENS = 1024
|
20 |
|
21 |
s3_client = boto3.client(
|
22 |
's3',
|
@@ -84,7 +83,6 @@ class S3DirectStream:
|
|
84 |
if not model_files:
|
85 |
raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
|
86 |
|
87 |
-
# Verificar que existe el archivo config.json
|
88 |
config_stream = await self.stream_from_s3(f"{model_prefix}/config.json")
|
89 |
config_data = config_stream.read()
|
90 |
|
@@ -148,7 +146,7 @@ def continue_generation(input_text, model, tokenizer, max_tokens=MAX_TOKENS):
|
|
148 |
input_text = tokenizer.decode(tokens[:max_tokens])
|
149 |
output = model.generate(input_ids=tokenizer.encode(input_text, return_tensors="pt").input_ids)
|
150 |
generated_text += tokenizer.decode(output[0], skip_special_tokens=True)
|
151 |
-
input_text = input_text[len(input_text):]
|
152 |
return generated_text
|
153 |
|
154 |
@app.post("/predict/")
|
@@ -163,7 +161,7 @@ async def predict(model_request: dict):
|
|
163 |
|
164 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
165 |
|
166 |
-
await streamer.create_s3_folders(model_name)
|
167 |
|
168 |
model = await streamer.load_model_from_s3(model_name)
|
169 |
tokenizer = await streamer.load_tokenizer_from_s3(model_name)
|
|
|
6 |
from fastapi.responses import JSONResponse
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
8 |
import asyncio
|
|
|
9 |
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
|
|
15 |
AWS_REGION = os.getenv("AWS_REGION")
|
16 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
17 |
|
18 |
+
MAX_TOKENS = 1024
|
19 |
|
20 |
s3_client = boto3.client(
|
21 |
's3',
|
|
|
83 |
if not model_files:
|
84 |
raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
|
85 |
|
|
|
86 |
config_stream = await self.stream_from_s3(f"{model_prefix}/config.json")
|
87 |
config_data = config_stream.read()
|
88 |
|
|
|
146 |
input_text = tokenizer.decode(tokens[:max_tokens])
|
147 |
output = model.generate(input_ids=tokenizer.encode(input_text, return_tensors="pt").input_ids)
|
148 |
generated_text += tokenizer.decode(output[0], skip_special_tokens=True)
|
149 |
+
input_text = input_text[len(input_text):]
|
150 |
return generated_text
|
151 |
|
152 |
@app.post("/predict/")
|
|
|
161 |
|
162 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
163 |
|
164 |
+
await streamer.create_s3_folders(model_name)
|
165 |
|
166 |
model = await streamer.load_model_from_s3(model_name)
|
167 |
tokenizer = await streamer.load_tokenizer_from_s3(model_name)
|