Hjgugugjhuhjggg commited on
Commit
c11082c
·
verified ·
1 Parent(s): 2957fb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -92
app.py CHANGED
@@ -1,19 +1,17 @@
1
  import os
2
  import logging
3
- import time
 
4
  from io import BytesIO
5
- from typing import Union
6
-
7
- from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
8
  from fastapi.responses import StreamingResponse
9
- from pydantic import BaseModel, ValidationError, field_validator
10
  from transformers import (
11
  AutoConfig,
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
  pipeline,
15
- GenerationConfig,
16
- StoppingCriteriaList
17
  )
18
  import boto3
19
  from huggingface_hub import hf_hub_download
@@ -21,8 +19,9 @@ import soundfile as sf
21
  import numpy as np
22
  import torch
23
  import uvicorn
 
24
 
25
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
26
 
27
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
28
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -32,7 +31,7 @@ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
32
 
33
  class GenerateRequest(BaseModel):
34
  model_name: str
35
- input_text: str = ""
36
  task_type: str
37
  temperature: float = 1.0
38
  max_new_tokens: int = 200
@@ -42,23 +41,6 @@ class GenerateRequest(BaseModel):
42
  repetition_penalty: float = 1.0
43
  num_return_sequences: int = 1
44
  do_sample: bool = True
45
- chunk_delay: float = 0.0
46
- stop_sequences: list[str] = []
47
-
48
- model_config = {"protected_namespaces": ()}
49
-
50
- @field_validator("model_name")
51
- def model_name_cannot_be_empty(cls, v):
52
- if not v:
53
- raise ValueError("model_name cannot be empty.")
54
- return v
55
-
56
- @field_validator("task_type")
57
- def task_type_must_be_valid(cls, v):
58
- valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
59
- if v not in valid_types:
60
- raise ValueError(f"task_type must be one of: {valid_types}")
61
- return v
62
 
63
  class S3ModelLoader:
64
  def __init__(self, bucket_name, s3_client):
@@ -67,39 +49,50 @@ class S3ModelLoader:
67
 
68
  def _get_s3_uri(self, model_name):
69
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
70
-
71
- async def load_model_and_tokenizer(self, model_name):
72
- s3_uri = self._get_s3_uri(model_name)
73
  try:
74
  logging.info(f"Trying to load {model_name} from S3...")
75
- config = AutoConfig.from_pretrained(s3_uri)
76
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
 
 
 
 
 
78
 
79
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
 
 
 
81
 
82
- logging.info(f"Loaded {model_name} from S3 successfully.")
 
 
 
 
 
 
 
83
  return model, tokenizer
84
- except EnvironmentError:
85
- logging.info(f"Model {model_name} not found in S3. Downloading...")
86
- try:
87
- config = AutoConfig.from_pretrained(model_name)
88
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
89
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
90
-
91
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
93
-
94
- logging.info(f"Downloaded {model_name} successfully.")
95
- logging.info(f"Saving {model_name} to S3...")
96
- model.save_pretrained(s3_uri)
97
- tokenizer.save_pretrained(s3_uri)
98
- logging.info(f"Saved {model_name} to S3 successfully.")
99
- return model, tokenizer
100
- except Exception as e:
101
- logging.exception(f"Error downloading/uploading model: {e}")
102
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
103
 
104
  app = FastAPI()
105
 
@@ -109,49 +102,44 @@ model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
109
  @app.post("/generate")
110
  async def generate(request: Request, body: GenerateRequest):
111
  try:
112
- validated_body = GenerateRequest(**body.model_dump())
113
- model, tokenizer = await model_loader.load_model_and_tokenizer(validated_body.model_name)
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
  model.to(device)
116
 
117
- if validated_body.task_type == "text-to-text":
118
  generation_config = GenerationConfig(
119
- temperature=validated_body.temperature,
120
- max_new_tokens=validated_body.max_new_tokens,
121
- top_p=validated_body.top_p,
122
- top_k=validated_body.top_k,
123
- repetition_penalty=validated_body.repetition_penalty,
124
- do_sample=validated_body.do_sample,
125
- num_return_sequences=validated_body.num_return_sequences
126
  )
127
 
128
  async def stream_text():
129
- input_text = validated_body.input_text
130
- generated_text = ""
131
  max_length = model.config.max_position_embeddings
 
132
 
133
  while True:
134
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
- input_length = encoded_input["input_ids"].shape[1]
136
  remaining_tokens = max_length - input_length
 
 
 
 
