Hjgugugjhuhjggg commited on
Commit
42861e8
·
verified ·
1 Parent(s): 685ddd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from huggingface_hub import HfApi
2
  from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel
4
  import requests
5
  import boto3
6
  from dotenv import load_dotenv
@@ -35,11 +35,17 @@ s3_client = boto3.client(
35
  app = FastAPI()
36
 
37
  class DownloadModelRequest(BaseModel):
38
- model_name: str
39
  pipeline_task: str
40
  input_text: str
41
  revision: str = "main"
42
 
 
 
 
 
 
 
43
  class S3DirectStream:
44
  def __init__(self, bucket_name):
45
  self.s3_client = boto3.client(
@@ -89,7 +95,7 @@ class S3DirectStream:
89
  logger.info(f"Loading config for {model_prefix} from S3...")
90
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
91
  config_dict = json.load(config_stream)
92
- config = AutoConfig.from_dict(config_dict)
93
  logger.info(f"Config loaded for {model_prefix}.")
94
 
95
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
@@ -115,7 +121,7 @@ class S3DirectStream:
115
  logger.info(f"Loading tokenizer for {model_prefix}...")
116
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
117
  logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
118
- return self.load_tokenizer_from_existing_s3(model_prefix, config) # Pass config
119
 
120
  logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
121
  self.download_and_upload_to_s3(model_prefix)
@@ -125,16 +131,13 @@ class S3DirectStream:
125
  logger.error(f"Error loading tokenizer: {e}")
126
  return None
127
 
128
-
129
- def load_tokenizer_from_existing_s3(self, model_prefix, config): # Recieve config
130
  logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
131
-
132
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
133
- tokenizer = AutoTokenizer.from_pretrained(None, config=config) # Use config
134
  logger.info(f"Tokenizer loaded for {model_prefix}.")
135
  return tokenizer
136
 
137
-
138
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
139
  model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
140
  safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
@@ -148,7 +151,6 @@ class S3DirectStream:
148
  self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
149
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
150
 
151
-
152
  def download_and_upload_to_s3_url(self, url, s3_key):
153
  logger.info(f"Downloading from {url}...")
154
  with requests.get(url, stream=True) as response:
@@ -163,7 +165,6 @@ class S3DirectStream:
163
  elif response.status_code == 404:
164
  logger.error(f"File not found at {url}")
165
  raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
166
-
167
  else:
168
  logger.error(f"Error downloading from {url}: Status code {response.status_code}")
169
  raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
@@ -172,33 +173,40 @@ class S3DirectStream:
172
  @app.post("/predict/")
173
  async def predict(model_request: DownloadModelRequest):
174
  try:
175
- logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
176
- model_name = model_request.model_name
177
  revision = model_request.revision
 
 
178
 
179
  streamer = S3DirectStream(S3_BUCKET_NAME)
180
  logger.info("Loading model and tokenizer...")
181
- model = streamer.load_model_from_stream(model_name, revision)
182
- tokenizer = streamer.load_tokenizer_from_stream(model_name) # Moved after model loading
 
 
 
 
 
 
183
  logger.info("Model and tokenizer loaded.")
184
 
185
- task = model_request.pipeline_task
186
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
187
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
188
 
189
  if task == "text-generation":
190
  logger.info("Starting text generation...")
191
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
192
- inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
193
  generation_kwargs = dict(inputs, streamer=text_streamer)
194
  model.generate(**generation_kwargs)
195
  logger.info("Text generation finished.")
196
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
197
-
198
  else:
199
  logger.info(f"Starting pipeline task: {task}...")
200
  nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
201
- outputs = nlp_pipeline(model_request.input_text)
202
  logger.info(f"Pipeline task {task} finished.")
203
  return {"result": outputs}
204
 
@@ -206,6 +214,5 @@ async def predict(model_request: DownloadModelRequest):
206
  logger.exception(f"Error processing request: {e}")
207
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
208
 
209
-
210
  if __name__ == "__main__":
211
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from huggingface_hub import HfApi
2
  from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel, field_validator
4
  import requests
5
  import boto3
6
  from dotenv import load_dotenv
 
35
  app = FastAPI()
36
 
37
  class DownloadModelRequest(BaseModel):
38
+ model_id: str
39
  pipeline_task: str
40
  input_text: str
41
  revision: str = "main"
42
 
43
+ @field_validator('model_id')
44
+ def validate_model_id(cls, value):
45
+ if not value:
46
+ raise ValueError("model_id cannot be empty")
47
+ return value
48
+
49
  class S3DirectStream:
50
  def __init__(self, bucket_name):
51
  self.s3_client = boto3.client(
 
95
  logger.info(f"Loading config for {model_prefix} from S3...")
96
  config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
97
  config_dict = json.load(config_stream)
98
+ config = AutoConfig.from_pretrained(model_prefix, **config_dict)
99
  logger.info(f"Config loaded for {model_prefix}.")
100
 
101
  if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
 
121
  logger.info(f"Loading tokenizer for {model_prefix}...")
122
  if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
123
  logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
124
+ return self.load_tokenizer_from_existing_s3(model_prefix, config)
125
 
126
  logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
127
  self.download_and_upload_to_s3(model_prefix)
 
131
  logger.error(f"Error loading tokenizer: {e}")
132
  return None
133
 
134
+ def load_tokenizer_from_existing_s3(self, model_prefix, config):
 
135
  logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
 
136
  tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
137
+ tokenizer = AutoTokenizer.from_pretrained(None, config=config)
138
  logger.info(f"Tokenizer loaded for {model_prefix}.")
139
  return tokenizer
140
 
 
141
  def download_and_upload_to_s3(self, model_prefix, revision="main"):
142
  model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
143
  safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
 
151
  self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
152
  logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
153
 
 
154
  def download_and_upload_to_s3_url(self, url, s3_key):
155
  logger.info(f"Downloading from {url}...")
156
  with requests.get(url, stream=True) as response:
 
165
  elif response.status_code == 404:
166
  logger.error(f"File not found at {url}")
167
  raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
 
168
  else:
169
  logger.error(f"Error downloading from {url}: Status code {response.status_code}")
170
  raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
 
173
  @app.post("/predict/")
174
  async def predict(model_request: DownloadModelRequest):
175
  try:
176
+ logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
177
+ model_id = model_request.model_id # Fixed: Use model_id, not model_name
178
  revision = model_request.revision
179
+ task = model_request.pipeline_task
180
+ input_text = model_request.input_text
181
 
182
  streamer = S3DirectStream(S3_BUCKET_NAME)
183
  logger.info("Loading model and tokenizer...")
184
+ model = streamer.load_model_from_stream(model_id, revision) # Use model_id
185
+
186
+ if model is None:
187
+ logger.error(f"Failed to load model {model_id}")
188
+ raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
189
+
190
+ tokenizer = streamer.load_tokenizer_from_stream(model_id) # Use model_id
191
+
192
  logger.info("Model and tokenizer loaded.")
193
 
194
+
195
  if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
196
  raise HTTPException(status_code=400, detail="Unsupported pipeline task")
197
 
198
  if task == "text-generation":
199
  logger.info("Starting text generation...")
200
  text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
201
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
202
  generation_kwargs = dict(inputs, streamer=text_streamer)
203
  model.generate(**generation_kwargs)
204
  logger.info("Text generation finished.")
205
  return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
 
206
  else:
207
  logger.info(f"Starting pipeline task: {task}...")
208
  nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
209
+ outputs = nlp_pipeline(input_text)
210
  logger.info(f"Pipeline task {task} finished.")
211
  return {"result": outputs}
212
 
 
214
  logger.exception(f"Error processing request: {e}")
215
  raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
216
 
 
217
  if __name__ == "__main__":
218
  uvicorn.run(app, host="0.0.0.0", port=7860)