File size: 10,363 Bytes
f56cbc6
 
 
 
 
 
 
 
972e5ee
f56cbc6
972e5ee
f56cbc6
8becaf9
 
f5b9942
f56cbc6
 
 
8becaf9
 
 
f56cbc6
 
 
972e5ee
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
972e5ee
f56cbc6
 
 
 
 
 
 
 
 
 
 
f5b9942
f56cbc6
 
8becaf9
f56cbc6
8becaf9
972e5ee
f56cbc6
8becaf9
972e5ee
f56cbc6
 
 
 
8becaf9
f56cbc6
 
8becaf9
f56cbc6
 
 
 
8becaf9
f5b9942
f56cbc6
 
8becaf9
f56cbc6
972e5ee
8becaf9
972e5ee
8becaf9
f56cbc6
 
8becaf9
f56cbc6
 
f5b9942
f56cbc6
8becaf9
f56cbc6
f5b9942
 
8becaf9
f56cbc6
f5b9942
f56cbc6
8becaf9
f56cbc6
972e5ee
 
8becaf9
f5b9942
972e5ee
8becaf9
f56cbc6
972e5ee
8becaf9
972e5ee
8becaf9
972e5ee
8becaf9
 
f5b9942
f56cbc6
 
 
 
8becaf9
f5b9942
f56cbc6
8becaf9
f56cbc6
8becaf9
f5b9942
8becaf9
972e5ee
8becaf9
f56cbc6
 
8becaf9
f56cbc6
 
f5b9942
f56cbc6
8becaf9
f56cbc6
f5b9942
 
8becaf9
f56cbc6
 
972e5ee
f5b9942
972e5ee
f56cbc6
 
 
 
 
f5b9942
8becaf9
f56cbc6
 
 
 
8becaf9
972e5ee
 
8becaf9
f5b9942
8becaf9
 
 
 
 
 
f5b9942
8becaf9
 
 
f5b9942
8becaf9
 
 
 
 
 
f56cbc6
 
 
 
 
8becaf9
f5b9942
f56cbc6
 
 
 
f5b9942
f56cbc6
f5b9942
f56cbc6
f5b9942
 
f56cbc6
 
972e5ee
 
 
f5b9942
972e5ee
f5b9942
972e5ee
 
 
 
f5b9942
 
972e5ee
 
 
8becaf9
972e5ee
 
8becaf9
972e5ee
 
f5b9942
 
f56cbc6
8becaf9
972e5ee
 
f56cbc6
f5b9942
f56cbc6
972e5ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from huggingface_hub import HfApi
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
import boto3
from dotenv import load_dotenv
import os
import uvicorn
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer
import safetensors.torch
import torch
from fastapi.responses import StreamingResponse
from tqdm import tqdm
import logging
import json

load_dotenv()

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

s3_client = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    region_name=AWS_REGION
)

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_name: str
    pipeline_task: str
    input_text: str
    revision: str = "main"