137
 
138
- if remaining_tokens <= 0:
139
- break
140
-
141
- generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
142
-
143
- stopping_criteria = StoppingCriteriaList(
144
- [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
145
- )
146
-
147
- output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
148
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
149
  generated_text += chunk
150
  yield chunk
151
- time.sleep(validated_body.chunk_delay)
152
- input_text = generated_text
 
153
 
154
- if validated_body.stream:
155
  return StreamingResponse(stream_text(), media_type="text/plain")
156
  else:
157
  generated_text = ""
@@ -159,24 +147,24 @@ async def generate(request: Request, body: GenerateRequest):
159
  generated_text += chunk
160
  return {"result": generated_text}
161
 
162
- elif validated_body.task_type == "text-to-image":
163
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
164
- image = generator(validated_body.input_text)[0]
165
  image_bytes = image.tobytes()
166
  return Response(content=image_bytes, media_type="image/png")
167
 
168
- elif validated_body.task_type == "text-to-speech":
169
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
170
- audio = generator(validated_body.input_text)
171
  audio_bytesio = BytesIO()
172
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
173
  audio_bytes = audio_bytesio.getvalue()
174
  return Response(content=audio_bytes, media_type="audio/wav")
175
 
176
- elif validated_body.task_type == "text-to-video":
177
  try:
178
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
179
- video = generator(validated_body.input_text)
180
  return Response(content=video, media_type="video/mp4")
181
  except Exception as e:
182
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
@@ -186,12 +174,31 @@ async def generate(request: Request, body: GenerateRequest):
186
 
187
  except HTTPException as e:
188
  raise e
189
- except ValidationError as e:
190
- raise HTTPException(status_code=422, detail=e.errors())
191
  except Exception as e:
192
- logging.exception(f"An unexpected error occurred: {e}")
193
- raise HTTPException(status_code=500, detail="An unexpected error occurred.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
 
 
 
195
 
196
  if __name__ == "__main__":
197
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import logging
3
+ import requests
4
+ import threading
5
  from io import BytesIO
6
+ from fastapi import FastAPI, HTTPException, Response, Request
 
 
7
  from fastapi.responses import StreamingResponse
8
+ from pydantic import BaseModel
9
  from transformers import (
10
  AutoConfig,
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  pipeline,
14
+ GenerationConfig
 
15
  )
16
  import boto3
17
  from huggingface_hub import hf_hub_download
 
19
  import numpy as np
20
  import torch
21
  import uvicorn
22
+ from tqdm import tqdm
23
 
24
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
25
 
26
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
27
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
31
 
32
  class GenerateRequest(BaseModel):
33
  model_name: str
34
+ input_text: str
35
  task_type: str
36
  temperature: float = 1.0
37
  max_new_tokens: int = 200
 
41
  repetition_penalty: float = 1.0
42
  num_return_sequences: int = 1
43
  do_sample: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class S3ModelLoader:
46
  def __init__(self, bucket_name, s3_client):
 
49
 
50
  def _get_s3_uri(self, model_name):
51
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
52
+
53
+ def download_model_from_s3(self, model_name):
 
54
  try:
55
  logging.info(f"Trying to load {model_name} from S3...")
56
+ config = AutoConfig.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
57
+ model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_name}", config=config)
58
+ tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
59
+ logging.info(f"Loaded {model_name} from S3 successfully.")
60
+ return model, tokenizer
61
+ except Exception as e:
62
+ logging.error(f"Error loading {model_name} from S3: {e}")
63
+ return None, None
64
 
65
+ async def load_model_and_tokenizer(self, model_name):
66
+ try:
67
+ model, tokenizer = self.download_model_from_s3(model_name)
68
+ if model is None or tokenizer is None:
69
+ model, tokenizer = await self.download_and_save_model_from_huggingface(model_name)
70
+ return model, tokenizer
71
+ except Exception as e:
72
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
73
 
74
+ async def download_and_save_model_from_huggingface(self, model_name):
75
+ try:
76
+ logging.info(f"Downloading {model_name} from Hugging Face...")
77
+ with tqdm(unit="B", unit_scale=True, desc=f"Downloading {model_name}") as t:
78
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, _tqdm=t)
79
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
80
+ logging.info(f"Downloaded {model_name} successfully.")
81
+ self.upload_model_to_s3(model_name, model, tokenizer)
82
  return model, tokenizer
