Hjgugugjhuhjggg commited on
Commit
d05ede6
·
verified ·
1 Parent(s): e079cb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -82
app.py CHANGED
@@ -3,15 +3,21 @@ import torch
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse, JSONResponse
5
  from pydantic import BaseModel, field_validator
6
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteria, StoppingCriteriaList, pipeline
 
 
 
 
 
 
 
7
  import boto3
8
  import uvicorn
9
  import asyncio
10
  import json
11
  from huggingface_hub import login
12
  from botocore.exceptions import NoCredentialsError
13
- from functools import cached_property
14
- import base64
15
 
16
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
17
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -19,79 +25,107 @@ AWS_REGION = os.getenv("AWS_REGION")
19
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
20
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
21
 
 
22
  if HUGGINGFACE_HUB_TOKEN:
23
- login(token=HUGGINGFACE_HUB_TOKEN,add_to_git_credential=False)
 
 
 
 
 
24
 
25
- s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,aws_secret_access_key=AWS_SECRET_ACCESS_KEY,region_name=AWS_REGION)
26
  app = FastAPI()
27
 
28
  class GenerateRequest(BaseModel):
29
  model_name: str
30
  input_text: str = ""
31
  task_type: str
32
- temperature: float = 0.01
33
- max_new_tokens: int = 20
34
  stream: bool = True
35
  top_p: float = 1.0
36
- top_k: int = 1
37
  repetition_penalty: float = 1.0
38
  num_return_sequences: int = 1
39
- do_sample: bool = False
40
  stop_sequences: list[str] = []
41
- quantize: bool = False
42
- use_onnx: bool = False
43
  @field_validator("model_name")
44
  def model_name_cannot_be_empty(cls, v):
45
  if not v:
46
  raise ValueError("model_name cannot be empty.")
47
  return v
 
48
  @field_validator("task_type")
49
  def task_type_must_be_valid(cls, v):
50
- valid_types = ["text-to-text", "text-to-image","text-to-speech", "text-to-video"]
 
51
  if v not in valid_types:
52
  raise ValueError(f"task_type must be one of: {valid_types}")
53
  return v
 
 
 
54
  class S3ModelLoader:
55
  def __init__(self, bucket_name, s3_client):
56
  self.bucket_name = bucket_name
57
  self.s3_client = s3_client
58
- self.model_cache = {}
59
  def _get_s3_uri(self, model_name):
60
- return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
61
- async def _load_model_and_tokenizer(self, model_name, quantize, use_onnx):
 
 
 
 
 
62
  s3_uri = self._get_s3_uri(model_name)
63
  try:
64
- config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
65
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False).to(self.device)
66
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=False)
67
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
68
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return model, tokenizer
70
  except (EnvironmentError, NoCredentialsError):
71
  try:
72
- config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
73
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
74
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN).to(self.device)
75
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
76
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return model, tokenizer
78
  except Exception as e:
79
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
80
- @cached_property
81
- def device(self):
82
- return torch.device("cpu")
83
- async def get_model_and_tokenizer(self, model_name, quantize, use_onnx):
84
- key = f"{model_name}-{quantize}-{use_onnx}"
85
- if key not in self.model_cache:
86
- model, tokenizer = await self._load_model_and_tokenizer(model_name, quantize, use_onnx)
87
- self.model_cache[key] = {"model":model, "tokenizer":tokenizer}
88
- return self.model_cache[key]["model"], self.model_cache[key]["tokenizer"]
89
- async def get_pipeline(self, model_name, task_type):
90
- if model_name not in self.model_cache:
91
- config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
92
- model = pipeline(task_type, model=model_name,device=self.device, config=config)
93
- self.model_cache[model_name] = {"model":model}
94
- return self.model_cache[model_name]["model"]
95
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
96
 
97
  @app.post("/generate")
@@ -109,96 +143,243 @@ async def generate(request: GenerateRequest):
109
  num_return_sequences = request.num_return_sequences
110
  do_sample = request.do_sample
111
  stop_sequences = request.stop_sequences
112
- quantize = request.quantize
113
- use_onnx = request.use_onnx
114
- model, tokenizer = await model_loader.get_model_and_tokenizer(model_name, quantize, use_onnx)
 
 
115
  if "text-to-text" == task_type:
116
- generation_config = GenerationConfig(temperature=temperature,max_new_tokens=max_new_tokens,top_p=top_p,top_k=top_k,repetition_penalty=repetition_penalty,do_sample=do_sample,num_return_sequences=num_return_sequences,eos_token_id = tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
 
117
  if stream:
118
- return StreamingResponse(stream_text(model, tokenizer, input_text,generation_config, stop_sequences),media_type="text/plain")
 
 
 
 
 
119
  else:
120
- result = await generate_text(model, tokenizer, input_text,generation_config, stop_sequences)
 
 
121
  return JSONResponse({"text": result, "is_end": True})
122
  else:
123
  return HTTPException(status_code=400, detail="Task type not text-to-text")
 
124
  except Exception as e:
125
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
 
126
  class StopOnSequences(StoppingCriteria):
127
  def __init__(self, stop_sequences, tokenizer):
128
  self.stop_sequences = stop_sequences
129
  self.tokenizer = tokenizer
130
  self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False) for seq in stop_sequences]
 
131
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
132
  decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
133
  for stop_sequence in self.stop_sequences:
134
  if stop_sequence in decoded_text:
135
  return True
136
  return False
137
- async def stream_text(model, tokenizer, input_text,generation_config, stop_sequences):
138
- encoded_input = tokenizer(input_text, return_tensors="pt",truncation=True).to(model_loader.device)
 
 
 
 
 
 
 
 
139
  stop_criteria = StopOnSequences(stop_sequences, tokenizer)
140
  stopping_criteria = StoppingCriteriaList([stop_criteria])
141
- async for token in _stream_text(model, encoded_input, tokenizer, generation_config, stop_criteria, stopping_criteria):
142
- yield json.dumps({"text":token, "is_end": False}) + "\n"
143
- yield json.dumps({"text":"", "is_end": True}) + "\n"
144
- async def _stream_text(model, encoded_input, tokenizer, generation_config, stop_criteria, stopping_criteria):
145
  output_text = ""
 
146
  while True:
147
- outputs = await asyncio.to_thread(model.generate,**encoded_input,do_sample=generation_config.do_sample,max_new_tokens=generation_config.max_new_tokens,temperature=generation_config.temperature,top_p=generation_config.top_p,top_k=generation_config.top_k,repetition_penalty=generation_config.repetition_penalty,num_return_sequences=generation_config.num_return_sequences,output_scores=True,return_dict_in_generate=True,stopping_criteria=stopping_criteria)
148
- new_text = tokenizer.decode(outputs.sequences[0][len(encoded_input["input_ids"][0]):],skip_special_tokens=True)
149
- if len(new_text) == 0:
150
- if not stop_criteria(outputs.sequences, None):
151
- for token in output_text.split():
152
- yield token
153
- break
154
- output_text += new_text
155
- for token in new_text.split():
156
- yield token
157
- if stop_criteria(outputs.sequences, None):
158
- break
159
- encoded_input = tokenizer(output_text, return_tensors="pt",truncation=True).to(model_loader.device)
160
- output_text=""
161
- async def generate_text(model, tokenizer, input_text,generation_config, stop_sequences):
162
- encoded_input = tokenizer(input_text, return_tensors="pt",truncation=True).to(model_loader.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  stop_criteria = StopOnSequences(stop_sequences, tokenizer)
164
  stopping_criteria = StoppingCriteriaList([stop_criteria])
165
- outputs = await asyncio.to_thread(model.generate,**encoded_input,do_sample=generation_config.do_sample,max_new_tokens=generation_config.max_new_tokens,temperature=generation_config.temperature,top_p=generation_config.top_p,top_k=generation_config.top_k,repetition_penalty=generation_config.repetition_penalty,num_return_sequences=num_return_sequences,output_scores=True,return_dict_in_generate=True,stopping_criteria=stopping_criteria)
166
- generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  return generated_text
 
168
  @app.post("/generate-image")
169
  async def generate_image(request: GenerateRequest):
170
  try:
171
  validated_body = request
172
- model = await model_loader.get_pipeline(validated_body.model_name, "text-to-image")
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  image = model(validated_body.input_text)[0]
 
174
  image_data = list(image.getdata())
 
175
  return json.dumps({"image_data": image_data, "is_end": True})
 
176
  except Exception as e:
177
- raise HTTPException(status_code=500,detail=f"Internal server error: {str(e)}")
 
 
 
 
 
178
  @app.post("/generate-text-to-speech")
179
  async def generate_text_to_speech(request: GenerateRequest):
180
  try:
181
  validated_body = request
182
- audio_generator = await model_loader.get_pipeline(validated_body.model_name, "text-to-speech")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  audio = audio_generator(validated_body.input_text)
 
 
184
  audio_bytes = audio["audio"]
 
185
  audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
 
186
  return json.dumps({"audio": audio_base64, "is_end": True})
 
187
  except Exception as e:
188
- raise HTTPException(status_code=500,detail=f"Internal server error: {str(e)}")
 
 
 
 
 
189
  @app.post("/generate-video")
190
  async def generate_video(request: GenerateRequest):
191
  try:
192
  validated_body = request
193
- video_generator = await model_loader.get_pipeline(validated_body.model_name, "text-to-video")
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  video = video_generator(validated_body.input_text)
 
 
195
  video_base64 = base64.b64encode(video).decode('utf-8')
 
196
  return json.dumps({"video": video_base64, "is_end": True})
 
197
  except Exception as e:
198
- raise HTTPException(status_code=500,detail=f"Internal server error: {str(e)}")
199
- async def load_all_models():
200
- pass
 
 
201
  if __name__ == "__main__":
202
- import asyncio
203
- asyncio.run(load_all_models())
204
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse, JSONResponse
5
  from pydantic import BaseModel, field_validator
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ GenerationConfig,
11
+ StoppingCriteria,
12
+ StoppingCriteriaList,
13
+ )
14
  import boto3
