Hjgugugjhuhjggg commited on
Commit
d9e405b
·
verified ·
1 Parent(s): 847979a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -34
app.py CHANGED
@@ -2,16 +2,18 @@ import os
2
  import logging
3
  import time
4
  from io import BytesIO
 
5
 
6
- from fastapi import FastAPI, HTTPException, Response, Request
7
  from fastapi.responses import StreamingResponse
8
- from pydantic import BaseModel
9
  from transformers import (
10
  AutoConfig,
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  pipeline,
14
- GenerationConfig
 
15
  )
16
  import boto3
17
  from huggingface_hub import hf_hub_download
@@ -20,7 +22,7 @@ import numpy as np
20
  import torch
21
  import uvicorn
22
 
23
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
 
25
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
26
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -30,7 +32,7 @@ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
30
 
31
  class GenerateRequest(BaseModel):
32
  model_name: str
33
- input_text: str
34
  task_type: str
35
  temperature: float = 1.0
36
  max_new_tokens: int = 200
@@ -41,6 +43,20 @@ class GenerateRequest(BaseModel):
41
  num_return_sequences: int = 1
42
  do_sample: bool = True
43
  chunk_delay: float = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class S3ModelLoader:
46
  def __init__(self, bucket_name, s3_client):
@@ -57,6 +73,13 @@ class S3ModelLoader:
57
  config = AutoConfig.from_pretrained(s3_uri)
58
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
59
  tokenizer = AutoTokenizer.from_pretrained(s3_uri)
 
 
 
 
 
 
 
60
  logging.info(f"Loaded {model_name} from S3 successfully.")
61
  return model, tokenizer
62
  except EnvironmentError:
@@ -64,6 +87,14 @@ class S3ModelLoader:
64
  try:
65
  model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
66
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
 
 
 
 
 
 
 
 
67
  logging.info(f"Downloaded {model_name} successfully.")
68
  logging.info(f"Saving {model_name} to S3...")
69
  model.save_pretrained(s3_uri)
@@ -71,7 +102,7 @@ class S3ModelLoader:
71
  logging.info(f"Saved {model_name} to S3 successfully.")
72
  return model, tokenizer
73
  except Exception as e:
74
- logging.error(f"Error downloading/uploading model: {e}")
75
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
76
 
77
  app = FastAPI()
@@ -82,23 +113,27 @@ model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
82
  @app.post("/generate")
83
  async def generate(request: Request, body: GenerateRequest):
84
  try:
85
- model, tokenizer = await model_loader.load_model_and_tokenizer(body.model_name)
 
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  model.to(device)
88
 
89
- if body.task_type == "text-to-text":
90
  generation_config = GenerationConfig(
91
- temperature=body.temperature,
92
- max_new_tokens=body.max_new_tokens,
93
- top_p=body.top_p,
94
- top_k=body.top_k,
95
- repetition_penalty=body.repetition_penalty,
96
- do_sample=body.do_sample,
97
- num_return_sequences=body.num_return_sequences
 
 
 
98
  )
99
 
100
  async def stream_text():
101
- input_text = body.input_text
102
  generated_text = ""
103
  max_length = model.config.max_position_embeddings
104
 
@@ -110,16 +145,16 @@ async def generate(request: Request, body: GenerateRequest):
110
  if remaining_tokens <= 0:
111
  break
112
 
113
- generation_config.max_new_tokens = min(remaining_tokens, body.max_new_tokens)
114
 
115
  output = model.generate(**encoded_input, generation_config=generation_config)
116
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
117
  generated_text += chunk
118
  yield chunk
119
- time.sleep(body.chunk_delay)
120
  input_text = generated_text
121
 
122
- if body.stream:
123
  return StreamingResponse(stream_text(), media_type="text/plain")
124
  else:
125
  generated_text = ""
@@ -127,32 +162,24 @@ async def generate(request: Request, body: GenerateRequest):
127
  generated_text += chunk
128
  return {"result": generated_text}
129
 
130
- elif body.task_type == "text-to-image":
131
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
132
- image = generator(body.input_text)[0]
133
  image_bytes = image.tobytes()
134
  return Response(content=image_bytes, media_type="image/png")
135
 
136
- elif body.task_type == "text-to-speech":
137
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
138
- audio = generator(body.input_text)
139
- audio_bytesio = BytesIO()
140
- sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
141
- audio_bytes = audio_bytesio.getvalue()
142
- return Response(content=audio_bytes, media_type="audio/wav")
143
-
144
- elif body.task_type == "text-to-audio":
145
- generator = pipeline("text-to-audio", model=model, tokenizer=tokenizer, device=device)
146
- audio = generator(body.input_text)
147
  audio_bytesio = BytesIO()
148
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
149
  audio_bytes = audio_bytesio.getvalue()
150
  return Response(content=audio_bytes, media_type="audio/wav")
151
 
152
- elif body.task_type == "text-to-video":
153
  try:
154
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
155
- video = generator(body.input_text)
156
  return Response(content=video, media_type="video/mp4")
157
  except Exception as e:
158
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
@@ -162,8 +189,11 @@ async def generate(request: Request, body: GenerateRequest):
162
 
163
  except HTTPException as e:
164
  raise e
 
 
165
  except Exception as e:
166
- raise HTTPException(status_code=500, detail=str(e))
 
167
 
168
 
169
  if __name__ == "__main__":
 
2
  import logging
3
  import time
4
  from io import BytesIO
5
+ from typing import Union
6
 
