Hjgugugjhuhjggg commited on
Commit
0c65dc8
verified
1 Parent(s): 54fa818

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -111
app.py CHANGED
@@ -1,36 +1,40 @@
1
  import os
2
  import logging
3
- import requests
4
  import threading
5
  from io import BytesIO
6
- from fastapi import FastAPI, HTTPException, Response
 
 
 
 
 
 
 
7
  from fastapi.responses import StreamingResponse
8
- from pydantic import BaseModel
9
  from transformers import (
10
  AutoConfig,
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
 
13
  GenerationConfig,
14
- pipeline
15
  )
16
- import boto3
17
- import torch
18
  import uvicorn
19
 
20
- # Configuraci贸n de logging
21
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
 
23
- # Variables de entorno
24
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
25
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
26
  AWS_REGION = os.getenv("AWS_REGION")
27
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
28
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
29
 
30
- # Clase para la petici贸n de generaci贸n
31
  class GenerateRequest(BaseModel):
32
  model_name: str
33
- input_text: str
34
  task_type: str
35
  temperature: float = 1.0
36
  max_new_tokens: int = 200
@@ -40,11 +44,24 @@ class GenerateRequest(BaseModel):
40
  repetition_penalty: float = 1.0
41
  num_return_sequences: int = 1
42
  do_sample: bool = True
 
 
 
 
43
 
44
- class Config:
45
- protected_namespaces = ()
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Clase para cargar modelos desde S3
48
  class S3ModelLoader:
49
  def __init__(self, bucket_name, s3_client):
50
  self.bucket_name = bucket_name
@@ -53,105 +70,126 @@ class S3ModelLoader:
53
  def _get_s3_uri(self, model_name):
54
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
55
 
56
- def download_model_from_s3(self, model_name):
 
57
  try:
58
- config = AutoConfig.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
59
- model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_name}", config=config)
60
- tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
 
61
 
62
- # Asegurarse de que el `eos_token_id` y `pad_token_id` est茅n definidos
63
- if model.config.eos_token_id is None:
64
- model.config.eos_token_id = tokenizer.eos_token_id
65
- if model.config.pad_token_id is None:
66
- model.config.pad_token_id = tokenizer.pad_token_id
67
 
68
- return model, tokenizer
69
- except Exception:
70
- return None, None
 
 
 
 
 
 
 
 
 
71
 
72
  async def load_model_and_tokenizer(self, model_name):
73
  try:
74
- model, tokenizer = self.download_model_from_s3(model_name)
75
- if model is None or tokenizer is None:
76
- model, tokenizer = await self.download_and_save_model_from_huggingface(model_name)
 
 
 
77
  return model, tokenizer
78
  except Exception as e:
 
79
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
80
 
81
- async def download_and_save_model_from_huggingface(self, model_name):
82
  try:
83
- # Descarga del modelo sin tqdm
84
- model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HUGGINGFACE_HUB_TOKEN)
85
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HUGGINGFACE_HUB_TOKEN)
86
-
87
- # Asegurarse de que el `eos_token_id` y `pad_token_id` est茅n definidos
88
- if model.config.eos_token_id is None:
89
- model.config.eos_token_id = tokenizer.eos_token_id
90
- if model.config.pad_token_id is None:
91
- model.config.pad_token_id = tokenizer.pad_token_id
92
-
93
- self.upload_model_to_s3(model_name, model, tokenizer)
94
- return model, tokenizer
95
  except Exception as e:
96
- raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}")
97
 
98
- def upload_model_to_s3(self, model_name, model, tokenizer):
 
99
  try:
100
- s3_uri = self._get_s3_uri(model_name)
101
- model.save_pretrained(s3_uri)
102
- tokenizer.save_pretrained(s3_uri)
 
 
 
 
 
 
103
  except Exception as e:
104
- raise HTTPException(status_code=500, detail=f"Error saving model to S3: {e}")
 
 
 
 
 
 
 
 
105
 
106
- # Crear la instancia de FastAPI
107
  app = FastAPI()
108
 
109
- # Instanciar model_loader aqu铆
110
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
111
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
112
 
