File size: 6,946 Bytes
f56cbc6
 
 
 
 
 
 
 
972e5ee
f56cbc6
972e5ee
f56cbc6
 
 
 
 
 
 
972e5ee
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
972e5ee
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
 
 
972e5ee
f56cbc6
972e5ee
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
 
972e5ee
 
f56cbc6
 
 
 
 
 
972e5ee
f56cbc6
 
 
972e5ee
 
 
f56cbc6
972e5ee
 
 
 
 
f56cbc6
 
972e5ee
 
f56cbc6
 
 
 
972e5ee
f56cbc6
 
 
 
 
 
972e5ee
f56cbc6
 
972e5ee
 
f56cbc6
 
 
 
 
 
 
 
 
 
972e5ee
 
 
f56cbc6
972e5ee
 
 
 
f56cbc6
972e5ee
 
f56cbc6
 
 
 
 
 
 
 
 
 
 
 
 
972e5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f56cbc6
 
972e5ee
 
 
f56cbc6
 
 
972e5ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from huggingface_hub import HfApi
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import requests
import boto3
from dotenv import load_dotenv
import os
import uvicorn
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextIteratorStreamer
import safetensors.torch
import torch
from fastapi.responses import StreamingResponse

load_dotenv()

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

s3_client = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    region_name=AWS_REGION
)

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_name: str
    pipeline_task: str
    input_text: str
    revision: str = "main"

class S3DirectStream:
    def __init__(self, bucket_name):
        self.s3_client = boto3.client(
            's3',
            aws_access_key_id=AWS_ACCESS_KEY_ID,
            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
            region_name=AWS_REGION
        )
        self.bucket_name = bucket_name

    def stream_from_s3(self, key):
        try:
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
            return response['Body']
        except self.s3_client.exceptions.NoSuchKey:
            raise HTTPException(status_code=404, detail=f"File {key} not found in S3")

    def file_exists_in_s3(self, key):
        try:
            self.s3_client.head_object(Bucket=self.bucket_name, Key=key)
            return True
        except self.s3_client.exceptions.ClientError:
            return False

    def load_model_from_stream(self, model_prefix, revision):
        try:
            if self.file_exists_in_s3(f"{model_prefix}/config.json") and \
               (self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin") or self.file_exists_in_s3(f"{model_prefix}/model.safetensors")):
                return self.load_model_from_existing_s3(model_prefix)

            self.download_and_upload_to_s3(model_prefix, revision)
            return self.load_model_from_stream(model_prefix, revision)
        except HTTPException as e:
            return None

    def load_model_from_existing_s3(self, model_prefix):
        config_stream = self.stream_from_s3(f"{model_prefix}/config.json")
        config = AutoConfig.from_pretrained(config_stream) # Directly from stream

        if self.file_exists_in_s3(f"{model_prefix}/model.safetensors"):
            model_stream = self.stream_from_s3(f"{model_prefix}/model.safetensors")
            model = AutoModelForCausalLM.from_config(config)
            model.load_state_dict(safetensors.torch.load_stream(model_stream))
        elif self.file_exists_in_s3(f"{model_prefix}/pytorch_model.bin"):
            model_stream = self.stream_from_s3(f"{model_prefix}/pytorch_model.bin")
            model = AutoModelForCausalLM.from_config(config)
            state_dict = torch.load(model_stream, map_location="cpu")  # Load directly
            model.load_state_dict(state_dict)
        else:
           raise EnvironmentError(f"No model file found for {model_prefix} in S3")
        return model



    def load_tokenizer_from_stream(self, model_prefix):
        try:
            if self.file_exists_in_s3(f"{model_prefix}/tokenizer.json"):
                return self.load_tokenizer_from_existing_s3(model_prefix)
            self.download_and_upload_to_s3(model_prefix)
            return self.load_tokenizer_from_stream(model_prefix)
        except HTTPException as e:
            return None

    def load_tokenizer_from_existing_s3(self, model_prefix):
        tokenizer_stream = self.stream_from_s3(f"{model_prefix}/tokenizer.json")
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) # Directly from stream
        return tokenizer


    def download_and_upload_to_s3(self, model_prefix, revision="main"):
        model_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/pytorch_model.bin"
        safetensors_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/model.safetensors"
        tokenizer_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/tokenizer.json"
        config_url = f"https://huggingface.co/{model_prefix}/resolve/{revision}/config.json"

        self.download_and_upload_to_s3_url(model_url, f"{model_prefix}/pytorch_model.bin")
        self.download_and_upload_to_s3_url(safetensors_url, f"{model_prefix}/model.safetensors")
        self.download_and_upload_to_s3_url(tokenizer_url, f"{model_prefix}/tokenizer.json")
        self.download_and_upload_to_s3_url(config_url, f"{model_prefix}/config.json")


    def download_and_upload_to_s3_url(self, url, s3_key):
        response = requests.get(url, stream=True)
        if response.status_code == 200:
            self.s3_client.upload_fileobj(response.raw, self.bucket_name, s3_key) # Direct upload
        elif response.status_code == 404:
            raise HTTPException(status_code=404, detail=f"Error downloading file from {url}. File not found.")

        else:
            raise HTTPException(status_code=500, detail=f"Error downloading file from {url}")



@app.post("/predict/")
async def predict(model_request: DownloadModelRequest):
    try:
        model_name = model_request.model_name
        revision = model_request.revision

        streamer = S3DirectStream(S3_BUCKET_NAME)
        model = streamer.load_model_from_stream(model_name, revision)
        tokenizer = streamer.load_tokenizer_from_stream(model_name)

        task = model_request.pipeline_task
        if task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering",  "summarization", "zero-shot-classification"]:
            raise HTTPException(status_code=400, detail="Unsupported pipeline task")

        if task == "text-generation":
            text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
            inputs = tokenizer(model_request.input_text, return_tensors="pt").to(model.device)
            generation_kwargs = dict(inputs, streamer=text_streamer)
            model.generate(**generation_kwargs)
            return StreamingResponse(iter([tokenizer.decode(token) for token in text_streamer]), media_type="text/event-stream")

        else:
            nlp_pipeline = pipeline(task, model=model, tokenizer=tokenizer, device_map="auto", trust_remote_code=True)
            outputs = nlp_pipeline(model_request.input_text)
            return {"result": outputs}


    except Exception as e:
        print(f"Complete Error: {e}")
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")



if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)