Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from huggingface_hub import HfApi
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
-
from pydantic import BaseModel
|
4 |
import requests
|
5 |
import boto3
|
6 |
from dotenv import load_dotenv
|
@@ -35,11 +35,17 @@ s3_client = boto3.client(
|
|
35 |
app = FastAPI()
|
36 |
|
37 |
class DownloadModelRequest(BaseModel):
|
38 |
-
|
39 |
pipeline_task: str
|
40 |
input_text: str
|
41 |
revision: str = "main"
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class S3DirectStream:
|
44 |
def __init__(self, bucket_name):
|
45 |
self.s3_client = boto3.client(
|
@@ -89,7 +95,7 @@ class S3DirectStream:
|
|
89 |
logger.info(f"Loading config for {model_prefix} from S3...")
|
90 |
config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
|
91 |
config_dict = json.load(config_stream)
|
92 |
-
config = AutoConfig.
|
93 |
logger.info(f"Config loaded for {model_prefix}.")
|
94 |
|
95 |
if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
|
@@ -115,7 +121,7 @@ class S3DirectStream:
|
|
115 |
logger.info(f"Loading tokenizer for {model_prefix}...")
|
116 |
if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
|
117 |
logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
|
118 |
-
return self.load_tokenizer_from_existing_s3(model_prefix, config)
|
119 |
|
120 |
logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
|
121 |
self.download_and_upload_to_s3(model_prefix)
|
@@ -125,16 +131,13 @@ class S3DirectStream:
|
|
125 |
logger.error(f"Error loading tokenizer: {e}")
|
126 |
return None
|
127 |
|
128 |
-
|
129 |
-
def load_tokenizer_from_existing_s3(self, model_prefix, config): # Recieve config
|
130 |
logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
|
131 |
-
|
132 |
tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
|
133 |
-
tokenizer = AutoTokenizer.from_pretrained(None, config=config)
|
134 |
logger.info(f"Tokenizer loaded for {model_prefix}.")
|
135 |
return tokenizer
|
136 |
|
137 |
-
|
138 |
def download_and_upload_to_s3(self, model_prefix, revision="main"):
|
139 |
model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
|
140 |
safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
|
@@ -148,7 +151,6 @@ class S3DirectStream:
|
|
148 |
self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
|
149 |
logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
|
150 |
|
151 |
-
|
152 |
def download_and_upload_to_s3_url(self, url, s3_key):
|
153 |
logger.info(f"Downloading from {url}...")
|
154 |
with requests.get(url, stream=True) as response:
|
@@ -163,7 +165,6 @@ class S3DirectStream:
|
|
163 |
elif response.status_code == 404:
|
164 |
logger.error(f"File not found at {url}")
|
165 |
raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
|
166 |
-
|
167 |
else:
|
168 |
logger.error(f"Error downloading from {url}: Status code {response.status_code}")
|
169 |
raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
|
@@ -172,33 +173,40 @@ class S3DirectStream:
|
|
172 |
@app.post("/predict/")
|
173 |
async def predict(model_request: DownloadModelRequest):
|
174 |
try:
|
175 |
-
logger.info(f"Received request: Model={model_request.
|
176 |
-
|
177 |
revision = model_request.revision
|
|
|
|
|
178 |
|
179 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
180 |
logger.info("Loading model and tokenizer...")
|
181 |
-
model = streamer.load_model_from_stream(
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
logger.info("Model and tokenizer loaded.")
|
184 |
|
185 |
-
|
186 |
if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
|
187 |
raise HTTPException(status_code=400, detail="Unsupported pipeline task")
|
188 |
|
189 |
if task == "text-generation":
|
190 |
logger.info("Starting text generation...")
|
191 |
text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
192 |
-
inputs = tokenizer(
|
193 |
generation_kwargs = dict(inputs, streamer=text_streamer)
|
194 |
model.generate(**generation_kwargs)
|
195 |
logger.info("Text generation finished.")
|
196 |
return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
|
197 |
-
|
198 |
else:
|
199 |
logger.info(f"Starting pipeline task: {task}...")
|
200 |
nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
|
201 |
-
outputs = nlp_pipeline(
|
202 |
logger.info(f"Pipeline task {task} finished.")
|
203 |
return {"result": outputs}
|
204 |
|
@@ -206,6 +214,5 @@ async def predict(model_request: DownloadModelRequest):
|
|
206 |
logger.exception(f"Error processing request: {e}")
|
207 |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
208 |
|
209 |
-
|
210 |
if __name__ == "__main__":
|
211 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
from huggingface_hub import HfApi
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
+
from pydantic import BaseModel, field_validator
|
4 |
import requests
|
5 |
import boto3
|
6 |
from dotenv import load_dotenv
|
|
|
35 |
app = FastAPI()
|
36 |
|
37 |
class DownloadModelRequest(BaseModel):
|
38 |
+
model_id: str
|
39 |
pipeline_task: str
|
40 |
input_text: str
|
41 |
revision: str = "main"
|
42 |
|
43 |
+
@field_validator('model_id')
|
44 |
+
def validate_model_id(cls, value):
|
45 |
+
if not value:
|
46 |
+
raise ValueError("model_id cannot be empty")
|
47 |
+
return value
|
48 |
+
|
49 |
class S3DirectStream:
|
50 |
def __init__(self, bucket_name):
|
51 |
self.s3_client = boto3.client(
|
|
|
95 |
logger.info(f"Loading config for {model_prefix} from S3...")
|
96 |
config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
|
97 |
config_dict = json.load(config_stream)
|
98 |
+
config = AutoConfig.from_pretrained(model_prefix, **config_dict)
|
99 |
logger.info(f"Config loaded for {model_prefix}.")
|
100 |
|
101 |
if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
|
|
|
121 |
logger.info(f"Loading tokenizer for {model_prefix}...")
|
122 |
if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
|
123 |
logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
|
124 |
+
return self.load_tokenizer_from_existing_s3(model_prefix, config)
|
125 |
|
126 |
logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
|
127 |
self.download_and_upload_to_s3(model_prefix)
|
|
|
131 |
logger.error(f"Error loading tokenizer: {e}")
|
132 |
return None
|
133 |
|
134 |
+
def load_tokenizer_from_existing_s3(self, model_prefix, config):
|
|
|
135 |
logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
|
|
|
136 |
tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
|
137 |
+
tokenizer = AutoTokenizer.from_pretrained(None, config=config)
|
138 |
logger.info(f"Tokenizer loaded for {model_prefix}.")
|
139 |
return tokenizer
|
140 |
|
|
|
141 |
def download_and_upload_to_s3(self, model_prefix, revision="main"):
|
142 |
model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
|
143 |
safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
|
|
|
151 |
self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
|
152 |
logger.info(f"Finished downloading and uploading model files for {model_prefix}.")
|
153 |
|
|
|
154 |
def download_and_upload_to_s3_url(self, url, s3_key):
|
155 |
logger.info(f"Downloading from {url}...")
|
156 |
with requests.get(url, stream=True) as response:
|
|
|
165 |
elif response.status_code == 404:
|
166 |
logger.error(f"File not found at {url}")
|
167 |
raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
|
|
|
168 |
else:
|
169 |
logger.error(f"Error downloading from {url}: Status code {response.status_code}")
|
170 |
raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")
|
|
|
173 |
@app.post("/predict/")
|
174 |
async def predict(model_request: DownloadModelRequest):
|
175 |
try:
|
176 |
+
logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
|
177 |
+
model_id = model_request.model_id # Fixed: Use model_id, not model_name
|
178 |
revision = model_request.revision
|
179 |
+
task = model_request.pipeline_task
|
180 |
+
input_text = model_request.input_text
|
181 |
|
182 |
streamer = S3DirectStream(S3_BUCKET_NAME)
|
183 |
logger.info("Loading model and tokenizer...")
|
184 |
+
model = streamer.load_model_from_stream(model_id, revision) # Use model_id
|
185 |
+
|
186 |
+
if model is None:
|
187 |
+
logger.error(f"Failed to load model {model_id}")
|
188 |
+
raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")
|
189 |
+
|
190 |
+
tokenizer = streamer.load_tokenizer_from_stream(model_id) # Use model_id
|
191 |
+
|
192 |
logger.info("Model and tokenizer loaded.")
|
193 |
|
194 |
+
|
195 |
if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering", "summarization", "zero-shot-classification"]:
|
196 |
raise HTTPException(status_code=400, detail="Unsupported pipeline task")
|
197 |
|
198 |
if task == "text-generation":
|
199 |
logger.info("Starting text generation...")
|
200 |
text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
201 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
202 |
generation_kwargs = dict(inputs, streamer=text_streamer)
|
203 |
model.generate(**generation_kwargs)
|
204 |
logger.info("Text generation finished.")
|
205 |
return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
|
|
|
206 |
else:
|
207 |
logger.info(f"Starting pipeline task: {task}...")
|
208 |
nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
|
209 |
+
outputs = nlp_pipeline(input_text)
|
210 |
logger.info(f"Pipeline task {task} finished.")
|
211 |
return {"result": outputs}
|
212 |
|
|
|
214 |
logger.exception(f"Error processing request: {e}")
|
215 |
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
216 |
|
|
|
217 |
if __name__ == "__main__":
|
218 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|