Hjgugugjhuhjggg commited on
Commit
1b3e8da
·
verified ·
1 Parent(s): ab9ad2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -76
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import os
2
  import logging
3
- import time
4
  from io import BytesIO
5
- from typing import Union
6
 
7
- from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
8
  from fastapi.responses import StreamingResponse
9
- from pydantic import BaseModel, ValidationError, field_validator
10
  from transformers import (
11
  AutoConfig,
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
  pipeline,
15
- GenerationConfig,
16
- StoppingCriteriaList
17
  )
18
  import boto3
19
  from huggingface_hub import hf_hub_download
@@ -21,8 +18,9 @@ import soundfile as sf
21
  import numpy as np
22
  import torch
23
  import uvicorn
 
24
 
25
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
26
 
27
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
28
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -32,7 +30,7 @@ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
32
 
33
  class GenerateRequest(BaseModel):
34
  model_name: str
35
- input_text: str = ""
36
  task_type: str
37
  temperature: float = 1.0
38
  max_new_tokens: int = 200
@@ -42,23 +40,6 @@ class GenerateRequest(BaseModel):
42
  repetition_penalty: float = 1.0
43
  num_return_sequences: int = 1
44
  do_sample: bool = True
45
- chunk_delay: float = 0.0
46
- stop_sequences: list[str] = []
47
-
48
- model_config = {"protected_namespaces": ()}
49
-
50
- @field_validator("model_name")
51
- def model_name_cannot_be_empty(cls, v):
52
- if not v:
53
- raise ValueError("model_name cannot be empty.")
54
- return v
55
-
56
- @field_validator("task_type")
57
- def task_type_must_be_valid(cls, v):
58
- valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
59
- if v not in valid_types:
60
- raise ValueError(f"task_type must be one of: {valid_types}")
61
- return v
62
 
63
  class S3ModelLoader:
64
  def __init__(self, bucket_name, s3_client):
@@ -74,23 +55,15 @@ class S3ModelLoader:
74
  logging.info(f"Trying to load {model_name} from S3...")
75
  config = AutoConfig.from_pretrained(s3_uri)
76
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
78
-
79
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
81
-
82
  logging.info(f"Loaded {model_name} from S3 successfully.")
83
  return model, tokenizer
84
  except EnvironmentError:
85
  logging.info(f"Model {model_name} not found in S3. Downloading...")
86
  try:
87
- config = AutoConfig.from_pretrained(model_name)
88
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
89
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
90
-
91
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
93
-
94
  logging.info(f"Downloaded {model_name} successfully.")
95
  logging.info(f"Saving {model_name} to S3...")
96
  model.save_pretrained(s3_uri)
@@ -98,7 +71,7 @@ class S3ModelLoader:
98
  logging.info(f"Saved {model_name} to S3 successfully.")
99
  return model, tokenizer
100
  except Exception as e:
101
- logging.exception(f"Error downloading/uploading model: {e}")
102
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
103
 
104
  app = FastAPI()
@@ -109,49 +82,44 @@ model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
109
  @app.post("/generate")
110
  async def generate(request: Request, body: GenerateRequest):
111
  try:
112
- validated_body = GenerateRequest(**body.model_dump())
113
- model, tokenizer = await model_loader.load_model_and_tokenizer(validated_body.model_name)
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
  model.to(device)
116
 
117
- if validated_body.task_type == "text-to-text":
118
  generation_config = GenerationConfig(
119
- temperature=validated_body.temperature,
120
- max_new_tokens=validated_body.max_new_tokens,
121
- top_p=validated_body.top_p,
122
- top_k=validated_body.top_k,
123
- repetition_penalty=validated_body.repetition_penalty,
124
- do_sample=validated_body.do_sample,
125
- num_return_sequences=validated_body.num_return_sequences
126
  )
127
 
128
  async def stream_text():
129
- input_text = validated_body.input_text
130
- generated_text = ""
131
  max_length = model.config.max_position_embeddings
 
132
 
133
  while True:
134
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
- input_length = encoded_input["input_ids"].shape[1]
136
  remaining_tokens = max_length - input_length
 
 
 
 
137
 
