Hjgugugjhuhjggg commited on
Commit
16cb5fc
·
verified ·
1 Parent(s): ca9b40d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -27
app.py CHANGED
@@ -2,17 +2,15 @@ import os
2
  import logging
3
  import threading
4
  import boto3
5
- import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteriaList, AutoConfig
7
  from fastapi import FastAPI, HTTPException, Request
8
  from pydantic import BaseModel, field_validator
9
- from io import BytesIO
10
  from huggingface_hub import hf_hub_download
11
  import requests
 
12
  import asyncio
13
- import soundfile as sf
14
- import numpy as np
15
  from fastapi.responses import StreamingResponse, Response
 
16
 
17
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
18
 
@@ -67,30 +65,22 @@ class S3ModelLoader:
67
  if "Contents" not in model_files:
68
  raise FileNotFoundError(f"Model files not found in S3 for {model_name}")
69
 
70
- local_dir = f"/tmp/{model_name.replace('/', '-')}"
71
- os.makedirs(local_dir, exist_ok=True)
72
-
73
- for obj in model_files["Contents"]:
74
- file_key = obj["Key"]
75
- if file_key.endswith('/'):
76
- continue
77
-
78
- local_file_path = os.path.join(local_dir, os.path.basename(file_key))
79
- self.s3_client.download_file(self.bucket_name, file_key, local_file_path)
80
-
81
- return local_dir
82
  except Exception as e:
83
  logging.error(f"Error downloading from S3: {e}")
84
  raise HTTPException(status_code=500, detail=f"Error downloading model from S3: {e}")
85
 
86
  async def load_model_and_tokenizer(self, model_name):
87
  try:
88
- model_dir = await asyncio.to_thread(self._download_from_s3, model_name)
89
- config = AutoConfig.from_pretrained(model_dir)
90
- tokenizer = AutoTokenizer.from_pretrained(model_dir, config=config)
91
- model = AutoModelForCausalLM.from_pretrained(model_dir, config=config)
 
92
 
93
- logging.info(f"Model {model_name} loaded from S3 successfully.")
94
  return model, tokenizer
95
  except Exception as e:
96
  logging.exception(f"Error loading model: {e}")
@@ -128,6 +118,7 @@ class S3ModelLoader:
128
  async def startup_event():
129
  model_loader.run_in_background()
130
 
 
131
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
132
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
133
 
@@ -204,13 +195,15 @@ async def generate(request: Request, body: GenerateRequest):
204
  video = generator(validated_body.input_text)
205
  return Response(content=video, media_type="video/mp4")
206
  except Exception as e:
207
- raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
208
 
209
  else:
210
- raise HTTPException(status_code=400, detail="Unsupported task type")
211
 
212
- except HTTPException as e:
213
- raise e
214
  except Exception as e:
215
- logging.exception(f"An unexpected error occurred: {e}")
216
- raise HTTPException(status_code=500, detail="An unexpected error occurred.")
 
 
 
 
 
2
  import logging
3
  import threading
4
  import boto3
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteriaList, AutoConfig
6
  from fastapi import FastAPI, HTTPException, Request
7
  from pydantic import BaseModel, field_validator
 
8
  from huggingface_hub import hf_hub_download
9
  import requests
10
+ import time
11
  import asyncio
 
 
12
  from fastapi.responses import StreamingResponse, Response
13
+ import torch
14
 
15
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
16
 
 
65
  if "Contents" not in model_files:
66
  raise FileNotFoundError(f"Model files not found in S3 for {model_name}")
67
 
68
+ s3_model_path = f"s3://{self.bucket_name}/lilmeaty_garca/{model_name.replace('/', '-')}"
69
+ logging.info(f"Model {model_name} found on S3 at {s3_model_path}")
70
+ return s3_model_path
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
  logging.error(f"Error downloading from S3: {e}")
73
  raise HTTPException(status_code=500, detail=f"Error downloading model from S3: {e}")
74
 
75
  async def load_model_and_tokenizer(self, model_name):
76
  try:
77
+ s3_model_path = await asyncio.to_thread(self._download_from_s3, model_name)
78
+ # Load from S3 directly (no local storage)
79
+ config = AutoConfig.from_pretrained(s3_model_path)
80
+ tokenizer = AutoTokenizer.from_pretrained(s3_model_path, config=config)
81
+ model = AutoModelForCausalLM.from_pretrained(s3_model_path, config=config)
82
 
83
+ logging.info(f"Model {model_name} loaded successfully from S3.")
84
  return model, tokenizer
85
  except Exception as e:
86
  logging.exception(f"Error loading model: {e}")
 
118
  async def startup_event():
119
  model_loader.run_in_background()
120
 
121
+ # Initialize S3 client with boto3
122
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
123
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
124
 
 
195
  video = generator(validated_body.input_text)
196
  return Response(content=video, media_type="video/mp4")
197
  except Exception as e:
198
+ raise HTTPException(status_code=500, detail=f"Error generating video: {str(e)}")
199
 
200
  else:
201
+ raise HTTPException(status_code=400, detail="Invalid task type.")
202
 
 
 
203
  except Exception as e:
204
+ logging.error(f"Error processing request: {str(e)}")
205
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
206
+
207
+ if __name__ == "__main__":
208
+ import uvicorn
209
+ uvicorn.run(app, host="0.0.0.0", port=8000)