Hjgugugjhuhjggg commited on
Commit
7c21718
·
verified ·
1 Parent(s): cc33a18

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import StreamingResponse
5
+ from pydantic import BaseModel, field_validator
6
+ from transformers import (
7
+ AutoConfig,
8
+ pipeline,
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ GenerationConfig,
12
+ StoppingCriteriaList
13
+ )
14
+ import boto3
15
+ import uvicorn
16
+ import asyncio
17
+ from io import BytesIO
18
+ from transformers import pipeline
19
+
20
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
+ AWS_REGION = os.getenv("AWS_REGION")
23
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24
+ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
25
+
26
+ s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
27
+
28
+ app = FastAPI()
29
+
30
+ class GenerateRequest(BaseModel):
31
+ model_name: str
32
+ input_text: str = ""
33
+ task_type: str
34
+ temperature: float = 1.0
35
+ max_new_tokens: int = 200
36
+ stream: bool = True
37
+ top_p: float = 1.0
38
+ top_k: int = 50
39
+ repetition_penalty: float = 1.0
40
+ num_return_sequences: int = 1
41
+ do_sample: bool = True
42
+ chunk_delay: float = 0.0
43
+ stop_sequences: list[str] = []
44
+
45
+ @field_validator("model_name")
46
+ def model_name_cannot_be_empty(cls, v):
47
+ if not v:
48
+ raise ValueError("model_name cannot be empty.")
49
+ return v
50
+
51
+ @field_validator("task_type")
52
+ def task_type_must_be_valid(cls, v):
53
+ valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
54
+ if v not in valid_types:
55
+ raise ValueError(f"task_type must be one of: {valid_types}")
56
+ return v
57
+
58
+ class S3ModelLoader:
59
+ def __init__(self, bucket_name, s3_client):
60
+ self.bucket_name = bucket_name
61
+ self.s3_client = s3_client
62
+
63
+ def _get_s3_uri(self, model_name):
64
+ return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
65
+
66
+ async def load_model_and_tokenizer(self, model_name):
67
+ s3_uri = self._get_s3_uri(model_name)
68
+ try:
69
+ config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
70
+ model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
71
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
72
+
73
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
74
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
75
+
76
+ return model, tokenizer
77
+ except EnvironmentError:
78
+ try:
79
+ config = AutoConfig.from_pretrained(model_name)
80
+ tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
81
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
82
+
83
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
84
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
85
+
86
+ model.save_pretrained(s3_uri)
87
+ tokenizer.save_pretrained(s3_uri)
88
+ return model, tokenizer
89
+ except Exception as e:
90
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
91
+
92
+ model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
93
+
94
+ @app.post("/generate")
95
+ async def generate(request: GenerateRequest):
96
+ try:
97
+ model_name = request.model_name
98
+ input_text = request.input_text
99
+ task_type = request.task_type
100
+ temperature = request.temperature
101
+ max_new_tokens = request.max_new_tokens
102
+ stream = request.stream
103
+ top_p = request.top_p
104
+ top_k = request.top_k
105
+ repetition_penalty = request.repetition_penalty
106
+ num_return_sequences = request.num_return_sequences
107
+ do_sample = request.do_sample
108
+ chunk_delay = request.chunk_delay
109
+ stop_sequences = request.stop_sequences
110
+
111
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ model.to(device)
114
+
115
+ generation_config = GenerationConfig(
116
+ temperature=temperature,
117
+ max_new_tokens=max_new_tokens,
118
+ top_p=top_p,
119
+ top_k=top_k,
120
+ repetition_penalty=repetition_penalty,
121
+ do_sample=do_sample,
122
+ num_return_sequences=num_return_sequences,
123
+ )
124
+
125
+ return StreamingResponse(
126
+ stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
127
+ media_type="text/plain"
128
+ )
129
+
130
+ except Exception as e:
131
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
132
+
133
+ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
134
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
+ input_length = encoded_input["input_ids"].shape[1]
136
+ remaining_tokens = max_length - input_length
137
+
138
+ if remaining_tokens <= 0:
139
+ yield ""
140
+
141
+ generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
142
+
143
+ def stop_criteria(input_ids, scores):
144
+ decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
145
+ return decoded_output in stop_sequences
146
+
147
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
148
+
149
+ output_text = ""
150
+ outputs = model.generate(
151
+ **encoded_input,
152
+ do_sample=generation_config.do_sample,
153
+ max_new_tokens=generation_config.max_new_tokens,
154
+ temperature=generation_config.temperature,
155
+ top_p=generation_config.top_p,
156
+ top_k=generation_config.top_k,
157
+ repetition_penalty=generation_config.repetition_penalty,
158
+ num_return_sequences=generation_config.num_return_sequences,
159
+ stopping_criteria=stopping_criteria,
160
+ output_scores=True,
161
+ return_dict_in_generate=True
162
+ )
163
+
164
+ for output in outputs.sequences:
165
+ for token_id in output:
166
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
167
+ yield token
168
+ await asyncio.sleep(chunk_delay) # Simula el delay entre tokens
169
+
170
+ if stop_sequences and any(stop in output_text for stop in stop_sequences):
171
+ yield output_text
172
+ return
173
+
174
+ outputs = model.generate(
175
+ **encoded_input,
176
+ do_sample=generation_config.do_sample,
177
+ max_new_tokens=generation_config.max_new_tokens,
178
+ temperature=generation_config.temperature,
179
+ top_p=generation_config.top_p,
180
+ top_k=generation_config.top_k,
181
+ repetition_penalty=generation_config.repetition_penalty,
182
+ num_return_sequences=generation_config.num_return_sequences,
183
+ stopping_criteria=stopping_criteria,
184
+ output_scores=True,
185
+ return_dict_in_generate=True
186
+ )
187
+
188
+ @app.post("/generate-image")
189
+ async def generate_image(request: GenerateRequest):
190
+ try:
191
+ validated_body = request
192
+ device = "cuda" if torch.cuda.is_available() else "cpu"
193
+
194
+ image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
195
+ image = image_generator(validated_body.input_text)[0]
196
+
197
+ img_byte_arr = BytesIO()
198
+ image.save(img_byte_arr, format="PNG")
199
+ img_byte_arr.seek(0)
200
+
201
+ return StreamingResponse(img_byte_arr, media_type="image/png")
202
+
203
+ except Exception as e:
204
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
205
+
206
+ @app.post("/generate-text-to-speech")
207
+ async def generate_text_to_speech(request: GenerateRequest):
208
+ try:
209
+ validated_body = request
210
+ device = "cuda" if torch.cuda.is_available() else "cpu"
211
+
212
+ audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
213
+ audio = audio_generator(validated_body.input_text)[0]
214
+
215
+ audio_byte_arr = BytesIO()
216
+ audio.save(audio_byte_arr)
217
+ audio_byte_arr.seek(0)
218
+
219
+ return StreamingResponse(audio_byte_arr, media_type="audio/wav")
220
+
221
+ except Exception as e:
222
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
223
+
224
+ @app.post("/generate-video")
225
+ async def generate_video(request: GenerateRequest):
226
+ try:
227
+ validated_body = request
228
+ device = "cuda" if torch.cuda.is_available() else "cpu"
229
+ video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
230
+ video = video_generator(validated_body.input_text)[0]
231
+
232
+ video_byte_arr = BytesIO()
233
+ video.save(video_byte_arr)
234
+ video_byte_arr.seek(0)
235
+
236
+ return StreamingResponse(video_byte_arr, media_type="video/mp4")
237
+
238
+ except Exception as e:
239
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
240
+
241
+ if __name__ == "__main__":
242
+ uvicorn.run(app, host="0.0.0.0", port=7860)