Hjgugugjhuhjggg commited on
Commit
fcc4b80
·
verified ·
1 Parent(s): 0598c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -80
app.py CHANGED
@@ -1,24 +1,29 @@
1
  import os
2
  import logging
3
- import threading
4
- import boto3
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteriaList, pipeline
6
- from fastapi import FastAPI, HTTPException, Request
7
- from pydantic import BaseModel, field_validator
8
- from huggingface_hub import hf_hub_download
9
- import requests
10
  import time
11
- import asyncio
12
- from fastapi.responses import StreamingResponse, Response
13
- import torch
14
  from io import BytesIO
15
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import soundfile as sf
 
 
 
17
 
18
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
19
 
20
- app = FastAPI()
21
-
22
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
23
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
24
  AWS_REGION = os.getenv("AWS_REGION")
@@ -40,6 +45,8 @@ class GenerateRequest(BaseModel):
40
  chunk_delay: float = 0.0
41
  stop_sequences: list[str] = []
42
 
 
 
43
  @field_validator("model_name")
44
  def model_name_cannot_be_empty(cls, v):
45
  if not v:
@@ -59,66 +66,42 @@ class S3ModelLoader:
59
  self.s3_client = s3_client
60
 
61
  def _get_s3_uri(self, model_name):
62
- return f"s3://{self.bucket_name}/lilmeaty_garca/{model_name.replace('/', '-')}"
63
-
64
- def _download_from_s3(self, model_name):
65
- try:
66
- logging.info(f"Attempting to load model {model_name} from S3...")
67
- model_files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=f"lilmeaty_garca/{model_name}")
68
- if "Contents" not in model_files:
69
- raise FileNotFoundError(f"Model files not found in S3 for {model_name}")
70
- s3_model_path = f"s3://{self.bucket_name}/lilmeaty_garca/{model_name.replace('/', '-')}"
71
- logging.info(f"Model {model_name} found on S3 at {s3_model_path}")
72
- return s3_model_path
73
- except Exception as e:
74
- logging.error(f"Error downloading from S3: {e}")
75
- raise HTTPException(status_code=500, detail=f"Error downloading model from S3: {e}")
76
-
77
- def download_model_from_huggingface(self, model_name):
78
- try:
79
- logging.info(f"Downloading model {model_name} from Hugging Face...")
80
- model_dir = hf_hub_download(model_name, token=HUGGINGFACE_HUB_TOKEN)
81
- model_files = os.listdir(model_dir)
82
- for model_file in model_files:
83
- s3_path = f"lilmeaty_garca/{model_name}/{model_file}"
84
- self.s3_client.upload_file(os.path.join(model_dir, model_file), self.bucket_name, s3_path)
85
- logging.info(f"Model {model_name} saved to S3 successfully.")
86
- except Exception as e:
87
- logging.error(f"Error downloading model {model_name} from Hugging Face: {e}")
88
- raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}")
89
-
90
- def download_all_models_in_background(self):
91
- models_url = "https://huggingface.co/api/models"
92
- try:
93
- response = requests.get(models_url)
94
- if response.status_code != 200:
95
- logging.error("Error getting Hugging Face model list.")
96
- raise HTTPException(status_code=500, detail="Error getting model list.")
97
- models = response.json()
98
- for model in models:
99
- model_name = model["id"]
100
- self.download_model_from_huggingface(model_name)
101
- except Exception as e:
102
- logging.error(f"Error downloading models in the background: {e}")
103
- raise HTTPException(status_code=500, detail="Error downloading models in the background.")
104
-
105
- def run_in_background(self):
106
- threading.Thread(target=self.download_all_models_in_background, daemon=True).start()
107
-
108
- def load_model_and_tokenizer(self, model_name):
109
  try:
110
- model_uri = self._download_from_s3(model_name)
111
- model = AutoModelForCausalLM.from_pretrained(model_uri)
112
- tokenizer = AutoTokenizer.from_pretrained(model_uri)
113
- logging.info(f"Model {model_name} loaded successfully from {model_uri}.")
 
 
 
 
 
114
  return model, tokenizer