138
- if remaining_tokens <= 0:
139
- break
140
-
141
- generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
142
-
143
- stopping_criteria = StoppingCriteriaList(
144
- [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
145
- )
146
-
147
- output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
148
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
149
  generated_text += chunk
150
  yield chunk
151
- time.sleep(validated_body.chunk_delay)
152
- input_text = generated_text
 
153
 
154
- if validated_body.stream:
155
  return StreamingResponse(stream_text(), media_type="text/plain")
156
  else:
157
  generated_text = ""
@@ -159,24 +127,24 @@ async def generate(request: Request, body: GenerateRequest):
159
  generated_text += chunk
160
  return {"result": generated_text}
161
 
162
- elif validated_body.task_type == "text-to-image":
163
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
164
- image = generator(validated_body.input_text)[0]
165
  image_bytes = image.tobytes()
166
  return Response(content=image_bytes, media_type="image/png")
167
 
168
- elif validated_body.task_type == "text-to-speech":
169
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
170
- audio = generator(validated_body.input_text)
171
  audio_bytesio = BytesIO()
172
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
173
  audio_bytes = audio_bytesio.getvalue()
174
  return Response(content=audio_bytes, media_type="audio/wav")
175
 
176
- elif validated_body.task_type == "text-to-video":
177
  try:
178
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
179
- video = generator(validated_body.input_text)
180
  return Response(content=video, media_type="video/mp4")
181
  except Exception as e:
182
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
@@ -186,12 +154,9 @@ async def generate(request: Request, body: GenerateRequest):
186
 
187
  except HTTPException as e:
188
  raise e
189
- except ValidationError as e:
190
- raise HTTPException(status_code=422, detail=e.errors())
191
  except Exception as e:
192
- logging.exception(f"An unexpected error occurred: {e}")
193
- raise HTTPException(status_code=500, detail="An unexpected error occurred.")
194
 
195
 
196
  if __name__ == "__main__":
197
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import logging
 
3
  from io import BytesIO
 
4
 
5
+ from fastapi import FastAPI, HTTPException, Response, Request
6
  from fastapi.responses import StreamingResponse
7
+ from pydantic import BaseModel
8
  from transformers import (
9
  AutoConfig,
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
  pipeline,
13
+ GenerationConfig
 
14
  )
15
  import boto3
16
  from huggingface_hub import hf_hub_download
 
18
  import numpy as np
19
  import torch
20
  import uvicorn
21
+ from tqdm import tqdm
22
 
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
 
25
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
26
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
30
 
31
  class GenerateRequest(BaseModel):
32
  model_name: str
33
+ input_text: str
34
  task_type: str
35
  temperature: float = 1.0
36
  max_new_tokens: int = 200
 
40
  repetition_penalty: float = 1.0
41
  num_return_sequences: int = 1
42
  do_sample: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  class S3ModelLoader:
45
  def __init__(self, bucket_name, s3_client):
 
55
  logging.info(f"Trying to load {model_name} from S3...")
56
  config = AutoConfig.from_pretrained(s3_uri)
57
  model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
58
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri)
 
 
 
 
59
  logging.info(f"Loaded {model_name} from S3 successfully.")
60
  return model, tokenizer
61
  except EnvironmentError:
62
  logging.info(f"Model {model_name} not found in S3. Downloading...")
63
  try:
64
+ with tqdm(unit="B", unit_scale=True, desc=f"Downloading {model_name}") as t:
65
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, _tqdm=t)
66
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
 
 
 
 
67
  logging.info(f"Downloaded {model_name} successfully.")
68
  logging.info(f"Saving {model_name} to S3...")
69
  model.save_pretrained(s3_uri)
 
71
  logging.info(f"Saved {model_name} to S3 successfully.")
72
  return model, tokenizer
73
  except Exception as e:
74
+ logging.error(f"Error downloading/uploading model: {e}")
75
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
76
 
77
  app = FastAPI()
 
82
  @app.post("/generate")
83
  async def generate(request: Request, body: GenerateRequest):
84
  try:
85
+ model, tokenizer = await model_loader.load_model_and_tokenizer(body.model_name)
 
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  model.to(device)
88
 
89
+ if body.task_type == "text-to-text":
90
  generation_config = GenerationConfig(
91
+ temperature=body.temperature,
92
+ max_new_tokens=body.max_new_tokens,
93
+ top_p=body.top_p,
94
+ top_k=body.top_k,
95
+ repetition_penalty=body.repetition_penalty,
96
+ do_sample=body.do_sample,
97
+ num_return_sequences=body.num_return_sequences
98
  )
99
 
100
  async def stream_text():
101
+ input_text = body.input_text
 
102
  max_length = model.config.max_position_embeddings
103
+ generated_text = ""
104
 
105
  while True:
106
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
107
+ input_length = inputs.input_ids.shape[1]
108
  remaining_tokens = max_length - input_length
109
+ if remaining_tokens < body.max_new_tokens:
110
+ generation_config.max_new_tokens = remaining_tokens
111
+ if remaining_tokens <= 0:
112
+ break
113
 
114
+ output = model.generate(**inputs, generation_config=generation_config)
 
 
 
 
 
 
 
 
 
115
  chunk = tokenizer.decode(output[0], skip_special_tokens=True)
116
  generated_text += chunk
117
  yield chunk
118
+ if len(tokenizer.encode(generated_text)) >= max_length:
119
+ break
120
+ input_text = chunk
121
 
122
+ if body.stream:
123
  return StreamingResponse(stream_text(), media_type="text/plain")
124
  else:
125
  generated_text = ""
 
127
  generated_text += chunk
128
  return {"result": generated_text}
129
 
130
+ elif body.task_type == "text-to-image":
131
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
132
+ image = generator(body.input_text)[0]
133
  image_bytes = image.tobytes()
134
  return Response(content=image_bytes, media_type="image/png")
135
 
136
+ elif body.task_type == "text-to-speech":
137
  generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
138
+ audio = generator(body.input_text)
139
  audio_bytesio = BytesIO()
140
  sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
141
  audio_bytes = audio_bytesio.getvalue()
142
  return Response(content=audio_bytes, media_type="audio/wav")
143
 
144
+ elif body.task_type == "text-to-video":
145
  try:
146
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
147
+ video = generator(body.input_text)
148
  return Response(content=video, media_type="video/mp4")
149
  except Exception as e:
150
  raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
 
154
 
155
  except HTTPException as e:
156
  raise e
 
 
157
  except Exception as e:
158
+ raise HTTPException(status_code=500, detail=str(e))
 
159
 
160
 
161
  if __name__ == "__main__":
162
+ uvicorn.run(app, host="0.0.0.0", port=8000)