Hjgugugjhuhjggg commited on
Commit
eda5cd9
·
verified ·
1 Parent(s): 56fc8ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -121
app.py CHANGED
@@ -1,29 +1,19 @@
1
  import os
2
- import logging
3
- import time
4
- from io import BytesIO
5
- from typing import Union
6
-
7
- from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
8
  from fastapi.responses import StreamingResponse
9
- from pydantic import BaseModel, ValidationError, field_validator
10
  from transformers import (
11
- AutoConfig,
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
- pipeline,
15
  GenerationConfig,
16
  StoppingCriteriaList
17
  )
18
  import boto3
19
- from huggingface_hub import hf_hub_download, HfApi
20
- import soundfile as sf
21
- import numpy as np
22
- import torch
23
  import uvicorn
24
- import shutil
25
-
26
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
27
 
28
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
29
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -31,13 +21,17 @@ AWS_REGION = os.getenv("AWS_REGION")
31
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
32
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
33
 
 
 
 
 
34
  class GenerateRequest(BaseModel):
35
  model_name: str
36
  input_text: str = ""
37
  task_type: str
38
  temperature: float = 1.0
39
  max_new_tokens: int = 200
40
- stream: bool = False
41
  top_p: float = 1.0
42
  top_k: int = 50
43
  repetition_penalty: float = 1.0
@@ -46,8 +40,6 @@ class GenerateRequest(BaseModel):
46
  chunk_delay: float = 0.0
47
  stop_sequences: list[str] = []
48
 
49
- model_config = {"protected_namespaces": ()}
50
-
51
  @field_validator("model_name")
52
  def model_name_cannot_be_empty(cls, v):
53
  if not v:
@@ -65,7 +57,6 @@ class S3ModelLoader:
65
  def __init__(self, bucket_name, s3_client):
66
  self.bucket_name = bucket_name
67
  self.s3_client = s3_client
68
- self.api = HfApi()
69
 
70
  def _get_s3_uri(self, model_name):
71
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
@@ -73,7 +64,6 @@ class S3ModelLoader:
73
  async def load_model_and_tokenizer(self, model_name):
74
  s3_uri = self._get_s3_uri(model_name)
75
  try:
76
- logging.info(f"Trying to load {model_name} from S3...")
77
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
78
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
79
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
@@ -81,130 +71,170 @@ class S3ModelLoader:
81
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
82
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
83
 
84
- logging.info(f"Loaded {model_name} from S3 successfully.")
85
  return model, tokenizer
86
  except EnvironmentError:
87
- logging.info(f"Model {model_name} not found in S3. Downloading...")
88
  try:
89
- model_info = self.api.model_info(model_name)
90
- files_to_download = [f.filename for f in self.api.list_repo_files(model_name)]
91
-
92
- temp_dir = "temp_model"
93
- os.makedirs(temp_dir, exist_ok=True)
94
-
95
- for file_name in files_to_download:
96
- hf_hub_download(repo_id=model_name, filename=file_name, local_dir=temp_dir, token=HUGGINGFACE_HUB_TOKEN)
97
-
98
- config = AutoConfig.from_pretrained(temp_dir)
99
- tokenizer = AutoTokenizer.from_pretrained(temp_dir, config=config)
100
- model = AutoModelForCausalLM.from_pretrained(temp_dir, config=config)
101
 
102
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
103
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
104
 
105
- logging.info(f"Downloaded {model_name} successfully.")
106
- logging.info(f"Saving {model_name} to S3...")
107
  model.save_pretrained(s3_uri)
108
  tokenizer.save_pretrained(s3_uri)
109
- logging.info(f"Saved {model_name} to S3 successfully.")
110
-
111
- shutil.rmtree(temp_dir)
112
-
113
  return model, tokenizer
114
  except Exception as e:
115
- logging.exception(f"Error downloading/uploading model: {e}")
116
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
117
 
118
- app = FastAPI()
119
-
120
- s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
121
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
122
 
123
  @app.post("/generate")
124
- async def generate(request: Request, body: GenerateRequest):
125
  try:
