Hjgugugjhuhjggg commited on
Commit
ad7c7d8
verified
1 Parent(s): 48f9783

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +609 -137
app.py CHANGED
@@ -1,39 +1,142 @@
1
  import os
2
  import torch
3
- from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import StreamingResponse
5
- from pydantic import BaseModel
6
  from transformers import (
7
  AutoModelForCausalLM,
8
  AutoTokenizer,
9
  GenerationConfig,
10
  StoppingCriteriaList,
11
- pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
  from io import BytesIO
14
  import boto3
15
- from botocore.exceptions import NoCredentialsError
16
  from huggingface_hub import snapshot_download
17
- import shutil
18
-
19
- # Configuraci贸n global
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
23
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
 
 
 
25
 
26
- # Diccionario global de tokens y configuraciones
27
- token_dict = {}
28
-
29
- # Inicializaci贸n de la aplicaci贸n FastAPI
30
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Modelo de solicitud
33
  class GenerateRequest(BaseModel):
34
  model_name: str
35
- input_text: str
36
- task_type: str
37
  temperature: float = 1.0
38
  max_new_tokens: int = 200
39
  stream: bool = True
@@ -43,13 +146,52 @@ class GenerateRequest(BaseModel):
43
  num_return_sequences: int = 1
44
  do_sample: bool = True
45
  chunk_delay: float = 0.0
46
- stop_sequences: list[str] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Clase para cargar y gestionar los modelos desde S3
49
  class S3ModelLoader:
50
- def __init__(self, bucket_name, aws_access_key_id=None, aws_secret_access_key=None, aws_region=None):
51
  self.bucket_name = bucket_name
52
- self.s3_client = boto3.client(
53
  's3',
54
  aws_access_key_id=aws_access_key_id,
55
  aws_secret_access_key=aws_secret_access_key,
@@ -57,78 +199,110 @@ class S3ModelLoader:
57
  )
58
 
59
  def _get_s3_uri(self, model_name):
60
- return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
61
-
62
- def load_model_and_tokenizer(self, model_name):
63
- if model_name in token_dict:
64
- return token_dict[model_name]
65
-
66
  s3_uri = self._get_s3_uri(model_name)
67
  try:
68
- # Verificar si el modelo ya est谩 en S3
69
- try:
70
- self.s3_client.head_object(Bucket=self.bucket_name, Key=f'{model_name}/model')
71
- print(f"Modelo {model_name} ya existe en S3.")
72
- except self.s3_client.exceptions.ClientError:
73
- print(f"Modelo {model_name} no existe en S3. Descargando desde Hugging Face...")
74
-
75
- # Eliminar cach茅 local de Hugging Face (si existe)
76
- local_cache_dir = os.path.join(os.getenv("HOME"), ".cache/huggingface/hub/models--")
77
- if os.path.exists(local_cache_dir):
78
- shutil.rmtree(local_cache_dir)
79
-
80
- model_path = snapshot_download(model_name, token=HUGGINGFACE_HUB_TOKEN)
81
-
82
- # Cargar el modelo y tokenizer
83
- model = AutoModelForCausalLM.from_pretrained(model_path)
 
 
 
 
 
 
 
 
84
  tokenizer = AutoTokenizer.from_pretrained(model_path)
85
-
86
- # Asignar EOS y PAD token si no est谩n definidos
87
  if tokenizer.eos_token_id is None:
88
  tokenizer.eos_token_id = tokenizer.pad_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Guardar el modelo y el tokenizer en el diccionario
91
- token_dict[model_name] = {
92
- "model": model,
93
- "tokenizer": tokenizer,
94
- "pad_token_id": tokenizer.pad_token_id,
95
- "eos_token_id": tokenizer.eos_token_id
96
- }
97
-
98
- # Subir los archivos del modelo y tokenizer a S3
99
- self.s3_client.upload_file(model_path, self.bucket_name, f'{model_name}/model')
100
- self.s3_client.upload_file(f'{model_path}/tokenizer', self.bucket_name, f'{model_name}/tokenizer')
101
-
102
- # Eliminar los archivos locales despu茅s de haber subido a S3
103
- shutil.rmtree(model_path)
104
-
105
- return token_dict[model_name]
106
- except NoCredentialsError:
107
- raise HTTPException(status_code=500, detail="AWS credentials not found.")
108
- except Exception as e:
109
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
110
-
111
- # Instanciaci贸n del cargador de modelos
112
- model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
113
-
114
- # Funci贸n de generaci贸n de texto con streaming
115
- async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
116
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
117
  input_length = encoded_input["input_ids"].shape[1]
 
118
  remaining_tokens = max_length - input_length
119
-
120
  if remaining_tokens <= 0:
121
  yield ""
122
-
123
  generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
124
-
125
  def stop_criteria(input_ids, scores):
126
- decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
127
  return decoded_output in stop_sequences
128
-
129
  stopping_criteria = StoppingCriteriaList([stop_criteria])
130
-
131
- output_text = ""
132
  outputs = model.generate(
133
  **encoded_input,
134
  do_sample=generation_config.do_sample,
@@ -142,82 +316,380 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
142
  output_scores=True,
143
  return_dict_in_generate=True
144
  )
145
-
146
  for output in outputs.sequences:
147
  for token_id in output:
148
  token = tokenizer.decode(token_id, skip_special_tokens=True)
149
  yield token
150
- await asyncio.sleep(chunk_delay)
151
 
152
- if stop_sequences and any(stop in output_text for stop in stop_sequences):
153
- yield output_text
154
- return
155
 
156
- # Endpoint para generar texto
157
- @app.post("/generate")
158
- async def generate(request: GenerateRequest):
159
- try:
160
- model_name = request.model_name
161
- input_text = request.input_text
162
- temperature = request.temperature
163
- max_new_tokens = request.max_new_tokens
164
- stream = request.stream
165
- top_p = request.top_p
166
- top_k = request.top_k
167
- repetition_penalty = request.repetition_penalty
168
- num_return_sequences = request.num_return_sequences
169
- do_sample = request.do_sample
170
- chunk_delay = request.chunk_delay
171
- stop_sequences = request.stop_sequences
172
-
173
- # Cargar el modelo y tokenizer desde S3 si no existe
174
- model_data = model_loader.load_model_and_tokenizer(model_name)
175
- model = model_data["model"]
176
- tokenizer = model_data["tokenizer"]
177
- pad_token_id = model_data["pad_token_id"]
178
- eos_token_id = model_data["eos_token_id"]
179
 
180
- device = "cuda" if torch.cuda.is_available() else "cpu"
181
- model.to(device)
182
-
183
- generation_config = GenerationConfig(
184
- temperature=temperature,
185
- max_new_tokens=max_new_tokens,
186
- top_p=top_p,
187
- top_k=top_k,
188
- repetition_penalty=repetition_penalty,
189
- do_sample=do_sample,
190
- num_return_sequences=num_return_sequences,
191
- )
192
 
193
- return StreamingResponse(
194
- stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
195
- media_type="text/plain"
196
- )
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
 
199
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
200
 
201
- # Endpoint para generar im谩genes
202
- @app.post("/generate-image")
203
- async def generate_image(request: GenerateRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  try:
205
- validated_body = request
206
- device = "cuda" if torch.cuda.is_available() else "cpu"
207
-
208
- image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
209
- image = image_generator(validated_body.input_text)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- img_byte_arr = BytesIO()
212
- image.save(img_byte_arr, format="PNG")
213
- img_byte_arr.seek(0)
214
 
215
- return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
 
 
 
 
 
 
 
216
 
217
  except Exception as e:
218
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # Ejecutar el servidor FastAPI con Uvicorn
221
  if __name__ == "__main__":
222
- import uvicorn
223
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
1
  import os
2
  import torch
3
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, BackgroundTasks, Request, Query, APIRouter, Path, Body, status, Response, Header
4
+ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse, PlainTextResponse, RedirectResponse
5
+ from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr, ValidationError
6
  from transformers import (
7
  AutoModelForCausalLM,
8
  AutoTokenizer,
9
  GenerationConfig,
10
  StoppingCriteriaList,
11
+ pipeline,
12
+ AutoProcessor,
13
+ AutoModelForImageClassification,
14
+ AutoModelForSeq2SeqLM,
15
+ AutoModelForQuestionAnswering,
16
+ AutoModelForSpeechSeq2Seq,
17
+ AutoModelForImageSegmentation,
18
+ AutoFeatureExtractor,
19
+ AutoModelForTokenClassification,
20
+ AutoModelForMaskedLM,
21
+ AutoModelForImageInpainting,
22
+ AutoModelForImageSuperResolution,
23
+ AutoModelForObjectDetection,
24
+ AutoModelForImageCaptioning,
25
+ AutoModelForTextToSpeech,
26
+ AutoModelForSeq2SeqLM
27
  )
28
  from io import BytesIO
29
  import boto3
30
+ from botocore.exceptions import NoCredentialsError, ClientError
31
  from huggingface_hub import snapshot_download
32
+ import asyncio
33
+ import tempfile
34
+ import hashlib
35
+ from PIL import Image
36
+ import base64
37
+ from typing import Optional, List, Union, Dict, Any
38
+ import uuid
39
+ import subprocess
40
+ import json
41
+ from starlette.middleware.cors import CORSMiddleware
42
+ import numpy as np
43
+ from typing import Dict, Any
44
+ from fastapi.staticfiles import StaticFiles
45
+ from fastapi.templating import Jinja2Templates
46
+ from fastapi.middleware.gzip import GZipMiddleware
47
+ from transformers import AutoImageProcessor, pipeline
48
+ from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
49
+ from fastapi.security.api_key import APIKeyCookie
50
+ from fastapi import Depends, Security, status, APIRouter, UploadFile, File, Request
51
+ from fastapi.security import APIKeyHeader, OAuth2PasswordRequestForm
52
+ from passlib.context import CryptContext
53
+ from jose import JWTError, jwt
54
+ from datetime import datetime, timedelta
55
+ from starlette.requests import Request
56
+ import logging
57
+ from pydantic import EmailStr, constr, ValidationError
58
+ from database import insert_user, get_user, delete_user, update_user, create_db_and_table
59
+ from starlette.middleware import Middleware
60
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
61
+ from starlette.types import ASGIApp
62
+ import uvicorn
63
+ from starlette.responses import StreamingResponse
64
+ import logging
65
+ from pydantic import EmailStr, constr, ValidationError
66
+ from database import insert_user, get_user, delete_user, update_user, create_db_and_table, get_all_users
67
+ from starlette.middleware import Middleware
68
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
69
+ from starlette.types import ASGIApp
70
+ import uvicorn
71
+ from starlette.responses import StreamingResponse
72
+ import logging
73
+ from fastapi.exceptions import RequestValidationError
74
+ from fastapi import Request, status, Depends
75
+ from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
76
+ from jose import JWTError, jwt
77
+ from passlib.context import CryptContext
78
+ from datetime import datetime, timedelta
79
+ from typing import Optional
80
+
81
+ #setting up logging
82
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
83
+ logger = logging.getLogger(__name__)
84
+
85
+ #JWT Settings
86
+ SECRET_KEY = os.getenv("SECRET_KEY")
87
+ if not SECRET_KEY:
88
+ raise ValueError("SECRET_KEY must be set.")
89
+ ALGORITHM = "HS256"
90
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
91
+
92
+ #Password Hashing
93
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
94
+
95
+ #Database connection - replace with your database setup
96
+ #Example using SQLite
97
+ import sqlite3
98
+ conn = sqlite3.connect('users.db')
99
+ cursor = conn.cursor()
100
+
101
+ #OAuth2
102
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
103
+
104
+ #API Key
105
+ API_KEY = os.getenv("API_KEY")
106
+ api_key_header = APIKeyHeader(name="X-API-Key")
107
+
108
+ #Configuration
109
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
110
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
111
  AWS_REGION = os.getenv("AWS_REGION")
112
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
113
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
114
+ TEMP_DIR = "/tmp"
115
+ STATIC_DIR = "static"
116
+ TEMPLATES = Jinja2Templates(directory="templates")
117
 
 
 
 
 
118
  app = FastAPI()
119
+ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
120
+ app.add_middleware(GZipMiddleware)
121
+
122
+ origins = ["*"]
123
+ app.add_middleware(
124
+ CORSMiddleware,
125
+ allow_origins=origins,
126
+ allow_credentials=True,
127
+ allow_methods=["*"],
128
+ allow_headers=["*"],
129
+ )
130
+
131
+ class User(BaseModel):
132
+ username: constr(min_length=3, max_length=50)
133
+ email: EmailStr
134
+ password: constr(min_length=8)
135
 
 
136
  class GenerateRequest(BaseModel):
137
  model_name: str
138
+ input_text: Optional[str] = Field(None, description="Input text for generation.")
139
+ task_type: str = Field(..., description="Type of generation task (text, image, audio, video, classification, translation, question-answering, speech-to-text, text-to-speech, image-segmentation, feature-extraction, token-classification, fill-mask, image-inpainting, image-super-resolution, object-detection, image-captioning, audio-transcription, summarization).")
140
  temperature: float = 1.0
141
  max_new_tokens: int = 200
142
  stream: bool = True
 
146
  num_return_sequences: int = 1
147
  do_sample: bool = True
148
  chunk_delay: float = 0.0
149
+ stop_sequences: List[str] = []
150
+ image_file: Optional[UploadFile] = None
151
+ source_language: Optional[str] = None
152
+ target_language: Optional[str] = None
153
+ context: Optional[str] = None
154
+ audio_file: Optional[UploadFile] = None
155
+ raw_input: Optional[Union[str, bytes]] = None # for feature extraction
156
+ masked_text: Optional[str] = None # for fill-mask
157
+ mask_image: Optional[UploadFile] = None # for image inpainting
158
+ low_res_image: Optional[UploadFile] = None # for image super-resolution
159
+
160
+
161
+ @validator("task_type")
162
+ def validate_task_type(cls, value):
163
+ allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
164
+ if value not in allowed_types:
165
+ raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
166
+ return value
167
+
168
+ @root_validator
169
+ def check_input(cls, values):
170
+ task_type = values.get("task_type")
171
+ if task_type == "text" and values.get("input_text") is None:
172
+ raise ValueError("input_text is required for text generation.")
173
+ elif task_type == "speech-to-text" and values.get("audio_file") is None:
174
+ raise ValueError("audio_file is required for speech-to-text.")
175
+ elif task_type == "classification" and values.get("image_file") is None:
176
+ raise ValueError("image_file is required for image classification.")
177
+ elif task_type == "image-segmentation" and values.get("image_file") is None:
178
+ raise ValueError("image_file is required for image segmentation.")
179
+ elif task_type == "feature-extraction" and values.get("raw_input") is None:
180
+ raise ValueError("raw_input is required for feature extraction.")
181
+ elif task_type == "fill-mask" and values.get("masked_text") is None:
182
+ raise ValueError("masked_text is required for fill-mask.")
183
+ elif task_type == "image-inpainting" and (values.get("image_file") is None or values.get("mask_image") is None):
184
+ raise ValueError("image_file and mask_image are required for image inpainting.")
185
+ elif task_type == "image-super-resolution" and values.get("low_res_image") is None:
186
+ raise ValueError("low_res_image is required for image super-resolution.")
187
+ return values
188
+
189
+
190
 
 
191
  class S3ModelLoader:
192
+ def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
193
  self.bucket_name = bucket_name
194
+ self.s3 = boto3.client(
195
  's3',
196
  aws_access_key_id=aws_access_key_id,
197
  aws_secret_access_key=aws_secret_access_key,
 
199
  )
200
 
201
  def _get_s3_uri(self, model_name):
202
+ return f"{self.bucket_name}/{model_name.replace('/', '-')}"
203
+
204
+ def load_model_and_tokenizer(self, model_name, task_type):
 
 
 
205
  s3_uri = self._get_s3_uri(model_name)
206
  try:
207
+ self.s3.head_object(Bucket=self.bucket_name, Key=f'{s3_uri}/config.json')
208
+ except ClientError as e:
209
+ if e.response['Error']['Code'] == '404':
210
+ with tempfile.TemporaryDirectory() as tmpdir:
211
+ model_path = snapshot_download(model_name, token=HUGGINGFACE_HUB_TOKEN, cache_dir=tmpdir)
212
+ self._upload_model_to_s3(model_path, s3_uri)
213
+ else:
214
+ raise HTTPException(status_code=500, detail=f"Error accessing S3: {e}")
215
+ return self._load_from_s3(s3_uri, task_type)
216
+
217
+ def _upload_model_to_s3(self, model_path, s3_uri):
218
+ for root, _, files in os.walk(model_path):
219
+ for file in files:
220
+ local_path = os.path.join(root, file)
221
+ s3_path = os.path.join(s3_uri, os.path.relpath(local_path, model_path))
222
+ self.s3.upload_file(local_path, self.bucket_name, s3_path)
223
+
224
+ def _load_from_s3(self, s3_uri, task_type):
225
+ with tempfile.TemporaryDirectory() as tmpdir:
226
+ model_path = os.path.join(tmpdir, s3_uri)
227
+ os.makedirs(model_path, exist_ok=True)
228
+ self.s3.download_file(self.bucket_name, f"{s3_uri}/config.json", os.path.join(model_path, "config.json"))
229
+ if task_type == "text":
230
+ model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True)
231
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
232
  if tokenizer.eos_token_id is None:
233
  tokenizer.eos_token_id = tokenizer.pad_token_id
234
+ return {"model": model, "tokenizer": tokenizer, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id}
235
+ elif task_type in ["image", "audio", "video"]:
236
+ processor = AutoProcessor.from_pretrained(model_path)
237
+ pipeline_function = pipeline(task_type, model=model_path, device=0 if torch.cuda.is_available() else -1, processor=processor)
238
+ return {"pipeline": pipeline_function}
239
+ elif task_type == "classification":
240
+ model = AutoModelForImageClassification.from_pretrained(model_path)
241
+ processor = AutoProcessor.from_pretrained(model_path)
242
+ return {"model": model, "processor": processor}
243
+ elif task_type == "translation":
244
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
245
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
246
+ return {"model": model, "tokenizer": tokenizer}
247
+ elif task_type == "question-answering":
248
+ model = AutoModelForQuestionAnswering.from_pretrained(model_path)
249
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
250
+ return {"model": model, "tokenizer": tokenizer}
251
+ elif task_type == "speech-to-text":
252
+ model = pipeline("automatic-speech-recognition", model=model_path, device=0 if torch.cuda.is_available() else -1)
253
+ return {"pipeline": model}
254
+ elif task_type == "text-to-speech":
255
+ model = pipeline("text-to-speech", model=model_path, device=0 if torch.cuda.is_available() else -1)
256
+ return {"pipeline": model}
257
+ elif task_type == "image-segmentation":
258
+ model = pipeline("image-segmentation", model=model_path, device=0 if torch.cuda.is_available() else -1)
259
+ return {"pipeline": model}
260
+ elif task_type == "feature-extraction":
261
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
262
+ return {"feature_extractor": feature_extractor}
263
+ elif task_type == "token-classification":
264
+ model = AutoModelForTokenClassification.from_pretrained(model_path)
265
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
266
+ return {"model": model, "tokenizer": tokenizer}
267
+ elif task_type == "fill-mask":
268
+ model = AutoModelForMaskedLM.from_pretrained(model_path)
269
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
270
+ return {"model": model, "tokenizer": tokenizer}
271
+ elif task_type == "image-inpainting":
272
+ model = pipeline("image-inpainting", model=model_path, device=0 if torch.cuda.is_available() else -1)
273
+ return {"pipeline": model}
274
+ elif task_type == "image-super-resolution":
275
+ model = pipeline("image-super-resolution", model=model_path, device=0 if torch.cuda.is_available() else -1)
276
+ return {"pipeline": model}
277
+ elif task_type == "object-detection":
278
+ model = pipeline("object-detection", model=model_path, device=0 if torch.cuda.is_available() else -1)
279
+ image_processor = AutoImageProcessor.from_pretrained(model_path)
280
+ return {"pipeline": model, "image_processor": image_processor}
281
+ elif task_type == "image-captioning":
282
+ model = pipeline("image-captioning", model=model_path, device=0 if torch.cuda.is_available() else -1)
283
+ return {"pipeline": model}
284
+ elif task_type == "audio-transcription":
285
+ model = pipeline("automatic-speech-recognition", model=model_path, device=0 if torch.cuda.is_available() else -1)
286
+ return {"pipeline": model}
287
+ elif task_type == "summarization":
288
+ model = pipeline("summarization", model=model_path, device=0 if torch.cuda.is_available() else -1)
289
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
290
+ return {"model": model, "tokenizer": tokenizer}
291
+ else:
292
+ raise ValueError("Unsupported task type")
293
 
294
+ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
295
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  input_length = encoded_input["input_ids"].shape[1]
297
+ max_length = model.config.max_length
298
  remaining_tokens = max_length - input_length
 
299
  if remaining_tokens <= 0:
300
  yield ""
 
301
  generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
 
302
  def stop_criteria(input_ids, scores):
303
+ decoded_output = tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
304
  return decoded_output in stop_sequences
 
305
  stopping_criteria = StoppingCriteriaList([stop_criteria])
 
 
306
  outputs = model.generate(
307
  **encoded_input,
308
  do_sample=generation_config.do_sample,
 
316
  output_scores=True,
317
  return_dict_in_generate=True
318
  )
 
319
  for output in outputs.sequences:
320
  for token_id in output:
321
  token = tokenizer.decode(token_id, skip_special_tokens=True)
322
  yield token
 
323
 
 
 
 
324
 
325
+ model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
+ def get_model_data(request: GenerateRequest):
328
+ return model_loader.load_model_and_tokenizer(request.model_name, request.task_type)
329
+
330
+ async def verify_api_key(api_key: str = Depends(api_key_header)):
331
+ if api_key != API_KEY:
332
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
 
 
 
 
 
 
333
 
 
 
 
 
334
 
335
+ @app.post("/generate", dependencies=[Depends(verify_api_key)])
336
+ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data = Depends(get_model_data)):
337
+ try:
338
+ device = "cuda" if torch.cuda.is_available() else "cpu"
339
+ if request.task_type == "text":
340
+ model = model_data["model"].to(device)
341
+ tokenizer = model_data["tokenizer"]
342
+ generation_config = GenerationConfig(
343
+ temperature=request.temperature,
344
+ max_new_tokens=request.max_new_tokens,
345
+ top_p=request.top_p,
346
+ top_k=request.top_k,
347
+ repetition_penalty=request.repetition_penalty,
348
+ do_sample=request.do_sample,
349
+ num_return_sequences=request.num_return_sequences,
350
+ )
351
+ async def stream_with_tokens():
352
+ async for token in stream_text(model, tokenizer, request.input_text, generation_config, request.stop_sequences, device, request.chunk_delay):
353
+ yield f"Token: {token}\n"
354
+ return StreamingResponse(stream_with_tokens(), media_type="text/plain")
355
+ elif request.task_type in ["image", "audio", "video"]:
356
+ pipeline = model_data["pipeline"]
357
+ result = pipeline(request.input_text)
358
+ if request.task_type == "image":
359
+ image = result[0]
360
+ img_byte_arr = BytesIO()
361
+ image.save(img_byte_arr, format="PNG")
362
+ img_byte_arr.seek(0)
363
+ return StreamingResponse(img_byte_arr, media_type="image/png")
364
+ elif request.task_type == "audio":
365
+ audio = result[0]
366
+ audio_byte_arr = BytesIO()
367
+ audio.save(audio_byte_arr, format="wav")
368
+ audio_byte_arr.seek(0)
369
+ return StreamingResponse(audio_byte_arr, media_type="audio/wav")
370
+ elif request.task_type == "video":
371
+ video = result[0]
372
+ video_byte_arr = BytesIO()
373
+ video.save(video_byte_arr, format="mp4")
374
+ video_byte_arr.seek(0)
375
+ return StreamingResponse(video_byte_arr, media_type="video/mp4")
376
+ elif request.task_type == "classification":
377
+ if request.image_file is None:
378
+ raise HTTPException(status_code=400, detail="Image file is required for classification.")
379
+ contents = await request.image_file.read()
380
+ image = Image.open(BytesIO(contents)).convert("RGB")
381
+ model = model_data["model"].to(device)
382
+ processor = model_data["processor"]
383
+ inputs = processor(images=image, return_tensors="pt").to(device)
384
+ with torch.no_grad():
385
+ outputs = model(**inputs)
386
+ predicted_class_idx = outputs.logits.argmax().item()
387
+ predicted_class = model.config.id2label[predicted_class_idx]
388
+ return JSONResponse({"predicted_class": predicted_class})
389
+ elif request.task_type == "translation":
390
+ if request.source_language is None or request.target_language is None:
391
+ raise HTTPException(status_code=400, detail="Source and target languages are required for translation.")
392
+ model = model_data["model"].to(device)
393
+ tokenizer = model_data["tokenizer"]
394
+ inputs = tokenizer(request.input_text, return_tensors="pt").to(device)
395
+ with torch.no_grad():
396
+ outputs = model.generate(**inputs)
397
+ translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
398
+ return JSONResponse({"translation": translation})
399
+ elif request.task_type == "question-answering":
400
+ if request.context is None:
401
+ raise HTTPException(status_code=400, detail="Context is required for question answering.")
402
+ model = model_data["model"].to(device)
403
+ tokenizer = model_data["tokenizer"]
404
+ inputs = tokenizer(question=request.input_text, context=request.context, return_tensors="pt").to(device)
405
+ with torch.no_grad():
406
+ outputs = model(**inputs)
407
+ answer_start = torch.argmax(outputs.start_logits)
408
+ answer_end = torch.argmax(outputs.end_logits) + 1
409
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
410
+ return JSONResponse({"answer": answer})
411
+ elif request.task_type == "speech-to-text":
412
+ if request.audio_file is None:
413
+ raise HTTPException(status_code=400, detail="Audio file is required for speech-to-text.")
414
+ contents = await request.audio_file.read()
415
+ pipeline = model_data["pipeline"]
416
+ try:
417
+ transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
418
+ return JSONResponse({"transcription": transcription})
419
+ except Exception as e:
420
+ raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}")
421
+
422
+ elif request.task_type == "text-to-speech":
423
+ if not request.input_text:
424
+ raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
425
+ pipeline = model_data["pipeline"]
426
+ try:
427
+ audio = pipeline(request.input_text)[0]
428
+ file_path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}.wav")
429
+ audio.save(file_path)
430
+ background_tasks.add_task(os.remove, file_path)
431
+ return FileResponse(file_path, media_type="audio/wav")
432
+ except Exception as e:
433
+ raise HTTPException(status_code=500, detail=f"Error during text-to-speech: {str(e)}")
434
+
435
+ elif request.task_type == "image-segmentation":
436
+ if request.image_file is None:
437
+ raise HTTPException(status_code=400, detail="Image file is required for image segmentation.")
438
+ contents = await request.image_file.read()
439
+ image = Image.open(BytesIO(contents)).convert("RGB")
440
+ pipeline = model_data["pipeline"]
441
+ result = pipeline(image)
442
+ mask = result[0]['mask']
443
+ mask_byte_arr = BytesIO()
444
+ mask.save(mask_byte_arr, format="PNG")
445
+ mask_byte_arr.seek(0)
446
+ return StreamingResponse(mask_byte_arr, media_type="image/png")
447
+ elif request.task_type == "feature-extraction":
448
+ if request.raw_input is None:
449
+ raise HTTPException(status_code=400, detail="raw_input is required for feature extraction.")
450
+ feature_extractor = model_data["feature_extractor"]
451
+ try:
452
+ if isinstance(request.raw_input, str):
453
+ inputs = feature_extractor(text=request.raw_input, return_tensors="pt")
454
+ elif isinstance(request.raw_input, bytes):
455
+ image = Image.open(BytesIO(request.raw_input)).convert("RGB")
456
+ inputs = feature_extractor(images=image, return_tensors="pt")
457
+ else:
458
+ raise ValueError("Unsupported raw_input type.")
459
+ features = inputs.pixel_values # Adjust according to your feature extractor
460
+ return JSONResponse({"features": features.tolist()})
461
+ except Exception as fe:
462
+ raise HTTPException(status_code=400, detail=f"Error during feature extraction: {fe}")
463
+ elif request.task_type == "token-classification":
464
+ if request.input_text is None:
465
+ raise HTTPException(status_code=400, detail="Input text is required for token classification.")
466
+ model = model_data["model"].to(device)
467
+ tokenizer = model_data["tokenizer"]
468
+ inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True)
469
+ with torch.no_grad():
470
+ outputs = model(**inputs)
471
+ predictions = outputs.logits.argmax(dim=-1)
472
+ predicted_labels = [model.config.id2label[label_id] for label_id in predictions[0].tolist()]
473
+ return JSONResponse({"predicted_labels": predicted_labels})
474
+ elif request.task_type == "fill-mask":
475
+ if request.masked_text is None:
476
+ raise HTTPException(status_code=400, detail="masked_text is required for fill-mask.")
477
+ model = model_data["model"].to(device)
478
+ tokenizer = model_data["tokenizer"]
479
+ inputs = tokenizer(request.masked_text, return_tensors="pt")
480
+ with torch.no_grad():
481
+ outputs = model(**inputs)
482
+ logits = outputs.logits
483
+ masked_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
484
+ predicted_token_id = torch.argmax(logits[0, masked_index])
485
+ predicted_token = tokenizer.decode(predicted_token_id)
486
+ return JSONResponse({"predicted_token": predicted_token})
487
+ elif request.task_type == "image-inpainting":
488
+ if request.image_file is None or request.mask_image is None:
489
+ raise HTTPException(status_code=400, detail="image_file and mask_image are required for image inpainting.")
490
+ image_contents = await request.image_file.read()
491
+ mask_contents = await request.mask_image.read()
492
+ image = Image.open(BytesIO(image_contents)).convert("RGB")
493
+ mask = Image.open(BytesIO(mask_contents)).convert("L") # Assuming mask is grayscale
494
+ pipeline = model_data["pipeline"]
495
+ result = pipeline(image, mask)
496
+ inpainted_image = result[0]
497
+ img_byte_arr = BytesIO()
498
+ inpainted_image.save(img_byte_arr, format="PNG")
499
+ img_byte_arr.seek(0)
500
+ return StreamingResponse(img_byte_arr, media_type="image/png")
501
+ elif request.task_type == "image-super-resolution":
502
+ if request.low_res_image is None:
503
+ raise HTTPException(status_code=400, detail="low_res_image is required for image super-resolution.")
504
+ contents = await request.low_res_image.read()
505
+ image = Image.open(BytesIO(contents)).convert("RGB")
506
+ pipeline = model_data["pipeline"]
507
+ result = pipeline(image)
508
+ upscaled_image = result[0]
509
+ img_byte_arr = BytesIO()
510
+ upscaled_image.save(img_byte_arr, format="PNG")
511
+ img_byte_arr.seek(0)
512
+ return StreamingResponse(img_byte_arr, media_type="image/png")
513
+ elif request.task_type == "object-detection":
514
+ if request.image_file is None:
515
+ raise HTTPException(status_code=400, detail="Image file is required for object detection.")
516
+ contents = await request.image_file.read()
517
+ image = Image.open(BytesIO(contents)).convert("RGB")
518
+ pipeline = model_data["pipeline"]
519
+ image_processor = model_data["image_processor"]
520
+ inputs = image_processor(images=image, return_tensors="pt")
521
+ with torch.no_grad():
522
+ outputs = pipeline(image)
523
+ detections = outputs
524
+ return JSONResponse({"detections": detections})
525
+ elif request.task_type == "image-captioning":
526
+ if request.image_file is None:
527
+ raise HTTPException(status_code=400, detail="Image file is required for image captioning.")
528
+ contents = await request.image_file.read()
529
+ image = Image.open(BytesIO(contents)).convert("RGB")
530
+ pipeline = model_data["pipeline"]
531
+ caption = pipeline(image)[0]['generated_text']
532
+ return JSONResponse({"caption": caption})
533
+ elif request.task_type == "audio-transcription":
534
+ if request.audio_file is None:
535
+ raise HTTPException(status_code=400, detail="Audio file is required for audio transcription.")
536
+ try:
537
+ contents = await request.audio_file.read()
538
+ pipeline = model_data["pipeline"]
539
+ try:
540
+ transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
541
+ return JSONResponse({"transcription": transcription})
542
+ except Exception as e:
543
+ raise HTTPException(status_code=500, detail=f"Error during audio transcription (pipeline): {str(e)}")
544
+ except Exception as e:
545
+ raise HTTPException(status_code=500, detail=f"Error during audio transcription (file read): {str(e)}")
546
+ elif request.task_type == "summarization":
547
+ if request.input_text is None:
548
+ raise HTTPException(status_code=400, detail="Input text is required for summarization.")
549
+ model = model_data["model"].to(device)
550
+ tokenizer = model_data["tokenizer"]
551
+ inputs = tokenizer(request.input_text, return_tensors="pt", truncation=True, max_length=512) # added max_length for summarization
552
+ with torch.no_grad():
553
+ outputs = model.generate(**inputs)
554
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
555
+ return JSONResponse({"summary": summary})
556
+
557
+ else:
558
+ raise HTTPException(status_code=500, detail=f"Unsupported task type")
559
  except Exception as e:
560
+ logger.exception(f"Internal server error: {str(e)}")
561
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
562
 
563
+
564
+ @app.get("/", response_class=HTMLResponse)
565
+ async def root(request: Request):
566
+ return TEMPLATES.TemplateResponse("index.html", {"request": request})
567
+
568
+ @app.get("/health")
569
+ async def health_check():
570
+ return {"status": "healthy"}
571
+
572
+ # Authentication Endpoints
573
+
574
+ @app.post("/token", response_model=Token)
575
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
576
+ user = authenticate_user(form_data.username, form_data.password)
577
+ if not user:
578
+ raise HTTPException(
579
+ status_code=status.HTTP_401_UNAUTHORIZED,
580
+ detail="Incorrect username or password",
581
+ headers={"WWW-Authenticate": "Bearer"},
582
+ )
583
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
584
+ access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
585
+ return {"access_token": access_token, "token_type": "bearer"}
586
+
587
+ def authenticate_user(username: str, password: str):
588
+ user = get_user(username)
589
+ if user and pwd_context.verify(password, user.hashed_password):
590
+ return {"username": user.username}
591
+ return None
592
+
593
+ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
594
+ to_encode = data.copy()
595
+ if expires_delta:
596
+ expire = datetime.utcnow() + expires_delta
597
+ else:
598
+ expire = datetime.utcnow() + timedelta(minutes=15)
599
+ to_encode.update({"exp": expire})
600
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
601
+ return encoded_jwt
602
+
603
+ class Token(BaseModel):
604
+ access_token: str
605
+ token_type: str
606
+
607
+
608
+ @app.get("/users/me")
609
+ async def read_users_me(current_user: str = Depends(get_current_user)):
610
+ return {"username": current_user}
611
+
612
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
613
+ credentials_exception = HTTPException(
614
+ status_code=status.HTTP_401_UNAUTHORIZED,
615
+ detail="Could not validate credentials",
616
+ headers={"WWW-Authenticate": "Bearer"},
617
+ )
618
  try:
619
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
620
+ username: str = payload.get("sub")
621
+ if username is None:
622
+ raise credentials_exception
623
+ token_data = {"username": username, "token": token}
624
+ except JWTError:
625
+ raise credentials_exception
626
+ user = get_user(username)
627
+ if user is None:
628
+ raise credentials_exception
629
+ return username
630
+
631
+
632
+ @app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
633
+ async def create_user(user: User):
634
+ try:
635
+ hashed_password = pwd_context.hash(user.password)
636
+ new_user = {"username": user.username, "email": user.email, "hashed_password": hashed_password}
637
+ inserted_user = insert_user(new_user)
638
+ if inserted_user:
639
+ return User(**inserted_user)
640
+ else:
641
+ raise HTTPException(status_code=500, detail="Failed to create user.")
642
+ except Exception as e:
643
+ logger.error(f"Error creating user: {e}")
644
+ raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
645
 
 
 
 
646
 
647
+ @app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
648
+ async def update_user_data(username: str, user: User):
649
+ try:
650
+ hashed_password = pwd_context.hash(user.password)
651
+ updated_user_data = {"email": user.email, "hashed_password": hashed_password}
652
+ updated_user = update_user(username, updated_user_data)
653
+ if updated_user:
654
+ return User(**updated_user)
655
+ else:
656
+ raise HTTPException(status_code=404, detail="User not found")
657
 
658
  except Exception as e:
659
+ logger.error(f"Error updating user: {e}")
660
+ raise HTTPException(status_code=500, detail="Error updating user.")
661
+
662
+
663
+
664
+ @app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
665
+ async def delete_user_account(username: str):
666
+ try:
667
+ deleted_user = delete_user(username)
668
+ if deleted_user:
669
+ return JSONResponse({"message": "User deleted successfully."}, status_code=200)
670
+ else:
671
+ raise HTTPException(status_code=404, detail="User not found")
672
+ except Exception as e:
673
+ logger.error(f"Error deleting user: {e}")
674
+ raise HTTPException(status_code=500, detail="Error deleting user.")
675
+
676
+
677
+ @app.get("/users", dependencies=[Depends(get_current_user)])
678
+ async def get_all_users_route():
679
+ return get_all_users()
680
+
681
+
682
+
683
+ @app.exception_handler(RequestValidationError)
684
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
685
+ return JSONResponse(
686
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
687
+ content=json.dumps({"detail": exc.errors(), "body": exc.body}),
688
+ )
689
+
690
 
 
691
  if __name__ == "__main__":
692
+
693
+ create_db_and_table() # Initialize database on startup
694
+
695
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) # replace main with your filename