Hjgugugjhuhjggg commited on
Commit
059d70c
verified
1 Parent(s): a9f88be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -6,7 +6,6 @@ from fastapi import FastAPI, HTTPException
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
  import asyncio
9
- import concurrent.futures
10
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
@@ -16,7 +15,7 @@ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
16
  AWS_REGION = os.getenv("AWS_REGION")
17
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
18
 
19
- MAX_TOKENS = 1024 # Limite de tokens por fragmento
20
 
21
  s3_client = boto3.client(
22
  's3',
@@ -84,7 +83,6 @@ class S3DirectStream:
84
  if not model_files:
85
  raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
86
 
87
- # Verificar que existe el archivo config.json
88
  config_stream = await self.stream_from_s3(f"{model_prefix}/config.json")
89
  config_data = config_stream.read()
90
 
@@ -148,7 +146,7 @@ def continue_generation(input_text, model, tokenizer, max_tokens=MAX_TOKENS):
148
  input_text = tokenizer.decode(tokens[:max_tokens])
149
  output = model.generate(input_ids=tokenizer.encode(input_text, return_tensors="pt").input_ids)
150
  generated_text += tokenizer.decode(output[0], skip_special_tokens=True)
151
- input_text = input_text[len(input_text):] # Si la entrada se agot贸, ya no hay m谩s que procesar
152
  return generated_text
153
 
154
  @app.post("/predict/")
@@ -163,7 +161,7 @@ async def predict(model_request: dict):
163
 
164
  streamer = S3DirectStream(S3_BUCKET_NAME)
165
 
166
- await streamer.create_s3_folders(model_name) # Crear las carpetas si no existen
167
 
168
  model = await streamer.load_model_from_s3(model_name)
169
  tokenizer = await streamer.load_tokenizer_from_s3(model_name)
 
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
  import asyncio
 
9
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
 
15
  AWS_REGION = os.getenv("AWS_REGION")
16
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
17
 
18
+ MAX_TOKENS = 1024
19
 
20
  s3_client = boto3.client(
21
  's3',
 
83
  if not model_files:
84
  raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados en S3.")
85
 
 
86
  config_stream = await self.stream_from_s3(f"{model_prefix}/config.json")
87
  config_data = config_stream.read()
88
 
 
146
  input_text = tokenizer.decode(tokens[:max_tokens])
147
  output = model.generate(input_ids=tokenizer.encode(input_text, return_tensors="pt").input_ids)
148
  generated_text += tokenizer.decode(output[0], skip_special_tokens=True)
149
+ input_text = input_text[len(input_text):]
150
  return generated_text
151
 
152
  @app.post("/predict/")
 
161
 
162
  streamer = S3DirectStream(S3_BUCKET_NAME)
163
 
164
+ await streamer.create_s3_folders(model_name)
165
 
166
  model = await streamer.load_model_from_s3(model_name)
167
  tokenizer = await streamer.load_tokenizer_from_s3(model_name)