113
- # Funci贸n de generaci贸n asincr贸nica
114
  @app.post("/generate")
115
- async def generate(body: GenerateRequest):
116
  try:
117
- model, tokenizer = await model_loader.load_model_and_tokenizer(body.model_name)
 
118
  device = "cuda" if torch.cuda.is_available() else "cpu"
119
  model.to(device)
120
 
121
- if body.task_type == "text-to-text":
122
  generation_config = GenerationConfig(
123
- temperature=body.temperature,
124
- max_new_tokens=body.max_new_tokens,
125
- top_p=body.top_p,
126
- top_k=body.top_k,
127
- repetition_penalty=body.repetition_penalty,
128
- do_sample=body.do_sample,
129
- num_return_sequences=body.num_return_sequences
130
  )
131
 
132
  async def stream_text():
133
- input_text = body.input_text
134
- max_length = model.config.max_position_embeddings
135
  generated_text = ""
 
136
 
137
  while True:
138
- inputs = tokenizer(input_text, return_tensors="pt").to(device)
139
- input_length = inputs.input_ids.shape[1]
140
  remaining_tokens = max_length - input_length
141
- if remaining_tokens < body.max_new_tokens:
142
- generation_config.max_new_tokens = remaining_tokens
143
- if remaining_tokens <= 0:
144
- break
145
 
146
- output = model.generate(**inputs, generation_config=generation_config)
 
 
 
 
 
 
 
 
 
147
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
148
  generated_text += chunk
149
  yield chunk
150
- if len(tokenizer.encode(generated_text)) >= max_length:
151
- break
152
- input_text = chunk
153
 
154
- if body.stream:
155
  return StreamingResponse(stream_text(), media_type="text/plain")
156
  else:
157
  generated_text = ""
@@ -159,24 +197,24 @@ async def generate(body: GenerateRequest):
159
  generated_text += chunk
160
  return {"result": generated_text}
161
 
162
- elif body.task_type == "text-to-image":
163
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
164
- image = generator(body.input_text)[0]
165
  image_bytes = image.tobytes()
166
  return Response(content=image_bytes, media_type="image/png")
167
 
168
- elif body.task_type == "text-to-speech":
169
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
170
- audio = generator(body.input_text)
171
  audio_bytesio = BytesIO()
172
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
173
  audio_bytes = audio_bytesio.getvalue()
174
  return Response(content=audio_bytes, media_type="audio/wav")
175
 
176
- elif body.task_type == "text-to-video":
177
  try:
178
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
179
- video = generator(body.input_text)
180
  return Response(content=video, media_type="video/mp4")
181
  except Exception as e:
182
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
@@ -186,38 +224,11 @@ async def generate(body: GenerateRequest):
186
 
187
  except HTTPException as e:
188
  raise e
 
 
189
  except Exception as e:
190
- raise HTTPException(status_code=500, detail=str(e))
 
191
 
192
- # Descargar todos los modelos en segundo plano
193
- async def download_all_models_in_background():
194
- models_url = "https://huggingface.co/api/models"
195
- try:
196
- # Se obtiene la lista de modelos
197
- response = requests.get(models_url)
198
- if response.status_code != 200:
199
- raise HTTPException(status_code=500, detail="Error al obtener la lista de modelos.")
200
-
201
- models = response.json()
202
- for model in models:
203
- model_name = model["id"]
204
- # Verifica si ya est谩 en S3 antes de intentar descargarlo
205
- try:
206
- await model_loader.download_and_save_model_from_huggingface(model_name)
207
- except Exception as e:
208
- logging.error(f"Error descargando o guardando el modelo {model_name}: {str(e)}")
209
-
210
- except Exception as e:
211
- logging.error(f"Error al obtener modelos de Hugging Face: {str(e)}")
212
-
213
- # Funci贸n que corre en segundo plano para descargar modelos
214
- def run_in_background():
215
- threading.Thread(target=download_all_models_in_background, daemon=True).start()
216
-
217
- # Si este archivo se ejecuta directamente, inicia el servidor
218
  if __name__ == "__main__":
219
- # Ejecutar la descarga de modelos en segundo plano
220
- run_in_background()
221
-
222
- # Iniciar el servidor FastAPI
223
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import logging
3
+ import time
4
  import threading