126
- validated_body = GenerateRequest(**body.model_dump())
127
- model, tokenizer = await model_loader.load_model_and_tokenizer(validated_body.model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  device = "cuda" if torch.cuda.is_available() else "cpu"
129
  model.to(device)
130
 
131
- if validated_body.task_type == "text-to-text":
132
- generation_config = GenerationConfig(
133
- temperature=validated_body.temperature,
134
- max_new_tokens=validated_body.max_new_tokens,
135
- top_p=validated_body.top_p,
136
- top_k=validated_body.top_k,
137
- repetition_penalty=validated_body.repetition_penalty,
138
- do_sample=validated_body.do_sample,
139
- num_return_sequences=validated_body.num_return_sequences,
140
- )
141
-
142
- async def stream_text():
143
- input_text = validated_body.input_text
144
- generated_text = ""
145
- max_length = model.config.max_position_embeddings
146
-
147
- while True:
148
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
149
- input_length = encoded_input["input_ids"].shape[1]
150
- remaining_tokens = max_length - input_length
151
-
152
- if remaining_tokens <= 0:
153
- break
154
-
155
- generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
156
- stopping_criteria = StoppingCriteriaList(
157
- [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
158
- )
159
-
160
- output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
161
- chunk = tokenizer.decode(output[0], skip_special_tokens=True)
162
- generated_text += chunk
163
- yield chunk
164
- time.sleep(validated_body.chunk_delay)
165
- input_text = generated_text
166
-
167
- if validated_body.stream:
168
- return StreamingResponse(stream_text(), media_type="text/plain")
169
- else:
170
- generated_text = ""
171
- async for chunk in stream_text():
172
- generated_text += chunk
173
- return {"result": generated_text}
174
-
175
- elif validated_body.task_type == "text-to-image":
176
- generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
177
- image = generator(validated_body.input_text)[0]
178
- image_bytes = image.tobytes()
179
- return Response(content=image_bytes, media_type="image/png")
180
-
181
- elif validated_body.task_type == "text-to-speech":
182
- generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
183
- audio = generator(validated_body.input_text)
184
- audio_bytesio = BytesIO()
185
- sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
186
- audio_bytes = audio_bytesio.getvalue()
187
- return Response(content=audio_bytes, media_type="audio/wav")
188
-
189
- elif validated_body.task_type == "text-to-video":
190
- try:
191
- generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
192
- video = generator(validated_body.input_text)
193
- return Response(content=video, media_type="video/mp4")
194
- except Exception as e:
195
- raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
196
 
197
- else:
198
- raise HTTPException(status_code=400, detail="Unsupported task type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- except HTTPException as e:
201
- raise e
202
- except ValidationError as e:
203
- raise HTTPException(status_code=422, detail=e.errors())
204
  except Exception as e:
205
- logging.exception(f"An unexpected error occurred: {e}")
206
- raise HTTPException(status_code=500, detail="An unexpected error occurred.")
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ import torch
3
+ from fastapi import FastAPI, HTTPException
 
 
 
 
4
  from fastapi.responses import StreamingResponse
5
+ from pydantic import BaseModel, field_validator
6
  from transformers import (
 
7
  AutoModelForCausalLM,
8
  AutoTokenizer,
 
9
  GenerationConfig,
10
  StoppingCriteriaList
11
  )
12
  import boto3
 
 
 
 
13
  import uvicorn
14
+ import asyncio
15
+ from io import BytesIO
16
+ from transformers import pipeline
17
 
18
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
19
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
21
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
22
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
23
 
24
+ s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
25
+
26
+ app = FastAPI()
27
+
28
  class GenerateRequest(BaseModel):
29
  model_name: str
30
  input_text: str = ""
31
  task_type: str
32
  temperature: float = 1.0
33
  max_new_tokens: int = 200
34
+ stream: bool = True
35
  top_p: float = 1.0
36
  top_k: int = 50
37
  repetition_penalty: float = 1.0
 
40
  chunk_delay: float = 0.0
41
  stop_sequences: list[str] = []
42
 
 
 
43
  @field_validator("model_name")
44
  def model_name_cannot_be_empty(cls, v):
45
  if not v:
 
57
  def __init__(self, bucket_name, s3_client):
58
  self.bucket_name = bucket_name
59
  self.s3_client = s3_client
 
60
 
61
  def _get_s3_uri(self, model_name):
62
  return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
 
64
  async def load_model_and_tokenizer(self, model_name):
65
  s3_uri = self._get_s3_uri(model_name)
66
  try:
 
67
  config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
68
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
69
  tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
 
71
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
72
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
73
 
 
74
  return model, tokenizer
75
  except EnvironmentError:
 
76
  try:
77
+ config = AutoConfig.from_pretrained(model_name)
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
79
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
 
 
 
 
 
 
 
 
 
80
 
81
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
82
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
83
 
 
 
84
  model.save_pretrained(s3_uri)
85
  tokenizer.save_pretrained(s3_uri)
 
 
 
 
86
  return model, tokenizer
87
  except Exception as e:
 
88
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
89
 
 
 
 
90
  model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
91
 
92
  @app.post("/generate")
93
+ async def generate(request: GenerateRequest):
94
  try:
95
+ model_name = request.model_name
96
+ input_text = request.input_text
97
+ task_type = request.task_type
98
+ temperature = request.temperature
99
+ max_new_tokens = request.max_new_tokens
100
+ stream = request.stream
101
+ top_p = request.top_p
102
+ top_k = request.top_k
103
+ repetition_penalty = request.repetition_penalty
104
+ num_return_sequences = request.num_return_sequences
105
+ do_sample = request.do_sample
106
+ chunk_delay = request.chunk_delay
107
+ stop_sequences = request.stop_sequences
108
+
109
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
110
  device = "cuda" if torch.cuda.is_available() else "cpu"
111
  model.to(device)
112
 
113
+ generation_config = GenerationConfig(
114
+ temperature=temperature,
115
+ max_new_tokens=max_new_tokens,
116
+ top_p=top_p,
117
+ top_k=top_k,
118
+ repetition_penalty=repetition_penalty,
119
+ do_sample=do_sample,
120
+ num_return_sequences=num_return_sequences,
121
+ )
122
+
123
+ return StreamingResponse(
124
+ stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
125
+ media_type="text/plain"
126
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
130
+
131
+ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
132
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
133
+ input_length = encoded_input["input_ids"].shape[1]
134
+ remaining_tokens = max_length - input_length
135
+
136
+ if remaining_tokens <= 0:
137
+ yield ""
138
+
139
+ generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
140
+
141
+ def stop_criteria(input_ids, scores):
142
+ decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
143
+ return decoded_output in stop_sequences
144
+
145
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
146
+
147
+ output_text = ""
148
+ outputs = model.generate(
149
+ **encoded_input,
150
+ do_sample=generation_config.do_sample,
151
+ max_new_tokens=generation_config.max_new_tokens,
152
+ temperature=generation_config.temperature,
153
+ top_p=generation_config.top_p,
154
+ top_k=generation_config.top_k,
155
+ repetition_penalty=generation_config.repetition_penalty,
156
+ num_return_sequences=generation_config.num_return_sequences,
157
+ stopping_criteria=stopping_criteria,
158
+ output_scores=True,
159
+ return_dict_in_generate=True
160
+ )
161
+
162
+ for output in outputs.sequences:
163
+ for token_id in output:
164
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
165
+ yield token
166
+ await asyncio.sleep(chunk_delay) # Simula el delay entre tokens
167
+
168
+ if stop_sequences and any(stop in output_text for stop in stop_sequences):
169
+ yield output_text
170
+ return
171
+
172
+ outputs = model.generate(
173
+ **encoded_input,
174
+ do_sample=generation_config.do_sample,
175
+ max_new_tokens=generation_config.max_new_tokens,
176
+ temperature=generation_config.temperature,
177
+ top_p=generation_config.top_p,
178
+ top_k=generation_config.top_k,
179
+ repetition_penalty=generation_config.repetition_penalty,
180
+ num_return_sequences=generation_config.num_return_sequences,
181
+ stopping_criteria=stopping_criteria,
182
+ output_scores=True,
183
+ return_dict_in_generate=True
184
+ )
185
+
186
+ @app.post("/generate-image")
187
+ async def generate_image(request: GenerateRequest):
188
+ try:
189
+ validated_body = request
190
+ device = "cuda" if torch.cuda.is_available() else "cpu"
191
+
192
+ image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
193
+ image = image_generator(validated_body.input_text)[0]
194
+
195
+ img_byte_arr = BytesIO()
196
+ image.save(img_byte_arr, format="PNG")
197
+ img_byte_arr.seek(0)
198
+
199
+ return StreamingResponse(img_byte_arr, media_type="image/png")
200
+
201
+ except Exception as e:
202
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
203
+
204
+ @app.post("/generate-text-to-speech")
205
+ async def generate_text_to_speech(request: GenerateRequest):
206
+ try:
207
+ validated_body = request
208
+ device = "cuda" if torch.cuda.is_available() else "cpu"
209
+
210
+ audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
211
+ audio = audio_generator(validated_body.input_text)[0]
212
+
213
+ audio_byte_arr = BytesIO()
214
+ audio.save(audio_byte_arr)
215
+ audio_byte_arr.seek(0)
216
+
217
+ return StreamingResponse(audio_byte_arr, media_type="audio/wav")
218
 
 
 
 
 
219
  except Exception as e:
220
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
221
 
222
+ @app.post("/generate-video")
223
+ async def generate_video(request: GenerateRequest):
224
+ try:
225
+ validated_body = request
226
+ device = "cuda" if torch.cuda.is_available() else "cpu"
227
+ video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
228
+ video = video_generator(validated_body.input_text)[0]
229
+
230
+ video_byte_arr = BytesIO()
231
+ video.save(video_byte_arr)
232
+ video_byte_arr.seek(0)
233
+
234
+ return StreamingResponse(video_byte_arr, media_type="video/mp4")
235
+
236
+ except Exception as e:
237
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
238
 
239
  if __name__ == "__main__":
240
  uvicorn.run(app, host="0.0.0.0", port=7860)