115
- except Exception as e:
116
- logging.error(f"Error loading model {model_name}: {e}")
117
- raise HTTPException(status_code=500, detail=f"Error loading model {model_name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- @app.on_event("startup")
120
- async def startup_event():
121
- model_loader.run_in_background()
122
 
123
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
124
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
@@ -186,21 +169,29 @@ async def generate(request: Request, body: GenerateRequest):
186
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
187
  audio = generator(validated_body.input_text)
188
  audio_bytesio = BytesIO()
189
- sf.write(audio_bytesio, audio["samples"], audio["rate"], format="WAV")
190
- audio_bytesio.seek(0)
191
- return StreamingResponse(audio_bytesio, media_type="audio/wav")
192
 
193
  elif validated_body.task_type == "text-to-video":
194
- return {"error": "Text-to-video task type is not yet supported."}
 
 
 
 
 
 
195
  else:
196
- raise HTTPException(status_code=400, detail="Invalid task type")
197
 
 
 
 
 
198
  except Exception as e:
199
- logging.error(f"Error during generation: {e}")
200
- raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
201
 
202
- import uvicorn
203
 
204
  if __name__ == "__main__":
205
- uvicorn.run(app, host="0.0.0.0", port=7860)
206
-
 
1
  import os
2
  import logging
 
 
 
 
 
 
 
3
  import time
 
 
 
4
  from io import BytesIO
5
+ from typing import Union
6
+
7
+ from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
8
+ from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel, ValidationError, field_validator
10
+ from transformers import (
11
+ AutoConfig,
12
+ AutoModelForCausalLM,
13
+ AutoTokenizer,
14
+ pipeline,
15
+ GenerationConfig,
16
+ StoppingCriteriaList
17
+ )
18
+ import boto3
19
+ from huggingface_hub import hf_hub_download
20
  import soundfile as sf
21
+ import numpy as np
22
+ import torch
23
+ import uvicorn
24
 
25
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
26
 
 
 
27
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
28
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
29
  AWS_REGION = os.getenv("AWS_REGION")
 
45
  chunk_delay: float = 0.0
46
  stop_sequences: list[str] = []
47
 
48
+ model_config = {"protected_namespaces": ()}
49
+
50
  @field_validator("model_name")
51
  def model_name_cannot_be_empty(cls, v):
52
  if not v:
 
66
  self.s3_client = s3_client
67
 
68
  def _get_s3_uri(self, model_name):
69
+ return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
70
+
71
+ async def load_model_and_tokenizer(self, model_name):
72
+ s3_uri = self._get_s3_uri(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  try:
74
+ logging.info(f"Trying to load {model_name} from S3...")
75
+ config = AutoConfig.from_pretrained(s3_uri)
76
+ model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
78
+
79
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
81
+
82
+ logging.info(f"Loaded {model_name} from S3 successfully.")
83
  return model, tokenizer
84
+ except EnvironmentError:
85
+ logging.info(f"Model {model_name} not found in S3. Downloading...")
86
+ try:
87
+ config = AutoConfig.from_pretrained(model_name)
88
+ tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
89
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
90
+
91
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
93
+
94
+ logging.info(f"Downloaded {model_name} successfully.")
95
+ logging.info(f"Saving {model_name} to S3...")
96
+ model.save_pretrained(s3_uri)
97
+ tokenizer.save_pretrained(s3_uri)
98
+ logging.info(f"Saved {model_name} to S3 successfully.")
99
+ return model, tokenizer
100
+ except Exception as e:
101
+ logging.exception(f"Error downloading/uploading model: {e}")
102
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
103
 
104
+ app = FastAPI()
 
 
105
 
106
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
107
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
 
169
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
170
  audio = generator(validated_body.input_text)
171
  audio_bytesio = BytesIO()
172
+ sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
173
+ audio_bytes = audio_bytesio.getvalue()
174
+ return Response(content=audio_bytes, media_type="audio/wav")
175
 
176
  elif validated_body.task_type == "text-to-video":
177
+ try:
178
+ generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
179
+ video = generator(validated_body.input_text)
180
+ return Response(content=video, media_type="video/mp4")
181
+ except Exception as e:
182
+ raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
183
+
184
  else:
185
+ raise HTTPException(status_code=400, detail="Unsupported task type")
186
 
187
+ except HTTPException as e:
188
+ raise e
189
+ except ValidationError as e:
190
+ raise HTTPException(status_code=422, detail=e.errors())
191
  except Exception as e:
192
+ logging.exception(f"An unexpected error occurred: {e}")
193
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
194
 
 
195
 
196
  if __name__ == "__main__":
197
+ uvicorn.run(app, host="0.0.0.0", port=7860)