5
  from io import BytesIO
6
+ from typing import Union
7
+ import requests
8
+ import boto3
9
+ import torch
10
+ import safetensors
11
+ import soundfile as sf
12
+ import numpy as np
13
+ from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
14
  from fastapi.responses import StreamingResponse
15
+ from pydantic import BaseModel, ValidationError, field_validator
16
  from transformers import (
17
  AutoConfig,
18
  AutoModelForCausalLM,
19
  AutoTokenizer,
20
+ pipeline,
21
  GenerationConfig,
22
+ StoppingCriteriaList
23
  )
24
+ from huggingface_hub import hf_hub_download
 
25
  import uvicorn
26
 
27
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
 
28
 
 
29
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
30
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
31
  AWS_REGION = os.getenv("AWS_REGION")
32
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
33
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
34
 
 
35
  class GenerateRequest(BaseModel):
36
  model_name: str
37
+ input_text: str = ""
38
  task_type: str
39
  temperature: float = 1.0
40
  max_new_tokens: int = 200
 
44
  repetition_penalty: float = 1.0
45
  num_return_sequences: int = 1
46
  do_sample: bool = True
47
+ chunk_delay: float = 0.0
48
+ stop_sequences: list[str] = []
49
+
50
+ model_config = {"protected_namespaces": ()}
51
 
52
+ @field_validator("model_name")
53
+ def model_name_cannot_be_empty(cls, v):
54
+ if not v:
55
+ raise ValueError("model_name cannot be empty.")
56
+ return v
57
+
58
+ @field_validator("task_type")
59
+ def task_type_must_be_valid(cls, v):
60
+ valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
61
+ if v not in valid_types:
62
+ raise ValueError(f"task_type must be one of: {valid_types}")
63
+ return v
64
 
 
65
  class S3ModelLoader:
66
  def __init__(self, bucket_name, s3_client):
67
  self.bucket_name = bucket_name
 
70
  def _get_s3_uri(self, model_name):
71
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
72
 
73
+ def _download_from_s3(self, model_name):
74
+ s3_uri = self._get_s3_uri(model_name)
75
  try:
76
+ logging.info(f"Attempting to load model {model_name} from S3...")
77
+ model_files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_name)
78
+ if "Contents" not in model_files:
79
+ raise FileNotFoundError(f"Model files not found in S3 for {model_name}")
80
 
81
+ local_dir = f"/tmp/{model_name.replace('/', '-')}"
82
+ os.makedirs(local_dir, exist_ok=True)
 
 
 
83
 
84
+ for obj in model_files["Contents"]:
85
+ file_key = obj["Key"]
86
+ if file_key.endswith('/'):
87
+ continue
88
+
89
+ local_file_path = os.path.join(local_dir, os.path.basename(file_key))
90
+ self.s3_client.download_file(self.bucket_name, file_key, local_file_path)
91
+
92
+ return local_dir
93
+ except Exception as e:
94
+ logging.error(f"Error downloading from S3: {e}")
95
+ raise HTTPException(status_code=500, detail=f"Error downloading model from S3: {e}")
96
 
97
  async def load_model_and_tokenizer(self, model_name):
98
  try:
99
+ model_dir = await self._download_from_s3(model_name)
100
+ config = AutoConfig.from_pretrained(model_dir)
101
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, config=config)
102
+ model = AutoModelForCausalLM.from_pretrained(model_dir, config=config)
103
+
104
+ logging.info(f"Model {model_name} loaded from S3 successfully.")
105
  return model, tokenizer
106
  except Exception as e:
107
+ logging.exception(f"Error loading model: {e}")
108
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
109
 
110
+ def download_model_from_huggingface(self, model_name):
111
  try:
112
+ logging.info(f"Downloading model {model_name} from Hugging Face...")
113
+ model_dir = hf_hub_download(model_name, token=HUGGINGFACE_HUB_TOKEN)
114
+ self.s3_client.upload_file(model_dir, self.bucket_name, model_name)
115
+ logging.info(f"Model {model_name} saved to S3 successfully.")
 
 
 
 
 
 
 
 
116
  except Exception as e:
117
+ logging.error(f"Error downloading model {model_name} from Hugging Face: {e}")
118
 
119
+ def download_all_models_in_background(self):
120
+ models_url = "https://huggingface.co/api/models"
121
  try:
122
+ response = requests.get(models_url)
123
+ if response.status_code != 200:
124
+ logging.error("Error getting Hugging Face model list.")
125
+ raise HTTPException(status_code=500, detail="Error getting model list.")
126
+
127
+ models = response.json()
128
+ for model in models:
129
+ model_name = model["id"]
130
+ self.download_model_from_huggingface(model_name)
131
  except Exception as e:
132
+ logging.error(f"Error downloading models in the background: {e}")
133
+ raise HTTPException(status_code=500, detail="Error downloading models in the background.")
134
+
135
+ def run_in_background(self):
136
+ threading.Thread(target=self.download_all_models_in_background, daemon=True).start()
137
+
138
+ @app.on_event("startup")
139
+ async def startup_event():
140
+ model_loader.run_in_background()
141
 
 
142
  app = FastAPI()
143
 
 
144
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
145
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
146
 
 
147
  @app.post("/generate")
148
+ async def generate(request: Request, body: GenerateRequest):
149
  try:
150
+ validated_body = GenerateRequest(**body.model_dump())
151
+ model, tokenizer = await model_loader.load_model_and_tokenizer(validated_body.model_name)
152
  device = "cuda" if torch.cuda.is_available() else "cpu"
153
  model.to(device)
154
 
155
+ if validated_body.task_type == "text-to-text":
156
  generation_config = GenerationConfig(
157
+ temperature=validated_body.temperature,
158
+ max_new_tokens=validated_body.max_new_tokens,
159
+ top_p=validated_body.top_p,
160
+ top_k=validated_body.top_k,
161
+ repetition_penalty=validated_body.repetition_penalty,
162
+ do_sample=validated_body.do_sample,
163
+ num_return_sequences=validated_body.num_return_sequences
164
  )
165
 
166
  async def stream_text():
167
+ input_text = validated_body.input_text
 
168
  generated_text = ""
169
+ max_length = model.config.max_position_embeddings
170
 
171
  while True:
172
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
173
+ input_length = encoded_input["input_ids"].shape[1]
174
  remaining_tokens = max_length - input_length
 
 
 
 
175
 
176
+ if remaining_tokens <= 0:
177
+ break
178
+
179
+ generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
180
+
181
+ stopping_criteria = StoppingCriteriaList(
182
+ [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
183
+ )
184
+
185
+ output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
186
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
187
  generated_text += chunk
188
  yield chunk
189
+ time.sleep(validated_body.chunk_delay)
190
+ input_text = generated_text
 
191
 
192
+ if validated_body.stream:
193
  return StreamingResponse(stream_text(), media_type="text/plain")
194
  else:
195
  generated_text = ""
 
197
  generated_text += chunk
198
  return {"result": generated_text}
199
 
200
+ elif validated_body.task_type == "text-to-image":
201
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
202
+ image = generator(validated_body.input_text)[0]
203
  image_bytes = image.tobytes()
204
  return Response(content=image_bytes, media_type="image/png")
205
 
206
+ elif validated_body.task_type == "text-to-speech":
207
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
208
+ audio = generator(validated_body.input_text)
209
  audio_bytesio = BytesIO()
210
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
211
  audio_bytes = audio_bytesio.getvalue()
212
  return Response(content=audio_bytes, media_type="audio/wav")
213
 
214
+ elif validated_body.task_type == "text-to-video":
215
  try:
216
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
217
+ video = generator(validated_body.input_text)
218
  return Response(content=video, media_type="video/mp4")
219
  except Exception as e:
220
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
 
224
 
225
  except HTTPException as e:
226
  raise e
227
+ except ValidationError as e:
228
+ raise HTTPException(status_code=422, detail=e.errors())
229
  except Exception as e:
230
+ logging.exception(f"An unexpected error occurred: {e}")
231
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  if __name__ == "__main__":
 
 
 
 
234
  uvicorn.run(app, host="0.0.0.0", port=7860)