15
  import uvicorn
16
  import asyncio
17
  import json
18
  from huggingface_hub import login
19
  from botocore.exceptions import NoCredentialsError
20
+
 
21
 
22
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
23
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
25
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
26
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
27
 
28
+
29
  if HUGGINGFACE_HUB_TOKEN:
30
+ login(token=HUGGINGFACE_HUB_TOKEN,
31
+ add_to_git_credential=False)
32
+
33
+ s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,
34
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
35
+ region_name=AWS_REGION)
36
 
 
37
  app = FastAPI()
38
 
39
  class GenerateRequest(BaseModel):
40
  model_name: str
41
  input_text: str = ""
42
  task_type: str
43
+ temperature: float = 1.0
44
+ max_new_tokens: int = 3
45
  stream: bool = True
46
  top_p: float = 1.0
47
+ top_k: int = 50
48
  repetition_penalty: float = 1.0
49
  num_return_sequences: int = 1
50
+ do_sample: bool = True
51
  stop_sequences: list[str] = []
52
+
 
53
  @field_validator("model_name")
54
  def model_name_cannot_be_empty(cls, v):
55
  if not v:
56
  raise ValueError("model_name cannot be empty.")
57
  return v
58
+
59
  @field_validator("task_type")
60
  def task_type_must_be_valid(cls, v):
61
+ valid_types = ["text-to-text", "text-to-image",
62
+ "text-to-speech", "text-to-video"]
63
  if v not in valid_types:
64
  raise ValueError(f"task_type must be one of: {valid_types}")
65
  return v
66
+
67
+ model_data = {} # Global dictionary to store model data
68
+
69
  class S3ModelLoader:
70
  def __init__(self, bucket_name, s3_client):
71
  self.bucket_name = bucket_name
72
  self.s3_client = s3_client
73
+
74
  def _get_s3_uri(self, model_name):
75
+ return f"s3://{self.bucket_name}/" \
76
+ f"{model_name.replace('/', '-')}"
77
+
78
+ async def load_model_and_tokenizer(self, model_name):
79
+ if model_name in model_data:
80
+ return model_data[model_name]["model"], model_data[model_name]["tokenizer"]
81
+
82
  s3_uri = self._get_s3_uri(model_name)
83
  try:
84
+
85
+ config = AutoConfig.from_pretrained(
86
+ s3_uri, local_files_only=False
87
+ )
88
+
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ s3_uri, config=config, local_files_only=False
91
+ )
92
+
93
+ tokenizer = AutoTokenizer.from_pretrained(
94
+ s3_uri, config=config, local_files_only=False
95
+ )
96
+
97
+ if tokenizer.eos_token_id is not None and \
98
+ tokenizer.pad_token_id is None:
99
+ tokenizer.pad_token_id = config.pad_token_id \
100
+ or tokenizer.eos_token_id
101
+ model_data[model_name] = {"model":model, "tokenizer":tokenizer}
102
  return model, tokenizer
103
  except (EnvironmentError, NoCredentialsError):
104
  try:
105
+ config = AutoConfig.from_pretrained(
106
+ model_name, token=HUGGINGFACE_HUB_TOKEN
107
+ )
108
+ tokenizer = AutoTokenizer.from_pretrained(
109
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
110
+ )
111
+
112
+ model = AutoModelForCausalLM.from_pretrained(
113
+ model_name, config=config, token=HUGGINGFACE_HUB_TOKEN
114
+ )
115
+
116
+
117
+ if tokenizer.eos_token_id is not None and \
118
+ tokenizer.pad_token_id is None:
119
+ tokenizer.pad_token_id = config.pad_token_id \
120
+ or tokenizer.eos_token_id
121
+
122
+ model_data[model_name] = {"model":model, "tokenizer":tokenizer}
123
  return model, tokenizer
