Hjgugugjhuhjggg commited on
Commit
05818b6
verified
1 Parent(s): 14bbbee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -65
app.py CHANGED
@@ -5,27 +5,31 @@ from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel
6
  from transformers import (
7
  AutoConfig,
8
- pipeline,
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
- StoppingCriteriaList
 
13
  )
14
  import asyncio
15
  from io import BytesIO
 
 
16
 
17
- # Diccionario global para almacenar los tokens
18
  token_dict = {}
19
 
20
- # Setup para acceder a modelos en Hugging Face o S3
21
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
22
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
23
  AWS_REGION = os.getenv("AWS_REGION")
24
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
25
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
26
 
 
27
  app = FastAPI()
28
 
 
29
  class GenerateRequest(BaseModel):
30
  model_name: str
31
  input_text: str
@@ -42,14 +46,19 @@ class GenerateRequest(BaseModel):
42
  stop_sequences: list[str] = []
43
 
44
  class S3ModelLoader:
45
- def __init__(self, bucket_name, s3_client):
46
  self.bucket_name = bucket_name
47
- self.s3_client = s3_client
 
 
 
 
 
48
 
49
  def _get_s3_uri(self, model_name):
50
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
51
 
52
- async def load_model_and_tokenizer(self, model_name):
53
  if model_name in token_dict:
54
  return token_dict[model_name]
55
 
@@ -69,55 +78,14 @@ class S3ModelLoader:
69
  }
70
 
71
  return token_dict[model_name]
 
 
72
  except Exception as e:
73
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
74
 
75
- model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
76
-
77
- @app.post("/generate")
78
- async def generate(request: GenerateRequest):
79
- try:
80
- model_name = request.model_name
81
- input_text = request.input_text
82
- temperature = request.temperature
83
- max_new_tokens = request.max_new_tokens
84
- stream = request.stream
85
- top_p = request.top_p
86
- top_k = request.top_k
87
- repetition_penalty = request.repetition_penalty
88
- num_return_sequences = request.num_return_sequences
89
- do_sample = request.do_sample
90
- chunk_delay = request.chunk_delay
91
- stop_sequences = request.stop_sequences
92
-
93
- # Cargar modelo y tokenizer desde el S3
94
- model_data = await model_loader.load_model_and_tokenizer(model_name)
95
- model = model_data["model"]
96
- tokenizer = model_data["tokenizer"]
97
- pad_token_id = model_data["pad_token_id"]
98
- eos_token_id = model_data["eos_token_id"]
99
-
100
- device = "cuda" if torch.cuda.is_available() else "cpu"
101
- model.to(device)
102
-
103
- generation_config = GenerationConfig(
104
- temperature=temperature,
105
- max_new_tokens=max_new_tokens,
106
- top_p=top_p,
107
- top_k=top_k,
108
- repetition_penalty=repetition_penalty,
109
- do_sample=do_sample,
110
- num_return_sequences=num_return_sequences,
111
- )
112
-
113
- return StreamingResponse(
114
- stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
115
- media_type="text/plain"
116
- )
117
-
118
- except Exception as e:
119
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
120
 
 
121
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
122
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
123
  input_length = encoded_input["input_ids"].shape[1]
@@ -159,20 +127,52 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
159
  yield output_text
160
  return
161
 
162
- outputs = model.generate(
163
- **encoded_input,
164
- do_sample=generation_config.do_sample,
165
- max_new_tokens=generation_config.max_new_tokens,
166
- temperature=generation_config.temperature,
167
- top_p=generation_config.top_p,
168
- top_k=generation_config.top_k,
169
- repetition_penalty=generation_config.repetition_penalty,
170
- num_return_sequences=generation_config.num_return_sequences,
171
- stopping_criteria=stopping_criteria,
172
- output_scores=True,
173
- return_dict_in_generate=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  )
175
 
 
 
 
 
176
  @app.post("/generate-image")
177
  async def generate_image(request: GenerateRequest):
178
  try:
@@ -191,6 +191,7 @@ async def generate_image(request: GenerateRequest):
191
  except Exception as e:
192
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
193
 
 
194
  @app.post("/generate-text-to-speech")
195
  async def generate_text_to_speech(request: GenerateRequest):
196
  try:
@@ -209,6 +210,7 @@ async def generate_text_to_speech(request: GenerateRequest):
209
  except Exception as e:
210
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
211
 
 
212
  @app.post("/generate-video")
213
  async def generate_video(request: GenerateRequest):
214
  try:
@@ -226,6 +228,7 @@ async def generate_video(request: GenerateRequest):
226
  except Exception as e:
227
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
228
 
 
229
  if __name__ == "__main__":
