File size: 11,874 Bytes
1064fad
f56cbc6
42861e8
f56cbc6
 
 
 
 
972e5ee
f56cbc6
972e5ee
f56cbc6
8becaf9
 
f5b9942
f56cbc6
 
 
8becaf9
 
 
f56cbc6
 
 
972e5ee
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
42861e8
f56cbc6
 
 
42861e8
 
 
 
 
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
8becaf9
f56cbc6
8becaf9
972e5ee
f56cbc6
8becaf9
972e5ee
f56cbc6
 
 
 
8becaf9
f56cbc6
 
8becaf9
f56cbc6
 
e58c8bb
f56cbc6
e58c8bb
 
 
 
 
 
1064fad
 
 
 
 
 
 
 
 
 
 
972e5ee
f56cbc6
8becaf9
1064fad
 
 
 
 
 
 
 
 
 
 
 
 
 
f56cbc6
 
1064fad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56cbc6
 
 
8becaf9
1064fad
 
 
 
 
 
 
 
 
 
f56cbc6
8becaf9
f56cbc6
1064fad
 
 
f56cbc6
1064fad
 
 
 
 
 
 
 
972e5ee
e58c8bb
44af224
 
1064fad
 
 
44af224
1064fad
 
44af224
 
972e5ee
8becaf9
 
 
685ddd1
8becaf9
 
 
 
 
 
 
 
 
 
 
 
f56cbc6
e58c8bb
 
 
 
2c59376
 
 
 
 
 
 
f58e444
2c59376
f58e444
e58c8bb
 
 
 
f56cbc6
 
 
 
42861e8
44af224
42861e8
 
f56cbc6
 
f5b9942
e58c8bb
42861e8
 
 
 
 
44af224
f5b9942
 
972e5ee
 
 
 
f5b9942
972e5ee
42861e8
972e5ee
 
f5b9942
972e5ee
 
8becaf9
972e5ee
42861e8
8becaf9
972e5ee
 
f56cbc6
8becaf9
972e5ee
 
f56cbc6
972e5ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
from huggingface_hub import HfApi, hf_hub_download
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, field_validator
import requests
import boto3
from dotenv import load_dotenv
import os
import uvicorn
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer
import safetensors.torch
import torch
from fastapi.responses import StreamingResponse
from tqdm import tqdm
import logging
import json

load_dotenv()

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

s3_client = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    region_name=AWS_REGION
)

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_id: str
    pipeline_task: str
    input_text: str

    @field_validator('model_id')
    def validate_model_id(cls, value):
        if not value:
            raise ValueError("model_id cannot be empty")
        return value