124
  except Exception as e:
125
+ raise HTTPException(
126
+ status_code=500, detail=f"Error loading model: {e}"
127
+ )
128
+
 
 
 
 
 
 
 
 
 
 
 
 
129
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
130
 
131
  @app.post("/generate")
 
143
  num_return_sequences = request.num_return_sequences
144
  do_sample = request.do_sample
145
  stop_sequences = request.stop_sequences
146
+
147
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
148
+ device = "cuda" if torch.cuda.is_available() else "cpu"
149
+ model.to(device)
150
+
151
  if "text-to-text" == task_type:
152
+ generation_config = GenerationConfig(
153
+ temperature=temperature,
154
+ max_new_tokens=max_new_tokens,
155
+ top_p=top_p,
156
+ top_k=top_k,
157
+ repetition_penalty=repetition_penalty,
158
+ do_sample=do_sample,
159
+ num_return_sequences=num_return_sequences,
160
+ eos_token_id = tokenizer.eos_token_id
161
+ )
162
  if stream:
163
+ return StreamingResponse(
164
+ stream_text(model, tokenizer, input_text,
165
+ generation_config, stop_sequences,
166
+ device),
167
+ media_type="text/plain"
168
+ )
169
  else:
170
+ result = await generate_text(model, tokenizer, input_text,
171
+ generation_config, stop_sequences,
172
+ device)
173
  return JSONResponse({"text": result, "is_end": True})
174
  else:
175
  return HTTPException(status_code=400, detail="Task type not text-to-text")
176
+
177
  except Exception as e:
178
+ raise HTTPException(
179
+ status_code=500, detail=f"Internal server error: {str(e)}"
180
+ )
181
+
182
  class StopOnSequences(StoppingCriteria):
183
  def __init__(self, stop_sequences, tokenizer):
184
  self.stop_sequences = stop_sequences
185
  self.tokenizer = tokenizer
186
  self.stop_ids = [tokenizer.encode(seq, add_special_tokens=False) for seq in stop_sequences]
187
+
188
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
189
+
190
  decoded_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
191
+
192
  for stop_sequence in self.stop_sequences:
193
  if stop_sequence in decoded_text:
194
  return True
195
  return False
196
+
197
+ async def stream_text(model, tokenizer, input_text,
198
+ generation_config, stop_sequences,
199
+ device):
200
+
201
+ encoded_input = tokenizer(
202
+ input_text, return_tensors="pt",
203
+ truncation=True
204
+ ).to(device)
205
+
206
  stop_criteria = StopOnSequences(stop_sequences, tokenizer)
207
  stopping_criteria = StoppingCriteriaList([stop_criteria])
208
+
 
 
 
209
  output_text = ""
210
+
211
  while True:
212
+
213
+ outputs = await asyncio.to_thread(model.generate,
214
+ **encoded_input,
215
+ do_sample=generation_config.do_sample,
216
+ max_new_tokens=generation_config.max_new_tokens,
217
+ temperature=generation_config.temperature,
218
+ top_p=generation_config.top_p,
219
+ top_k=generation_config.top_k,
220
+ repetition_penalty=generation_config.repetition_penalty,
221
+ num_return_sequences=generation_config.num_return_sequences,
222
+ output_scores=True,
223
+ return_dict_in_generate=True,
224
+ stopping_criteria=stopping_criteria
225
+ )
226
+
227
+ new_text = tokenizer.decode(
228
+ outputs.sequences[0][len(encoded_input["input_ids"][0]):],
229
+ skip_special_tokens=True
230
+ )
231
+
232
+ if len(new_text) == 0:
233
+ if not stop_criteria(outputs.sequences, None):
234
+ for text in output_text.split():
235
+ yield json.dumps({"text": text, "is_end": False}) + "\n"
236
+ yield json.dumps({"text": "", "is_end": True}) + "\n"
237
+ break
238
+
239
+ output_text += new_text
240
+
241
+ for text in new_text.split():
242
+ yield json.dumps({"text": text, "is_end": False}) + "\n"
243
+
244
+ if stop_criteria(outputs.sequences, None):
245
+ yield json.dumps({"text": "", "is_end": True}) + "\n"
246
+ break
247
+
248
+ encoded_input = tokenizer(
249
+ output_text, return_tensors="pt",
250
+ truncation=True
251
+ ).to(device)
252
+ output_text = ""
253
+
254
+
255
+ async def generate_text(model, tokenizer, input_text,
256
+ generation_config, stop_sequences,
257
+ device):
258
+ encoded_input = tokenizer(
259
+ input_text, return_tensors="pt",
260
+ truncation=True
261
+ ).to(device)
262
+
263
  stop_criteria = StopOnSequences(stop_sequences, tokenizer)