230
  import uvicorn
231
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  from pydantic import BaseModel
6
  from transformers import (
7
  AutoConfig,
 
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
10
  GenerationConfig,
11
+ StoppingCriteriaList,
12
+ pipeline
13
  )
14
  import asyncio
15
  from io import BytesIO
16
+ from botocore.exceptions import NoCredentialsError
17
+ import boto3
18
 
19
+ # Diccionario global para almacenar los tokens y configuraciones de los modelos
20
  token_dict = {}
21
 
22
+ # Configuraci贸n para acceso a modelos en Hugging Face o S3
23
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
24
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
25
  AWS_REGION = os.getenv("AWS_REGION")
26
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
27
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
28
 
29
+ # Inicializaci贸n de la aplicaci贸n FastAPI
30
  app = FastAPI()
31
 
32
+ # Modelo de la solicitud para la API
33
  class GenerateRequest(BaseModel):
34
  model_name: str
35
  input_text: str
 
46
  stop_sequences: list[str] = []
47
 
48
  class S3ModelLoader:
49
+ def __init__(self, bucket_name, aws_access_key_id=None, aws_secret_access_key=None, aws_region=None):
50
  self.bucket_name = bucket_name
51
+ self.s3_client = boto3.client(
52
+ 's3',
53
+ aws_access_key_id=aws_access_key_id,
54
+ aws_secret_access_key=aws_secret_access_key,
55
+ region_name=aws_region
56
+ )
57
 
58
  def _get_s3_uri(self, model_name):
59
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
60
 
61
+ def load_model_and_tokenizer(self, model_name):
62
  if model_name in token_dict:
63
  return token_dict[model_name]
64
 
 
78
  }
79
 
80
  return token_dict[model_name]
81
+ except NoCredentialsError:
82
+ raise HTTPException(status_code=500, detail="AWS credentials not found.")
83
  except Exception as e:
84
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
85
 
86
+ model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Funci贸n para hacer streaming de texto, generando un token a la vez
89
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
90
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
91
  input_length = encoded_input["input_ids"].shape[1]
 
127
  yield output_text
128
  return
129
 
130
+ # Endpoint para la generaci贸n de texto
131
+ @app.post("/generate")
132
+ async def generate(request: GenerateRequest):
133
+ try:
134
+ model_name = request.model_name
135
+ input_text = request.input_text
136
+ temperature = request.temperature
137
+ max_new_tokens = request.max_new_tokens
138
+ stream = request.stream
139
+ top_p = request.top_p
140
+ top_k = request.top_k
141
+ repetition_penalty = request.repetition_penalty
142
+ num_return_sequences = request.num_return_sequences
143
+ do_sample = request.do_sample
144
+ chunk_delay = request.chunk_delay
145
+ stop_sequences = request.stop_sequences
146
+
147
+ # Cargar el modelo y el tokenizer desde el S3
148
+ model_data = model_loader.load_model_and_tokenizer(model_name)
149
+ model = model_data["model"]
150
+ tokenizer = model_data["tokenizer"]
151
+ pad_token_id = model_data["pad_token_id"]
152
+ eos_token_id = model_data["eos_token_id"]
153
+
154
+ device = "cuda" if torch.cuda.is_available() else "cpu"
155
+ model.to(device)
156
+
157
+ generation_config = GenerationConfig(
158
+ temperature=temperature,
159
+ max_new_tokens=max_new_tokens,
160
+ top_p=top_p,
161
+ top_k=top_k,
162
+ repetition_penalty=repetition_penalty,
163
+ do_sample=do_sample,
164
+ num_return_sequences=num_return_sequences,
165
+ )
166
+
167
+ return StreamingResponse(
168
+ stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
169
+ media_type="text/plain"
170
  )
171
 
172
+ except Exception as e:
173
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
174
+
175
+ # Endpoint para la generaci贸n de im谩genes
176
  @app.post("/generate-image")
177
  async def generate_image(request: GenerateRequest):
178
  try:
 
191
  except Exception as e:
192
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
193
 
194
+ # Endpoint para la generaci贸n de texto a voz
195
  @app.post("/generate-text-to-speech")
196
  async def generate_text_to_speech(request: GenerateRequest):
197
  try:
 
210
  except Exception as e:
211
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
212
 
213
+ # Endpoint para la generaci贸n de video
214
  @app.post("/generate-video")
215
  async def generate_video(request: GenerateRequest):
216
  try:
 
228
  except Exception as e:
229
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
230
 
231
+ # Configuraci贸n para ejecutar el servidor
232
  if __name__ == "__main__":
233
  import uvicorn
234
+ uvicorn.run(app, host="0.0.0.0", port=7860)