class S3DirectStream:
    def __init__(self, bucket_name):
        self.s3_client = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            region_name=AWS_REGION
        )
        self.bucket_name = bucket_name

    def stream_from_s3(self, key):
        try:
            logger.info(f"Downloading {key} from S3...")
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
            logger.info(f"Downloaded {key} from S3 successfully.")
            return response['Body']
        except self.s3_client.exceptions.NoSuchKey:
            logger.error(f"File {key} not found in S3")
            raise HTTPException(status_code=404, detail=f"File {key} not found in S3")

    def file_exists_in_s3(self, key):
        try:
            self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
            logger.info(f"File {key} exists in S3.")
            return True
        except self.s3_client.exceptions.ClientError:
            logger.info(f"File {key} does not exist in S3.")
            return False

    def load_model_from_stream(self, model_prefix):
        try:
            logger.info(f"Loading model {model_prefix}...")
            revision = self._get_latest_revision(model_prefix)
            if revision is None:
                logger.error(f"Could not determine revision for {model_prefix}")
                raise ValueError(f"Could not determine revision for {model_prefix}")

            config = self._load_config(model_prefix, revision)
            if config is None:
                logger.error(f"Failed to load config for {model_prefix}")
                raise ValueError(f"Failed to load config for {model_prefix}")

            model = self._load_model(model_prefix, config, revision)
            if model is None:
                logger.error(f"Failed to load model {model_prefix}")
                raise ValueError(f"Failed to load model {model_prefix}")

            return model

        except HTTPException as e:
            logger.error(f"Error loading model: {e}")
            raise
        except Exception as e:
            logger.exception(f"Unexpected error loading model: {e}")
            raise HTTPException(status_code=500, detail=f"An unexpected error occurred while loading the model.")

    def _load_config(self, model_prefix, revision):
        try:
            logger.info(f"Downloading config for {model_prefix} (revision {revision})...")
            config_path = hf_hub_download(repo_id=model_prefix, filename="config.json", revision=revision)
            with open(config_path, "r", encoding="utf-8") as f:
                config_dict = json.load(f)
            return AutoConfig.from_pretrained(model_prefix, **config_dict)
        except Exception as e:
            logger.error(f"Error loading config: {e}")
            return None

    def _load_model(self, model_prefix, config, revision):
        try:
            logger.info(f"Downloading model files for {model_prefix} (revision {revision})...")
            model_files = self._get_model_files(model_prefix, revision)
            if not model_files:
                logger.error(f"No model files found for {model_prefix}")
                return None

            state_dict = {}
            for model_file in model_files:
                logger.info(f"Downloading model file: {model_file}")
                file_path = hf_hub_download(repo_id=model_prefix, filename=model_file, revision=revision)
                with open(file_path, "rb") as f:
                    if model_file.endswith(".safetensors"):
                        shard_state = safetensors.torch.load_file(file_path)
                    elif model_file.endswith(".bin"):
                        shard_state = torch.load(f, map_location="cpu")
                    else:
                        logger.error(f"Unsupported model file type: {model_file}")
                        raise ValueError(f"Unsupported model file type: {model_file}")
                    state_dict.update(shard_state)

            model = AutoModelForCausalLM.from_config(config)
            model.load_state_dict(state_dict)
            return model

        except Exception as e:
            logger.exception(f"Error loading model: {e}")
            return None

    def load_tokenizer_from_stream(self, model_prefix):
        try:
            logger.info(f"Loading tokenizer for {model_prefix}...")
            revision = self._get_latest_revision(model_prefix)
            if revision is None:
                logger.error(f"Could not determine revision for {model_prefix}")
                raise ValueError(f"Could not determine revision for {model_prefix}")

            tokenizer = self._load_tokenizer(model_prefix, revision)
            if tokenizer is None:
                logger.error(f"Failed to load tokenizer for {model_prefix}")
                raise ValueError(f"Failed to load tokenizer for {model_prefix}")
            return tokenizer
        except HTTPException as e:
            logger.error(f"Error loading tokenizer: {e}")
            return None
        except Exception as e:
            logger.exception(f"Unexpected error loading tokenizer: {e}")
            raise HTTPException(status_code=500, detail=f"An unexpected error occurred while loading the tokenizer.")

    def _load_tokenizer(self, model_prefix, revision):
        try:
            logger.info(f"Downloading tokenizer for {model_prefix} (revision {revision})...")
            tokenizer_path = hf_hub_download(repo_id=model_prefix, filename="tokenizer.json", revision=revision)
            return AutoTokenizer.from_pretrained(tokenizer_path)
        except Exception as e:
            logger.error(f"Error loading tokenizer: {e}")
            return None


    def _get_model_files(self, model_prefix, revision="main"):
        try:
            api = HfApi()
            model_files = api.list_repo_files(model_prefix, revision=revision)
            model_files = [file.rfilename for file in model_files if file.rfilename.endswith(('.bin', '.safetensors'))]
            return model_files
        except Exception as e:
            logger.error(f"Error retrieving model files from Hugging Face: {e}")
            raise HTTPException(status_code=500, detail=f"Error retrieving model files from Hugging Face") from e

    def download_and_upload_to_s3_url(self, url, s3_key):
        logger.info(f"Downloading from {url}...")
        with requests.get(url, stream=True) as response:
            if response.status_code == 200:
                total_size_in_bytes = int(response.headers.get('content-length', 0))
                block_size = 1024
                progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
                logger.info(f"Uploading to S3: {s3_key}...")
                self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
                progress_bar.close()
                logger.info(f"Uploaded {s3_key} to S3 successfully.")
            elif response.status_code == 404:
                logger.error(f"File not found at {url}")
                raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
            else:
                logger.error(f"Error downloading from {url}: Status code {response.status_code}")
                raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")

    def _get_latest_revision(self, model_prefix):
        try:
            api = HfApi()
            model_info = api.model_info(model_prefix)
            if hasattr(model_info, 'revision'):
                revision = model_info.revision
                if revision:
                    return revision
                else:
                    logger.warning(f"No revision found for {model_prefix}, using 'main'")
                    return "main"
            else:
                logger.warning(f"ModelInfo object for {model_prefix} does not have a 'revision' attribute, using 'main'")
                return "main"
        except Exception as e:
            logger.error(f"Error getting latest revision for {model_prefix}: {e}")
            return None


@app.post("/predict/")
async def predict(model_request: DownloadModelRequest):
    try:
        logger.info(f"Received request: Model={model_request.model_id}, Task={model_request.pipeline_task}, Input={model_request.input_text}")
        model_id = model_request.model_id
        task = model_request.pipeline_task
        input_text = model_request.input_text

        streamer = S3DirectStream(S3_BUCKET_NAME)
        logger.info("Loading model and tokenizer...")
        model = streamer.load_model_from_stream(model_id)

        if model is None:
            logger.error(f"Failed to load model {model_id}")
            raise HTTPException(status_code=500, detail=f"Failed to load model {model_id}")

        tokenizer = streamer.load_tokenizer_from_stream(model_id)
        logger.info("Model and tokenizer loaded.")

        if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering",  "summarization", "zero-shot-classification"]:
            raise HTTPException(status_code=400, detail="Unsupported pipeline task")

        if task == "text-generation":
            logger.info("Starting text generation...")
            text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
            generation_kwargs = dict(inputs, streamer=text_streamer)
            model.generate(**generation_kwargs)
            logger.info("Text generation finished.")
            return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")
        else:
            logger.info(f"Starting pipeline task: {task}...")
            nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
            outputs = nlp_pipeline(input_text)
            logger.info(f"Pipeline task {task} finished.")
            return {"result": outputs}

    except Exception as e:
        logger.exception(f"Error processing request: {e}")
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)