264
  stopping_criteria = StoppingCriteriaList([stop_criteria])
265
+
266
+ outputs = await asyncio.to_thread(model.generate,
267
+ **encoded_input,
268
+ do_sample=generation_config.do_sample,
269
+ max_new_tokens=generation_config.max_new_tokens,
270
+ temperature=generation_config.temperature,
271
+ top_p=generation_config.top_p,
272
+ top_k=generation_config.top_k,
273
+ repetition_penalty=generation_config.repetition_penalty,
274
+ num_return_sequences=generation_config.num_return_sequences,
275
+ output_scores=True,
276
+ return_dict_in_generate=True,
277
+ stopping_criteria=stopping_criteria
278
+ )
279
+
280
+
281
+ generated_text = tokenizer.decode(
282
+ outputs.sequences[0], skip_special_tokens=True
283
+ )
284
+
285
  return generated_text
286
+
287
  @app.post("/generate-image")
288
  async def generate_image(request: GenerateRequest):
289
  try:
290
  validated_body = request
291
+ device = "cuda" if torch.cuda.is_available() else "cpu"
292
+
293
+ if validated_body.model_name not in model_data:
294
+ config = AutoConfig.from_pretrained(
295
+ validated_body.model_name, token=HUGGINGFACE_HUB_TOKEN
296
+ )
297
+ model = pipeline(
298
+ "text-to-image", model=validated_body.model_name,
299
+ device=device, config=config
300
+ )
301
+ model_data[validated_body.model_name] = {"model":model}
302
+ else:
303
+ model = model_data[validated_body.model_name]["model"]
304
+
305
  image = model(validated_body.input_text)[0]
306
+
307
  image_data = list(image.getdata())
308
+
309
  return json.dumps({"image_data": image_data, "is_end": True})
310
+
311
  except Exception as e:
312
+ raise HTTPException(
313
+ status_code=500,
314
+ detail=f"Internal server error: {str(e)}"
315
+ )
316
+
317
+
318
  @app.post("/generate-text-to-speech")
319
  async def generate_text_to_speech(request: GenerateRequest):
320
  try:
321
  validated_body = request
322
+ device = "cuda" if torch.cuda.is_available() else "cpu"
323
+
324
+ if validated_body.model_name not in model_data:
325
+ config = AutoConfig.from_pretrained(
326
+ validated_body.model_name, token=HUGGINGFACE_HUB_TOKEN
327
+ )
328
+
329
+ audio_generator = pipeline(
330
+ "text-to-speech", model=validated_body.model_name,
331
+ device=device, config=config
332
+ )
333
+ model_data[validated_body.model_name] = {"model":audio_generator}
334
+ else:
335
+ audio_generator = model_data[validated_body.model_name]["model"]
336
+
337
  audio = audio_generator(validated_body.input_text)
338
+
339
+
340
  audio_bytes = audio["audio"]
341
+
342
  audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
343
+
344
  return json.dumps({"audio": audio_base64, "is_end": True})
345
+
346
  except Exception as e:
347
+ raise HTTPException(
348
+ status_code=500,
349
+ detail=f"Internal server error: {str(e)}"
350
+ )
351
+
352
+
353
  @app.post("/generate-video")
354
  async def generate_video(request: GenerateRequest):
355
  try:
356
  validated_body = request
357
+ device = "cuda" if torch.cuda.is_available() else "cpu"
358
+ if validated_body.model_name not in model_data:
359
+ config = AutoConfig.from_pretrained(
360
+ validated_body.model_name, token=HUGGINGFACE_HUB_TOKEN
361
+ )
362
+
363
+ video_generator = pipeline(
364
+ "text-to-video", model=validated_body.model_name,
365
+ device=device, config=config
366
+ )
367
+ model_data[validated_body.model_name] = {"model":video_generator}
368
+ else:
369
+ video_generator = model_data[validated_body.model_name]["model"]
370
+
371
  video = video_generator(validated_body.input_text)
372
+
373
+
374
  video_base64 = base64.b64encode(video).decode('utf-8')
375
+
376
  return json.dumps({"video": video_base64, "is_end": True})
377
+
378
  except Exception as e:
379
+ raise HTTPException(
380
+ status_code=500,
381
+ detail=f"Internal server error: {str(e)}"
382
+ )
383
+
384
  if __name__ == "__main__":
 
 
385
  uvicorn.run(app, host="0.0.0.0", port=7860)