83
+ except Exception as e:
84
+ logging.error(f"Error downloading model from Hugging Face: {e}")
85
+ raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}")
86
+
87
+ def upload_model_to_s3(self, model_name, model, tokenizer):
88
+ try:
89
+ s3_uri = self._get_s3_uri(model_name)
90
+ model.save_pretrained(s3_uri)
91
+ tokenizer.save_pretrained(s3_uri)
92
+ logging.info(f"Saved {model_name} to S3 successfully.")
93
+ except Exception as e:
94
+ logging.error(f"Error saving {model_name} to S3: {e}")
95
+ raise HTTPException(status_code=500, detail=f"Error saving model to S3: {e}")
 
 
 
 
 
 
96
 
97
  app = FastAPI()
98
 
 
102
  @app.post("/generate")
103
  async def generate(request: Request, body: GenerateRequest):
104
  try:
105
+ model, tokenizer = await model_loader.load_model_and_tokenizer(body.model_name)
 
106
  device = "cuda" if torch.cuda.is_available() else "cpu"
107
  model.to(device)
108
 
109
+ if body.task_type == "text-to-text":
110
  generation_config = GenerationConfig(
111
+ temperature=body.temperature,
112
+ max_new_tokens=body.max_new_tokens,
113
+ top_p=body.top_p,
114
+ top_k=body.top_k,
115
+ repetition_penalty=body.repetition_penalty,
116
+ do_sample=body.do_sample,
117
+ num_return_sequences=body.num_return_sequences
118
  )
119
 
120
  async def stream_text():
121
+ input_text = body.input_text
 
122
  max_length = model.config.max_position_embeddings
123
+ generated_text = ""
124
 
125
  while True:
126
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
127
+ input_length = inputs.input_ids.shape[1]
128
  remaining_tokens = max_length - input_length
129
+ if remaining_tokens < body.max_new_tokens:
130
+ generation_config.max_new_tokens = remaining_tokens
131
+ if remaining_tokens <= 0:
132
+ break
133
 
134
+ output = model.generate(**inputs, generation_config=generation_config)
 
 
 
 
 
 
 
 
 
135
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
136
  generated_text += chunk
137
  yield chunk
138
+ if len(tokenizer.encode(generated_text)) >= max_length:
139
+ break
140
+ input_text = chunk
141
 
142
+ if body.stream:
143
  return StreamingResponse(stream_text(), media_type="text/plain")
144
  else:
145
  generated_text = ""
 
147
  generated_text += chunk
148
  return {"result": generated_text}
149
 
150
+ elif body.task_type == "text-to-image":
151
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
152
+ image = generator(body.input_text)[0]
153
  image_bytes = image.tobytes()
154
  return Response(content=image_bytes, media_type="image/png")
155
 
156
+ elif body.task_type == "text-to-speech":
157
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
158
+ audio = generator(body.input_text)
159
  audio_bytesio = BytesIO()
160
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
161
  audio_bytes = audio_bytesio.getvalue()
162
  return Response(content=audio_bytes, media_type="audio/wav")
163
 
164
+ elif body.task_type == "text-to-video":
165
  try:
166
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
167
+ video = generator(body.input_text)
168
  return Response(content=video, media_type="video/mp4")
169
  except Exception as e:
170
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
 
174
 
175
  except HTTPException as e:
176
  raise e
 
 
177
  except Exception as e:
178
+ raise HTTPException(status_code=500, detail=str(e))
179
+
180
+ def download_all_models_in_background():
181
+ models_url = "https://huggingface.co/api/models"
182
+ try:
183
+ response = requests.get(models_url)
184
+ if response.status_code != 200:
185
+ logging.error("Error al obtener la lista de modelos de Hugging Face.")
186
+ raise HTTPException(status_code=500, detail="Error al obtener la lista de modelos.")
187
+
188
+ models = response.json()
189
+ for model in models:
190
+ model_name = model["id"]
191
+ model_loader.download_and_save_model_from_huggingface(model_name)
192
+ except Exception as e:
193
+ logging.error(f"Error al descargar modelos en segundo plano: {e}")
194
+ raise HTTPException(status_code=500, detail="Error al descargar modelos en segundo plano.")
195
+
196
+ def run_in_background():
197
+ threading.Thread(target=download_all_models_in_background, daemon=True).start()
198
 
199
+ @app.on_event("startup")
200
+ async def startup_event():
201
+ run_in_background()
202
 
203
  if __name__ == "__main__":
204
+ uvicorn.run(app, host="0.0.0.0", port=8000)