class S3DirectStream:
    def __init__(self, bucket_name):
        self.s3_client = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            region_name=AWS_REGION
        )
        self.bucket_name = bucket_name


    def stream_from_s3(self, key):
        try:
            logger.info(f"Downloading {key} from S3...")
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
            logger.info(f"Downloaded {key} from S3 successfully.")
            return response['Body']
        except self.s3_client.exceptions.NoSuchKey:
            logger.error(f"File {key} not found in S3")
            raise HTTPException(status_code=404, detail=f"File {key} not found in S3")

    def file_exists_in_s3(self, key):
        try:
            self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
            logger.info(f"File {key} exists in S3.")
            return True
        except self.s3_client.exceptions.ClientError:
            logger.info(f"File {key} does not exist in S3.")
            return False

    def load_model_from_stream(self, model_prefix, revision):
        try:
            logger.info(f"Loading model {model_prefix} (revision {revision})...")

            if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
               (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
                logger.info(f"Model {model_prefix} found in S3. Loading...")
                return self.load_model_from_existing_s3(model_prefix)

            logger.info(f"Model {model_prefix} not found in S3. Downloading and uploading...")
            self.download_and_upload_to_s3(model_prefix, revision)
            logger.info(f"Downloaded and uploaded {model_prefix}. Loading from S3...")
            return self.load_model_from_stream(model_prefix, revision)
        except HTTPException as e:
            logger.error(f"Error loading model: {e}")
            return None


    def load_model_from_existing_s3(self, model_prefix):
        logger.info(f"Loading config for {model_prefix} from S3...")
        config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
        config_dict = json.load(config_stream)
        config = AutoConfig.from_pretrained(config_dict)
        logger.info(f"Config loaded for {model_prefix}.")


        if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
            logger.info(f"Loading safetensors model for {model_prefix} from S3...")
            model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
            model = AutoModelForCausalLM.from_config(config)
            model.load_state_dict(safetensors.torch.load_stream(model_stream))
            logger.info(f"Safetensors model loaded for {model_prefix}.")

        elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
            logger.info(f"Loading PyTorch model for {model_prefix} from S3...")
            model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
            model = AutoModelForCausalLM.from_config(config)
            state_dict = torch.load(model_stream, map_location="cpu")
            model.load_state_dict(state_dict)
            logger.info(f"PyTorch model loaded for {model_prefix}.")
        else:
            logger.error(f"No model file found for {model_prefix} in S3")
            raise EnvironmentError(f"No model file found for {model_prefix} in S3")

        return model

    def load_tokenizer_from_stream(self, model_prefix):
        try:
            logger.info(f"Loading tokenizer for {model_prefix}...")

            if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
                logger.info(f"Tokenizer for {model_prefix} found in S3. Loading...")
                return self.load_tokenizer_from_existing_s3(model_prefix)


            logger.info(f"Tokenizer for {model_prefix} not found in S3. Downloading and uploading...")
            self.download_and_upload_to_s3(model_prefix)
            logger.info(f"Downloaded and uploaded tokenizer for {model_prefix}. Loading from S3...")
            return self.load_tokenizer_from_stream(model_prefix)
        except HTTPException as e:
            logger.error(f"Error loading tokenizer: {e}")
            return None


    def load_tokenizer_from_existing_s3(self, model_prefix):
        logger.info(f"Loading tokenizer from S3 for {model_prefix}...")
        tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
        tokenizer_config = json.load(tokenizer_stream) # Corrected this line
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=None, config=tokenizer_config) # Corrected this line
        logger.info(f"Tokenizer loaded for {model_prefix}.")
        return tokenizer



    def download_and_upload_to_s3(self, model_prefix, revision="main"):
        model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
        safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
        tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
        config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"


        logger.info(f"Downloading and uploading model files for {model_prefix} to S3...")
        self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
        self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
        self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
        self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")
        logger.info(f"Finished downloading and uploading model files for {model_prefix}.")

    def download_and_upload_to_s3_url(self, url, s3_key):
        logger.info(f"Downloading from {url}...")

        with requests.get(url, stream=True) as response:
            if response.status_code == 200:
                total_size_in_bytes= int(response.headers.get('content-length', 0))
                block_size = 1024
                progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
                logger.info(f"Uploading to S3: {s3_key}...")

                self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key, Callback=lambda bytes_transferred: progress_bar.update(bytes_transferred))
                progress_bar.close()
                logger.info(f"Uploaded {s3_key} to S3 successfully.")

            elif response.status_code == 404:
                logger.error(f"File not found at {url}")
                raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")
            else:
                logger.error(f"Error downloading from {url}: Status code {response.status_code}")
                raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")


@app.post("/predict/")
async def predict(model_request: DownloadModelRequest):
    try:
        logger.info(f"Received request: Model={model_request.model_name}, Task={model_request.pipeline_task}, Input={model_request.input_text}")

        model_name = model_request.model_name
        revision = model_request.revision

        streamer = S3DirectStream(S3_BUCKET_NAME)
        logger.info("Loading model and tokenizer...")
        model = streamer.load_model_from_stream(model_name, revision)

        tokenizer = streamer.load_tokenizer_from_stream(model_name)
        logger.info("Model and tokenizer loaded.")


        task = model_request.pipeline_task
        if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering",  "summarization", "zero-shot-classification"]:
            raise HTTPException(status_code=400, detail="Unsupported pipeline task")


        if task == "text-generation":
            logger.info("Starting text generation...")
            text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
            generation_kwargs = dict(inputs, streamer=text_streamer)
            model.generate(**generation_kwargs)
            logger.info("Text generation finished.")

            return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")

        else:
            logger.info(f"Starting pipeline task: {task}...")
            nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
            outputs = nlp_pipeline(model_request.input_text)
            logger.info(f"Pipeline task {task} finished.")
            return {"result": outputs}



    except Exception as e:
        logger.exception(f"Error processing request: {e}")
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")



if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)