File size: 12,443 Bytes
b88653d
f56cbc6
42861e8
f56cbc6
 
 
 
 
972e5ee
f56cbc6
972e5ee
f56cbc6
8becaf9
 
f5b9942
f56cbc6
 
 
8becaf9
 
 
f56cbc6
 
 
972e5ee
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
42861e8
f56cbc6
 
 
42861e8
 
 
 
 
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
8becaf9
f56cbc6
8becaf9
972e5ee
f56cbc6
8becaf9
972e5ee
f56cbc6
 
 
 
8becaf9
f56cbc6
 
8becaf9
f56cbc6
 
e58c8bb
f56cbc6
e58c8bb
b88653d
 
 
 
 
 
 
 
 
f56cbc6
8becaf9
f56cbc6
 
b88653d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56cbc6
 
 
8becaf9
b88653d
 
 
 
 
 
 
 
f56cbc6
8becaf9
f56cbc6
 
b88653d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
972e5ee
e58c8bb
44af224
b88653d
44af224
b88653d
 
 
 
 
 
44af224
b88653d
 
 
 
 
44af224
 
972e5ee
8becaf9
 
 
685ddd1
8becaf9
 
 
 
 
 
 
 
 
 
 
 
f56cbc6
e58c8bb
 
 
 
2c59376
 
 
 
 
 
 
f58e444
2c59376
f58e444
e58c8bb
 
 
 
f56cbc6
 
 
 
42861e8
44af224
42861e8
 
f56cbc6
 
f5b9942
e58c8bb
42861e8
 
 
 
 
44af224
f5b9942
 
972e5ee
 
 
 
f5b9942
972e5ee
42861e8
972e5ee
 
f5b9942
972e5ee
 
8becaf9
972e5ee
42861e8
8becaf9
972e5ee
 
f56cbc6
8becaf9
972e5ee
 
f56cbc6
972e5ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from huggingface_hub import HfApi
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, field_validator
import requests
import boto3
from dotenv import load_dotenv
import os
import uvicorn
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer
import safetensors.torch
import torch
from fastapi.responses import StreamingResponse
from tqdm import tqdm
import logging
import json

load_dotenv()

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

s3_client = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    region_name=AWS_REGION
)

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_id: str
    pipeline_task: str
    input_text: str

    @field_validator('model_id')
    def validate_model_id(cls, value):
        if not value:
            raise ValueError("model_id cannot be empty")
        return value