7
+ from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
8
  from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel, ValidationError, validator
10
  from transformers import (
11
  AutoConfig,
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
  pipeline,
15
+ GenerationConfig,
16
+ StoppingCriteriaList
17
  )
18
  import boto3
19
  from huggingface_hub import hf_hub_download
 
22
  import torch
23
  import uvicorn
24
 
25
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
26
 
27
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
28
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
32
 
33
  class GenerateRequest(BaseModel):
34
  model_name: str
35
+ input_text: str = ""
36
  task_type: str
37
  temperature: float = 1.0
38
  max_new_tokens: int = 200
 
43
  num_return_sequences: int = 1
44
  do_sample: bool = True
45
  chunk_delay: float = 0.0
46
+ stop_sequences: list[str] = []
47
+
48
+ @validator("model_name")
49
+ def model_name_cannot_be_empty(cls, v):
50
+ if not v:
51
+ raise ValueError("model_name cannot be empty.")
52
+ return v
53
+
54
+ @validator("task_type")
55
+ def task_type_must_be_valid(cls, v):
56
+ valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
57
+ if v not in valid_types:
58
+ raise ValueError(f"task_type must be one of: {valid_types}")
59
+ return v
60
 
61
  class S3ModelLoader:
62
  def __init__(self, bucket_name, s3_client):
 
73
  config = AutoConfig.from_pretrained(s3_uri)
74
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
75
  tokenizer = AutoTokenizer.from_pretrained(s3_uri)
76
+
77
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
78
+ if config.pad_token_id is not None:
79
+ tokenizer.pad_token_id = config.pad_token_id
80
+ else:
81
+ tokenizer.pad_token_id = 0
82
+
83
  logging.info(f"Loaded {model_name} from S3 successfully.")
84
  return model, tokenizer
85
  except EnvironmentError:
 
87
  try:
88
  model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
89
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
90
+
91
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
+ config = AutoConfig.from_pretrained(model_name)
93
+ if config.pad_token_id is not None:
94
+ tokenizer.pad_token_id = config.pad_token_id
95
+ else:
96
+ tokenizer.pad_token_id = 0
97
+
98
  logging.info(f"Downloaded {model_name} successfully.")
99
  logging.info(f"Saving {model_name} to S3...")
100
  model.save_pretrained(s3_uri)
 
102
  logging.info(f"Saved {model_name} to S3 successfully.")
103
  return model, tokenizer
104
  except Exception as e:
105
+ logging.exception(f"Error downloading/uploading model: {e}")
106
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
107
 
108
  app = FastAPI()
 
113
  @app.post("/generate")
114
  async def generate(request: Request, body: GenerateRequest):
115
  try:
116
+ validated_body = GenerateRequest(**body.model_dump())
117
+ model, tokenizer = await model_loader.load_model_and_tokenizer(validated_body.model_name)
118
  device = "cuda" if torch.cuda.is_available() else "cpu"
119
  model.to(device)
120
 
121
+ if validated_body.task_type == "text-to-text":
122
  generation_config = GenerationConfig(
123
+ temperature=validated_body.temperature,
124
+ max_new_tokens=validated_body.max_new_tokens,
125
+ top_p=validated_body.top_p,
126
+ top_k=validated_body.top_k,
127
+ repetition_penalty=validated_body.repetition_penalty,
128
+ do_sample=validated_body.do_sample,
129
+ num_return_sequences=validated_body.num_return_sequences,
130
+ stopping_criteria=StoppingCriteriaList(
131
+ [lambda _, outputs: tokenizer.decode(outputs[0][-1]) in validated_body.stop_sequences] if validated_body.stop_sequences else []
132
+ )
133
  )
134
 
135
  async def stream_text():
136
+ input_text = validated_body.input_text
137
  generated_text = ""
138
  max_length = model.config.max_position_embeddings
139
 
 
145
  if remaining_tokens <= 0:
146
  break
147
 
148
+ generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
149
 
150
  output = model.generate(**encoded_input, generation_config=generation_config)
151
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
152
  generated_text += chunk
153
  yield chunk
154
+ time.sleep(validated_body.chunk_delay)
155
  input_text = generated_text
156
 
157
+ if validated_body.stream:
158
  return StreamingResponse(stream_text(), media_type="text/plain")
159
  else:
160
  generated_text = ""
 
162
  generated_text += chunk
163
  return {"result": generated_text}
164
 
165
+ elif validated_body.task_type == "text-to-image":
166
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
167
+ image = generator(validated_body.input_text)[0]
168
  image_bytes = image.tobytes()
169
  return Response(content=image_bytes, media_type="image/png")
170
 
171
+ elif validated_body.task_type == "text-to-speech":
172
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
173
+ audio = generator(validated_body.input_text)
 
 
 
 
 
 
 
 
174
  audio_bytesio = BytesIO()
175
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
176
  audio_bytes = audio_bytesio.getvalue()
177
  return Response(content=audio_bytes, media_type="audio/wav")
178
 
179
+ elif validated_body.task_type == "text-to-video":
180
  try:
181
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
182
+ video = generator(validated_body.input_text)
183
  return Response(content=video, media_type="video/mp4")
184
  except Exception as e:
185
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
 
189
 
190
  except HTTPException as e:
191
  raise e
192
+ except ValidationError as e:
193
+ raise HTTPException(status_code=422, detail=e.errors())
194
  except Exception as e:
195
+ logging.exception(f"An unexpected error occurred: {e}")
196
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
197
 
198
 
199
  if __name__ == "__main__":