class S3DirectStream:
    def __init__(self, bucket_name):
        self.s3_client = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            region_name=AWS_REGION
        )
        self.bucket_name = bucket_name

    def stream_from_s3(self, key):
        try:
            logger.info(f"Downloading {key} from S3...")
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
            logger.info(f"Downloaded {key} from S3 successfully.")
            return response['Body']
        except self.s3_client.exceptions.NoSuchKey:
            logger.error(f"File {key} not found in S3")
            raise HTTPException(status_code=404, detail=f"File {key} not found in S3")

    def file_exists_in_s3(self, key):
        try:
            self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
            logger.info(f"File {key} exists in S3.")
            return True
        except self.s3_client.exceptions.ClientError:
            logger.info(f"File {key} does not exist in S3.")
            return False

    def load_model_from_stream(self, model_prefix):
        try:
            logger.info(f"Loading model {model_prefix}...")
            if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
               any(self.file_exists_in_s3(f"{model_prefix}/{file}") for file in self._get_model_files(model_prefix)):
                logger.info(f"Model {model_prefix} found in S3. Loading...")
                return self.load_model_from_existing_s3(model_prefix)

            logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
            self.download_and_upload_to_s3(model_prefix)
            logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
            return self.load_model_from_stream(model_prefix)
        except HTTPException as e:
            logger.error(f"Error loading model: {e}")
            return None

    def load_model_from_existing_s3(self, model_prefix):
        logger.info(f"Loading config for {model_prefix} from S3...")
        config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
        config_dict = json.load(config_stream)
        config = AutoConfig.from_pretrained(model_prefix, **config_dict)
        logger.info(f"Config loaded for {model_prefix}.")

        model_files = self._get_model_files(model_prefix)
        if not model_files:
            logger.error(f"No model files found for {model_prefix} in S3")
            raise EnvironmentError(f"No model files found for {model_prefix} in S3")

        state_dict = {}
        for model_file in model_files:
            model_path = os.path.join(model_prefix, model_file)
            logger.info(f"Loading model file: {model_path}")
            model_stream = self.stream_from_s3(model_path)
            try:
                if model_path.endswith(".safetensors"):
                    shard_state = safetensors.torch.load_stream(model_stream)
                elif model_path.endswith(".bin"):
                    shard_state = torch.load(model_stream, map_location="cpu")
                else:
                    logger.error(f"Unsupported model file type: {model_path}")
                    raise ValueError(f"Unsupported model file type: {model_path}")

                state_dict.update(shard_state)
            except Exception as e:
                logger.exception(f"Error loading model file {model_path}: {e}")
                raise

        model = AutoModelForCausalLM.from_config(config)
        model.load_state_dict(state_dict)
        return model

    def load_tokenizer_from_stream(self, model_prefix):
        try:
            logger.info(f"Loading tokenizer for {model_prefix}...")
            if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
                logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
                return self.load_tokenizer_from_existing_s3(model_prefix, config)

            logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
            self.download_and_upload_to_s3(model_prefix)
            logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
            return self.load_tokenizer_from_stream(model_prefix)
        except HTTPException as e:
            logger.error(f"Error loading tokenizer: {e}")
            return None

    def load_tokenizer_from_existing_s3(self, model_prefix, config):
        logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
        tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
        tokenizer = AutoTokenizer.from_pretrained(None, config=config)
        logger.info(f"Tokenizer loaded for {model_prefix}.")
        return tokenizer

    def download_and_upload_to_s3(self, model_prefix, revision="main"):
        logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
        config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"
        self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")

        model_files = self._get_model_files(model_prefix, revision)
        for model_file in model_files:
            url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/{model_file}"
            s3_key = f"{model_prefix}/{model_file}"
            self.download_and_upload_to_s3_url(url, s3_key)
            logger.info(f"Downloaded and uploaded {s3_key}")

        tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
        self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")

        logger.info(f"Finished downloading and uploading model files for {model_prefix}.")


    def _get_model_files(self, model_prefix, revision="main"):
        index_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/"
        try:
            index_response = requests.get(index_url)
            index_response.raise_for_status()
            logger.info(f"Hugging Face API Response: Status Code = {index_response.status_code}, Headers = {index_response.headers}")
            index_content = index_response.text
            logger.info(f"Index content: {index_content}")
            model_files = [f for f in index_content.split('\n') if f.endswith(('.bin', '.safetensors'))]
            return model_files
        except requests.exceptions.RequestException as e:
            logger.error(f"Error retrieving model index: {e}")
            raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e
        except (IndexError, ValueError) as e:
            logger.error(f"Error parsing model file names from Hugging Face: {e}")
            raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e

    def download_and_upload_to_s3_url(self, url, s3_key):
        logger.info(f"Downloading from {url}...")
        with requests.get(url, stream=True) as response:
            if response.status_code == 200:
                total_size_in_bytes = int(response.headers.get('content-length', 0))
                block_size = 1024
                progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
                logger.info(f"Uploading to S3: {s3_key}...")
                self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
                progress_bar.close()
                logger.info(f"Uploaded {s3_key} to S3 successfully.")
            elif response.status_code == 404:
                logger.error(f"File not found at {url}")
                raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
            else:
                logger.error(f"Error downloading from {url}: Status code {response.status_code}")
                raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")

    def _get_latest_revision(self, model_prefix):
        try:
            api = HfApi()
            model_info = api.model_info(model_prefix)
            if hasattr(model_info, 'revision'):
                revision = model_info.revision
                if revision:
                    return revision
                else:
                    logger.warning(f"No revision found for {model_prefix}, using 'main'")
                    return "main"
            else:
                logger.warning(f"ModelInfo object for {model_prefix} does not have a 'revision' attribute, using 'main'")
                return "main"
        except Exception as e:
            logger.error(f"Error getting latest revision for {model_prefix}: {e}")
            return None


@app.post("/predict/")
async def predict(model_request: DownloadModelRequest):
    try:
        logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
        model_id = model_request.model_id
        task = model_request.pipeline_task
        input_text = model_request.input_text

        streamer = S3DirectStream(S3_BUCKET_NAME)
        logger.info("Loading model and tokenizer...")
        model = streamer.load_model_from_stream(model_id)

        if model is None:
            logger.error(f"Failed to load model {model_id}")
            raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")

        tokenizer = streamer.load_tokenizer_from_stream(model_id)
        logger.info("Model and tokenizer loaded.")

        if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering",  "summarization", "zero-shot-classification"]:
            raise HTTPException(status_code=400, detail="Unsupported pipeline task")

        if task == "text-generation":
            logger.info("Starting text generation...")
            text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
            generation_kwargs = dict(inputs, streamer=text_streamer)
            model.generate(**generation_kwargs)
            logger.info("Text generation finished.")
            return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
        else:
            logger.info(f"Starting pipeline task: {task}...")
            nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
            outputs = nlp_pipeline(input_text)
            logger.info(f"Pipeline task {task} finished.")
            return {"result": outputs}

    except Exception as e:
        logger.exception(f"Error processing request: {e}")
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)