Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,39 +1,142 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
from fastapi import FastAPI, HTTPException
|
4 |
-
from fastapi.responses import StreamingResponse
|
5 |
-
from pydantic import BaseModel
|
6 |
from transformers import (
|
7 |
AutoModelForCausalLM,
|
8 |
AutoTokenizer,
|
9 |
GenerationConfig,
|
10 |
StoppingCriteriaList,
|
11 |
-
pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
)
|
13 |
from io import BytesIO
|
14 |
import boto3
|
15 |
-
from botocore.exceptions import NoCredentialsError
|
16 |
from huggingface_hub import snapshot_download
|
17 |
-
import
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
21 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
22 |
AWS_REGION = os.getenv("AWS_REGION")
|
23 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
24 |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
|
|
|
|
|
|
25 |
|
26 |
-
# Diccionario global de tokens y configuraciones
|
27 |
-
token_dict = {}
|
28 |
-
|
29 |
-
# Inicializaci贸n de la aplicaci贸n FastAPI
|
30 |
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
# Modelo de solicitud
|
33 |
class GenerateRequest(BaseModel):
|
34 |
model_name: str
|
35 |
-
input_text: str
|
36 |
-
task_type: str
|
37 |
temperature: float = 1.0
|
38 |
max_new_tokens: int = 200
|
39 |
stream: bool = True
|
@@ -43,13 +146,52 @@ class GenerateRequest(BaseModel):
|
|
43 |
num_return_sequences: int = 1
|
44 |
do_sample: bool = True
|
45 |
chunk_delay: float = 0.0
|
46 |
-
stop_sequences:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
# Clase para cargar y gestionar los modelos desde S3
|
49 |
class S3ModelLoader:
|
50 |
-
def __init__(self, bucket_name, aws_access_key_id
|
51 |
self.bucket_name = bucket_name
|
52 |
-
self.
|
53 |
's3',
|
54 |
aws_access_key_id=aws_access_key_id,
|
55 |
aws_secret_access_key=aws_secret_access_key,
|
@@ -57,78 +199,110 @@ class S3ModelLoader:
|
|
57 |
)
|
58 |
|
59 |
def _get_s3_uri(self, model_name):
|
60 |
-
return f"
|
61 |
-
|
62 |
-
def load_model_and_tokenizer(self, model_name):
|
63 |
-
if model_name in token_dict:
|
64 |
-
return token_dict[model_name]
|
65 |
-
|
66 |
s3_uri = self._get_s3_uri(model_name)
|
67 |
try:
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
85 |
-
|
86 |
-
# Asignar EOS y PAD token si no est谩n definidos
|
87 |
if tokenizer.eos_token_id is None:
|
88 |
tokenizer.eos_token_id = tokenizer.pad_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
"model": model,
|
93 |
-
"tokenizer": tokenizer,
|
94 |
-
"pad_token_id": tokenizer.pad_token_id,
|
95 |
-
"eos_token_id": tokenizer.eos_token_id
|
96 |
-
}
|
97 |
-
|
98 |
-
# Subir los archivos del modelo y tokenizer a S3
|
99 |
-
self.s3_client.upload_file(model_path, self.bucket_name, f'{model_name}/model')
|
100 |
-
self.s3_client.upload_file(f'{model_path}/tokenizer', self.bucket_name, f'{model_name}/tokenizer')
|
101 |
-
|
102 |
-
# Eliminar los archivos locales despu茅s de haber subido a S3
|
103 |
-
shutil.rmtree(model_path)
|
104 |
-
|
105 |
-
return token_dict[model_name]
|
106 |
-
except NoCredentialsError:
|
107 |
-
raise HTTPException(status_code=500, detail="AWS credentials not found.")
|
108 |
-
except Exception as e:
|
109 |
-
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
110 |
-
|
111 |
-
# Instanciaci贸n del cargador de modelos
|
112 |
-
model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
|
113 |
-
|
114 |
-
# Funci贸n de generaci贸n de texto con streaming
|
115 |
-
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
|
116 |
-
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
117 |
input_length = encoded_input["input_ids"].shape[1]
|
|
|
118 |
remaining_tokens = max_length - input_length
|
119 |
-
|
120 |
if remaining_tokens <= 0:
|
121 |
yield ""
|
122 |
-
|
123 |
generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
|
124 |
-
|
125 |
def stop_criteria(input_ids, scores):
|
126 |
-
decoded_output = tokenizer.decode(
|
127 |
return decoded_output in stop_sequences
|
128 |
-
|
129 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
130 |
-
|
131 |
-
output_text = ""
|
132 |
outputs = model.generate(
|
133 |
**encoded_input,
|
134 |
do_sample=generation_config.do_sample,
|
@@ -142,82 +316,380 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
142 |
output_scores=True,
|
143 |
return_dict_in_generate=True
|
144 |
)
|
145 |
-
|
146 |
for output in outputs.sequences:
|
147 |
for token_id in output:
|
148 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
149 |
yield token
|
150 |
-
await asyncio.sleep(chunk_delay)
|
151 |
|
152 |
-
if stop_sequences and any(stop in output_text for stop in stop_sequences):
|
153 |
-
yield output_text
|
154 |
-
return
|
155 |
|
156 |
-
|
157 |
-
@app.post("/generate")
|
158 |
-
async def generate(request: GenerateRequest):
|
159 |
-
try:
|
160 |
-
model_name = request.model_name
|
161 |
-
input_text = request.input_text
|
162 |
-
temperature = request.temperature
|
163 |
-
max_new_tokens = request.max_new_tokens
|
164 |
-
stream = request.stream
|
165 |
-
top_p = request.top_p
|
166 |
-
top_k = request.top_k
|
167 |
-
repetition_penalty = request.repetition_penalty
|
168 |
-
num_return_sequences = request.num_return_sequences
|
169 |
-
do_sample = request.do_sample
|
170 |
-
chunk_delay = request.chunk_delay
|
171 |
-
stop_sequences = request.stop_sequences
|
172 |
-
|
173 |
-
# Cargar el modelo y tokenizer desde S3 si no existe
|
174 |
-
model_data = model_loader.load_model_and_tokenizer(model_name)
|
175 |
-
model = model_data["model"]
|
176 |
-
tokenizer = model_data["tokenizer"]
|
177 |
-
pad_token_id = model_data["pad_token_id"]
|
178 |
-
eos_token_id = model_data["eos_token_id"]
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
top_p=top_p,
|
187 |
-
top_k=top_k,
|
188 |
-
repetition_penalty=repetition_penalty,
|
189 |
-
do_sample=do_sample,
|
190 |
-
num_return_sequences=num_return_sequences,
|
191 |
-
)
|
192 |
|
193 |
-
return StreamingResponse(
|
194 |
-
stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
|
195 |
-
media_type="text/plain"
|
196 |
-
)
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
except Exception as e:
|
|
|
199 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
200 |
|
201 |
-
|
202 |
-
@app.
|
203 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
try:
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
-
img_byte_arr = BytesIO()
|
212 |
-
image.save(img_byte_arr, format="PNG")
|
213 |
-
img_byte_arr.seek(0)
|
214 |
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
except Exception as e:
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
# Ejecutar el servidor FastAPI con Uvicorn
|
221 |
if __name__ == "__main__":
|
222 |
-
|
223 |
-
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, BackgroundTasks, Request, Query, APIRouter, Path, Body, status, Response, Header
|
4 |
+
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse, PlainTextResponse, RedirectResponse
|
5 |
+
from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr, ValidationError
|
6 |
from transformers import (
|
7 |
AutoModelForCausalLM,
|
8 |
AutoTokenizer,
|
9 |
GenerationConfig,
|
10 |
StoppingCriteriaList,
|
11 |
+
pipeline,
|
12 |
+
AutoProcessor,
|
13 |
+
AutoModelForImageClassification,
|
14 |
+
AutoModelForSeq2SeqLM,
|
15 |
+
AutoModelForQuestionAnswering,
|
16 |
+
AutoModelForSpeechSeq2Seq,
|
17 |
+
AutoModelForImageSegmentation,
|
18 |
+
AutoFeatureExtractor,
|
19 |
+
AutoModelForTokenClassification,
|
20 |
+
AutoModelForMaskedLM,
|
21 |
+
AutoModelForImageInpainting,
|
22 |
+
AutoModelForImageSuperResolution,
|
23 |
+
AutoModelForObjectDetection,
|
24 |
+
AutoModelForImageCaptioning,
|
25 |
+
AutoModelForTextToSpeech,
|
26 |
+
AutoModelForSeq2SeqLM
|
27 |
)
|
28 |
from io import BytesIO
|
29 |
import boto3
|
30 |
+
from botocore.exceptions import NoCredentialsError, ClientError
|
31 |
from huggingface_hub import snapshot_download
|
32 |
+
import asyncio
|
33 |
+
import tempfile
|
34 |
+
import hashlib
|
35 |
+
from PIL import Image
|
36 |
+
import base64
|
37 |
+
from typing import Optional, List, Union, Dict, Any
|
38 |
+
import uuid
|
39 |
+
import subprocess
|
40 |
+
import json
|
41 |
+
from starlette.middleware.cors import CORSMiddleware
|
42 |
+
import numpy as np
|
43 |
+
from typing import Dict, Any
|
44 |
+
from fastapi.staticfiles import StaticFiles
|
45 |
+
from fastapi.templating import Jinja2Templates
|
46 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
47 |
+
from transformers import AutoImageProcessor, pipeline
|
48 |
+
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
49 |
+
from fastapi.security.api_key import APIKeyCookie
|
50 |
+
from fastapi import Depends, Security, status, APIRouter, UploadFile, File, Request
|
51 |
+
from fastapi.security import APIKeyHeader, OAuth2PasswordRequestForm
|
52 |
+
from passlib.context import CryptContext
|
53 |
+
from jose import JWTError, jwt
|
54 |
+
from datetime import datetime, timedelta
|
55 |
+
from starlette.requests import Request
|
56 |
+
import logging
|
57 |
+
from pydantic import EmailStr, constr, ValidationError
|
58 |
+
from database import insert_user, get_user, delete_user, update_user, create_db_and_table
|
59 |
+
from starlette.middleware import Middleware
|
60 |
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
61 |
+
from starlette.types import ASGIApp
|
62 |
+
import uvicorn
|
63 |
+
from starlette.responses import StreamingResponse
|
64 |
+
import logging
|
65 |
+
from pydantic import EmailStr, constr, ValidationError
|
66 |
+
from database import insert_user, get_user, delete_user, update_user, create_db_and_table, get_all_users
|
67 |
+
from starlette.middleware import Middleware
|
68 |
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
69 |
+
from starlette.types import ASGIApp
|
70 |
+
import uvicorn
|
71 |
+
from starlette.responses import StreamingResponse
|
72 |
+
import logging
|
73 |
+
from fastapi.exceptions import RequestValidationError
|
74 |
+
from fastapi import Request, status, Depends
|
75 |
+
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
|
76 |
+
from jose import JWTError, jwt
|
77 |
+
from passlib.context import CryptContext
|
78 |
+
from datetime import datetime, timedelta
|
79 |
+
from typing import Optional
|
80 |
+
|
81 |
+
#setting up logging
|
82 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
83 |
+
logger = logging.getLogger(__name__)
|
84 |
+
|
85 |
+
#JWT Settings
|
86 |
+
SECRET_KEY = os.getenv("SECRET_KEY")
|
87 |
+
if not SECRET_KEY:
|
88 |
+
raise ValueError("SECRET_KEY must be set.")
|
89 |
+
ALGORITHM = "HS256"
|
90 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
91 |
+
|
92 |
+
#Password Hashing
|
93 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
94 |
+
|
95 |
+
#Database connection - replace with your database setup
|
96 |
+
#Example using SQLite
|
97 |
+
import sqlite3
|
98 |
+
conn = sqlite3.connect('users.db')
|
99 |
+
cursor = conn.cursor()
|
100 |
+
|
101 |
+
#OAuth2
|
102 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
103 |
+
|
104 |
+
#API Key
|
105 |
+
API_KEY = os.getenv("API_KEY")
|
106 |
+
api_key_header = APIKeyHeader(name="X-API-Key")
|
107 |
+
|
108 |
+
#Configuration
|
109 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
110 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
111 |
AWS_REGION = os.getenv("AWS_REGION")
|
112 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
113 |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
114 |
+
TEMP_DIR = "/tmp"
|
115 |
+
STATIC_DIR = "static"
|
116 |
+
TEMPLATES = Jinja2Templates(directory="templates")
|
117 |
|
|
|
|
|
|
|
|
|
118 |
app = FastAPI()
|
119 |
+
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
120 |
+
app.add_middleware(GZipMiddleware)
|
121 |
+
|
122 |
+
origins = ["*"]
|
123 |
+
app.add_middleware(
|
124 |
+
CORSMiddleware,
|
125 |
+
allow_origins=origins,
|
126 |
+
allow_credentials=True,
|
127 |
+
allow_methods=["*"],
|
128 |
+
allow_headers=["*"],
|
129 |
+
)
|
130 |
+
|
131 |
+
class User(BaseModel):
|
132 |
+
username: constr(min_length=3, max_length=50)
|
133 |
+
email: EmailStr
|
134 |
+
password: constr(min_length=8)
|
135 |
|
|
|
136 |
class GenerateRequest(BaseModel):
|
137 |
model_name: str
|
138 |
+
input_text: Optional[str] = Field(None, description="Input text for generation.")
|
139 |
+
task_type: str = Field(..., description="Type of generation task (text, image, audio, video, classification, translation, question-answering, speech-to-text, text-to-speech, image-segmentation, feature-extraction, token-classification, fill-mask, image-inpainting, image-super-resolution, object-detection, image-captioning, audio-transcription, summarization).")
|
140 |
temperature: float = 1.0
|
141 |
max_new_tokens: int = 200
|
142 |
stream: bool = True
|
|
|
146 |
num_return_sequences: int = 1
|
147 |
do_sample: bool = True
|
148 |
chunk_delay: float = 0.0
|
149 |
+
stop_sequences: List[str] = []
|
150 |
+
image_file: Optional[UploadFile] = None
|
151 |
+
source_language: Optional[str] = None
|
152 |
+
target_language: Optional[str] = None
|
153 |
+
context: Optional[str] = None
|
154 |
+
audio_file: Optional[UploadFile] = None
|
155 |
+
raw_input: Optional[Union[str, bytes]] = None # for feature extraction
|
156 |
+
masked_text: Optional[str] = None # for fill-mask
|
157 |
+
mask_image: Optional[UploadFile] = None # for image inpainting
|
158 |
+
low_res_image: Optional[UploadFile] = None # for image super-resolution
|
159 |
+
|
160 |
+
|
161 |
+
@validator("task_type")
|
162 |
+
def validate_task_type(cls, value):
|
163 |
+
allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
|
164 |
+
if value not in allowed_types:
|
165 |
+
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
166 |
+
return value
|
167 |
+
|
168 |
+
@root_validator
|
169 |
+
def check_input(cls, values):
|
170 |
+
task_type = values.get("task_type")
|
171 |
+
if task_type == "text" and values.get("input_text") is None:
|
172 |
+
raise ValueError("input_text is required for text generation.")
|
173 |
+
elif task_type == "speech-to-text" and values.get("audio_file") is None:
|
174 |
+
raise ValueError("audio_file is required for speech-to-text.")
|
175 |
+
elif task_type == "classification" and values.get("image_file") is None:
|
176 |
+
raise ValueError("image_file is required for image classification.")
|
177 |
+
elif task_type == "image-segmentation" and values.get("image_file") is None:
|
178 |
+
raise ValueError("image_file is required for image segmentation.")
|
179 |
+
elif task_type == "feature-extraction" and values.get("raw_input") is None:
|
180 |
+
raise ValueError("raw_input is required for feature extraction.")
|
181 |
+
elif task_type == "fill-mask" and values.get("masked_text") is None:
|
182 |
+
raise ValueError("masked_text is required for fill-mask.")
|
183 |
+
elif task_type == "image-inpainting" and (values.get("image_file") is None or values.get("mask_image") is None):
|
184 |
+
raise ValueError("image_file and mask_image are required for image inpainting.")
|
185 |
+
elif task_type == "image-super-resolution" and values.get("low_res_image") is None:
|
186 |
+
raise ValueError("low_res_image is required for image super-resolution.")
|
187 |
+
return values
|
188 |
+
|
189 |
+
|
190 |
|
|
|
191 |
class S3ModelLoader:
|
192 |
+
def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
|
193 |
self.bucket_name = bucket_name
|
194 |
+
self.s3 = boto3.client(
|
195 |
's3',
|
196 |
aws_access_key_id=aws_access_key_id,
|
197 |
aws_secret_access_key=aws_secret_access_key,
|
|
|
199 |
)
|
200 |
|
201 |
def _get_s3_uri(self, model_name):
|
202 |
+
return f"{self.bucket_name}/{model_name.replace('/', '-')}"
|
203 |
+
|
204 |
+
def load_model_and_tokenizer(self, model_name, task_type):
|
|
|
|
|
|
|
205 |
s3_uri = self._get_s3_uri(model_name)
|
206 |
try:
|
207 |
+
self.s3.head_object(Bucket=self.bucket_name, Key=f'{s3_uri}/config.json')
|
208 |
+
except ClientError as e:
|
209 |
+
if e.response['Error']['Code'] == '404':
|
210 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
211 |
+
model_path = snapshot_download(model_name, token=HUGGINGFACE_HUB_TOKEN, cache_dir=tmpdir)
|
212 |
+
self._upload_model_to_s3(model_path, s3_uri)
|
213 |
+
else:
|
214 |
+
raise HTTPException(status_code=500, detail=f"Error accessing S3: {e}")
|
215 |
+
return self._load_from_s3(s3_uri, task_type)
|
216 |
+
|
217 |
+
def _upload_model_to_s3(self, model_path, s3_uri):
|
218 |
+
for root, _, files in os.walk(model_path):
|
219 |
+
for file in files:
|
220 |
+
local_path = os.path.join(root, file)
|
221 |
+
s3_path = os.path.join(s3_uri, os.path.relpath(local_path, model_path))
|
222 |
+
self.s3.upload_file(local_path, self.bucket_name, s3_path)
|
223 |
+
|
224 |
+
def _load_from_s3(self, s3_uri, task_type):
|
225 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
226 |
+
model_path = os.path.join(tmpdir, s3_uri)
|
227 |
+
os.makedirs(model_path, exist_ok=True)
|
228 |
+
self.s3.download_file(self.bucket_name, f"{s3_uri}/config.json", os.path.join(model_path, "config.json"))
|
229 |
+
if task_type == "text":
|
230 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True)
|
231 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
|
|
|
232 |
if tokenizer.eos_token_id is None:
|
233 |
tokenizer.eos_token_id = tokenizer.pad_token_id
|
234 |
+
return {"model": model, "tokenizer": tokenizer, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id}
|
235 |
+
elif task_type in ["image", "audio", "video"]:
|
236 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
237 |
+
pipeline_function = pipeline(task_type, model=model_path, device=0 if torch.cuda.is_available() else -1, processor=processor)
|
238 |
+
return {"pipeline": pipeline_function}
|
239 |
+
elif task_type == "classification":
|
240 |
+
model = AutoModelForImageClassification.from_pretrained(model_path)
|
241 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
242 |
+
return {"model": model, "processor": processor}
|
243 |
+
elif task_type == "translation":
|
244 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
245 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
246 |
+
return {"model": model, "tokenizer": tokenizer}
|
247 |
+
elif task_type == "question-answering":
|
248 |
+
model = AutoModelForQuestionAnswering.from_pretrained(model_path)
|
249 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
250 |
+
return {"model": model, "tokenizer": tokenizer}
|
251 |
+
elif task_type == "speech-to-text":
|
252 |
+
model = pipeline("automatic-speech-recognition", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
253 |
+
return {"pipeline": model}
|
254 |
+
elif task_type == "text-to-speech":
|
255 |
+
model = pipeline("text-to-speech", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
256 |
+
return {"pipeline": model}
|
257 |
+
elif task_type == "image-segmentation":
|
258 |
+
model = pipeline("image-segmentation", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
259 |
+
return {"pipeline": model}
|
260 |
+
elif task_type == "feature-extraction":
|
261 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
|
262 |
+
return {"feature_extractor": feature_extractor}
|
263 |
+
elif task_type == "token-classification":
|
264 |
+
model = AutoModelForTokenClassification.from_pretrained(model_path)
|
265 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
266 |
+
return {"model": model, "tokenizer": tokenizer}
|
267 |
+
elif task_type == "fill-mask":
|
268 |
+
model = AutoModelForMaskedLM.from_pretrained(model_path)
|
269 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
270 |
+
return {"model": model, "tokenizer": tokenizer}
|
271 |
+
elif task_type == "image-inpainting":
|
272 |
+
model = pipeline("image-inpainting", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
273 |
+
return {"pipeline": model}
|
274 |
+
elif task_type == "image-super-resolution":
|
275 |
+
model = pipeline("image-super-resolution", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
276 |
+
return {"pipeline": model}
|
277 |
+
elif task_type == "object-detection":
|
278 |
+
model = pipeline("object-detection", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
279 |
+
image_processor = AutoImageProcessor.from_pretrained(model_path)
|
280 |
+
return {"pipeline": model, "image_processor": image_processor}
|
281 |
+
elif task_type == "image-captioning":
|
282 |
+
model = pipeline("image-captioning", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
283 |
+
return {"pipeline": model}
|
284 |
+
elif task_type == "audio-transcription":
|
285 |
+
model = pipeline("automatic-speech-recognition", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
286 |
+
return {"pipeline": model}
|
287 |
+
elif task_type == "summarization":
|
288 |
+
model = pipeline("summarization", model=model_path, device=0 if torch.cuda.is_available() else -1)
|
289 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
290 |
+
return {"model": model, "tokenizer": tokenizer}
|
291 |
+
else:
|
292 |
+
raise ValueError("Unsupported task type")
|
293 |
|
294 |
+
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
|
295 |
+
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
input_length = encoded_input["input_ids"].shape[1]
|
297 |
+
max_length = model.config.max_length
|
298 |
remaining_tokens = max_length - input_length
|
|
|
299 |
if remaining_tokens <= 0:
|
300 |
yield ""
|
|
|
301 |
generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
|
|
|
302 |
def stop_criteria(input_ids, scores):
|
303 |
+
decoded_output = tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
|
304 |
return decoded_output in stop_sequences
|
|
|
305 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
|
|
|
|
306 |
outputs = model.generate(
|
307 |
**encoded_input,
|
308 |
do_sample=generation_config.do_sample,
|
|
|
316 |
output_scores=True,
|
317 |
return_dict_in_generate=True
|
318 |
)
|
|
|
319 |
for output in outputs.sequences:
|
320 |
for token_id in output:
|
321 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
322 |
yield token
|
|
|
323 |
|
|
|
|
|
|
|
324 |
|
325 |
+
model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
+
def get_model_data(request: GenerateRequest):
|
328 |
+
return model_loader.load_model_and_tokenizer(request.model_name, request.task_type)
|
329 |
+
|
330 |
+
async def verify_api_key(api_key: str = Depends(api_key_header)):
|
331 |
+
if api_key != API_KEY:
|
332 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
|
|
|
|
|
|
|
|
334 |
|
335 |
+
@app.post("/generate", dependencies=[Depends(verify_api_key)])
|
336 |
+
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data = Depends(get_model_data)):
|
337 |
+
try:
|
338 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
339 |
+
if request.task_type == "text":
|
340 |
+
model = model_data["model"].to(device)
|
341 |
+
tokenizer = model_data["tokenizer"]
|
342 |
+
generation_config = GenerationConfig(
|
343 |
+
temperature=request.temperature,
|
344 |
+
max_new_tokens=request.max_new_tokens,
|
345 |
+
top_p=request.top_p,
|
346 |
+
top_k=request.top_k,
|
347 |
+
repetition_penalty=request.repetition_penalty,
|
348 |
+
do_sample=request.do_sample,
|
349 |
+
num_return_sequences=request.num_return_sequences,
|
350 |
+
)
|
351 |
+
async def stream_with_tokens():
|
352 |
+
async for token in stream_text(model, tokenizer, request.input_text, generation_config, request.stop_sequences, device, request.chunk_delay):
|
353 |
+
yield f"Token: {token}\n"
|
354 |
+
return StreamingResponse(stream_with_tokens(), media_type="text/plain")
|
355 |
+
elif request.task_type in ["image", "audio", "video"]:
|
356 |
+
pipeline = model_data["pipeline"]
|
357 |
+
result = pipeline(request.input_text)
|
358 |
+
if request.task_type == "image":
|
359 |
+
image = result[0]
|
360 |
+
img_byte_arr = BytesIO()
|
361 |
+
image.save(img_byte_arr, format="PNG")
|
362 |
+
img_byte_arr.seek(0)
|
363 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
364 |
+
elif request.task_type == "audio":
|
365 |
+
audio = result[0]
|
366 |
+
audio_byte_arr = BytesIO()
|
367 |
+
audio.save(audio_byte_arr, format="wav")
|
368 |
+
audio_byte_arr.seek(0)
|
369 |
+
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
370 |
+
elif request.task_type == "video":
|
371 |
+
video = result[0]
|
372 |
+
video_byte_arr = BytesIO()
|
373 |
+
video.save(video_byte_arr, format="mp4")
|
374 |
+
video_byte_arr.seek(0)
|
375 |
+
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
376 |
+
elif request.task_type == "classification":
|
377 |
+
if request.image_file is None:
|
378 |
+
raise HTTPException(status_code=400, detail="Image file is required for classification.")
|
379 |
+
contents = await request.image_file.read()
|
380 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
381 |
+
model = model_data["model"].to(device)
|
382 |
+
processor = model_data["processor"]
|
383 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
384 |
+
with torch.no_grad():
|
385 |
+
outputs = model(**inputs)
|
386 |
+
predicted_class_idx = outputs.logits.argmax().item()
|
387 |
+
predicted_class = model.config.id2label[predicted_class_idx]
|
388 |
+
return JSONResponse({"predicted_class": predicted_class})
|
389 |
+
elif request.task_type == "translation":
|
390 |
+
if request.source_language is None or request.target_language is None:
|
391 |
+
raise HTTPException(status_code=400, detail="Source and target languages are required for translation.")
|
392 |
+
model = model_data["model"].to(device)
|
393 |
+
tokenizer = model_data["tokenizer"]
|
394 |
+
inputs = tokenizer(request.input_text, return_tensors="pt").to(device)
|
395 |
+
with torch.no_grad():
|
396 |
+
outputs = model.generate(**inputs)
|
397 |
+
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
398 |
+
return JSONResponse({"translation": translation})
|
399 |
+
elif request.task_type == "question-answering":
|
400 |
+
if request.context is None:
|
401 |
+
raise HTTPException(status_code=400, detail="Context is required for question answering.")
|
402 |
+
model = model_data["model"].to(device)
|
403 |
+
tokenizer = model_data["tokenizer"]
|
404 |
+
inputs = tokenizer(question=request.input_text, context=request.context, return_tensors="pt").to(device)
|
405 |
+
with torch.no_grad():
|
406 |
+
outputs = model(**inputs)
|
407 |
+
answer_start = torch.argmax(outputs.start_logits)
|
408 |
+
answer_end = torch.argmax(outputs.end_logits) + 1
|
409 |
+
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
|
410 |
+
return JSONResponse({"answer": answer})
|
411 |
+
elif request.task_type == "speech-to-text":
|
412 |
+
if request.audio_file is None:
|
413 |
+
raise HTTPException(status_code=400, detail="Audio file is required for speech-to-text.")
|
414 |
+
contents = await request.audio_file.read()
|
415 |
+
pipeline = model_data["pipeline"]
|
416 |
+
try:
|
417 |
+
transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
|
418 |
+
return JSONResponse({"transcription": transcription})
|
419 |
+
except Exception as e:
|
420 |
+
raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}")
|
421 |
+
|
422 |
+
elif request.task_type == "text-to-speech":
|
423 |
+
if not request.input_text:
|
424 |
+
raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
|
425 |
+
pipeline = model_data["pipeline"]
|
426 |
+
try:
|
427 |
+
audio = pipeline(request.input_text)[0]
|
428 |
+
file_path = os.path.join(TEMP_DIR, f"{uuid.uuid4()}.wav")
|
429 |
+
audio.save(file_path)
|
430 |
+
background_tasks.add_task(os.remove, file_path)
|
431 |
+
return FileResponse(file_path, media_type="audio/wav")
|
432 |
+
except Exception as e:
|
433 |
+
raise HTTPException(status_code=500, detail=f"Error during text-to-speech: {str(e)}")
|
434 |
+
|
435 |
+
elif request.task_type == "image-segmentation":
|
436 |
+
if request.image_file is None:
|
437 |
+
raise HTTPException(status_code=400, detail="Image file is required for image segmentation.")
|
438 |
+
contents = await request.image_file.read()
|
439 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
440 |
+
pipeline = model_data["pipeline"]
|
441 |
+
result = pipeline(image)
|
442 |
+
mask = result[0]['mask']
|
443 |
+
mask_byte_arr = BytesIO()
|
444 |
+
mask.save(mask_byte_arr, format="PNG")
|
445 |
+
mask_byte_arr.seek(0)
|
446 |
+
return StreamingResponse(mask_byte_arr, media_type="image/png")
|
447 |
+
elif request.task_type == "feature-extraction":
|
448 |
+
if request.raw_input is None:
|
449 |
+
raise HTTPException(status_code=400, detail="raw_input is required for feature extraction.")
|
450 |
+
feature_extractor = model_data["feature_extractor"]
|
451 |
+
try:
|
452 |
+
if isinstance(request.raw_input, str):
|
453 |
+
inputs = feature_extractor(text=request.raw_input, return_tensors="pt")
|
454 |
+
elif isinstance(request.raw_input, bytes):
|
455 |
+
image = Image.open(BytesIO(request.raw_input)).convert("RGB")
|
456 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
457 |
+
else:
|
458 |
+
raise ValueError("Unsupported raw_input type.")
|
459 |
+
features = inputs.pixel_values # Adjust according to your feature extractor
|
460 |
+
return JSONResponse({"features": features.tolist()})
|
461 |
+
except Exception as fe:
|
462 |
+
raise HTTPException(status_code=400, detail=f"Error during feature extraction: {fe}")
|
463 |
+
elif request.task_type == "token-classification":
|
464 |
+
if request.input_text is None:
|
465 |
+
raise HTTPException(status_code=400, detail="Input text is required for token classification.")
|
466 |
+
model = model_data["model"].to(device)
|
467 |
+
tokenizer = model_data["tokenizer"]
|
468 |
+
inputs = tokenizer(request.input_text, return_tensors="pt", padding=True, truncation=True)
|
469 |
+
with torch.no_grad():
|
470 |
+
outputs = model(**inputs)
|
471 |
+
predictions = outputs.logits.argmax(dim=-1)
|
472 |
+
predicted_labels = [model.config.id2label[label_id] for label_id in predictions[0].tolist()]
|
473 |
+
return JSONResponse({"predicted_labels": predicted_labels})
|
474 |
+
elif request.task_type == "fill-mask":
|
475 |
+
if request.masked_text is None:
|
476 |
+
raise HTTPException(status_code=400, detail="masked_text is required for fill-mask.")
|
477 |
+
model = model_data["model"].to(device)
|
478 |
+
tokenizer = model_data["tokenizer"]
|
479 |
+
inputs = tokenizer(request.masked_text, return_tensors="pt")
|
480 |
+
with torch.no_grad():
|
481 |
+
outputs = model(**inputs)
|
482 |
+
logits = outputs.logits
|
483 |
+
masked_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
|
484 |
+
predicted_token_id = torch.argmax(logits[0, masked_index])
|
485 |
+
predicted_token = tokenizer.decode(predicted_token_id)
|
486 |
+
return JSONResponse({"predicted_token": predicted_token})
|
487 |
+
elif request.task_type == "image-inpainting":
|
488 |
+
if request.image_file is None or request.mask_image is None:
|
489 |
+
raise HTTPException(status_code=400, detail="image_file and mask_image are required for image inpainting.")
|
490 |
+
image_contents = await request.image_file.read()
|
491 |
+
mask_contents = await request.mask_image.read()
|
492 |
+
image = Image.open(BytesIO(image_contents)).convert("RGB")
|
493 |
+
mask = Image.open(BytesIO(mask_contents)).convert("L") # Assuming mask is grayscale
|
494 |
+
pipeline = model_data["pipeline"]
|
495 |
+
result = pipeline(image, mask)
|
496 |
+
inpainted_image = result[0]
|
497 |
+
img_byte_arr = BytesIO()
|
498 |
+
inpainted_image.save(img_byte_arr, format="PNG")
|
499 |
+
img_byte_arr.seek(0)
|
500 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
501 |
+
elif request.task_type == "image-super-resolution":
|
502 |
+
if request.low_res_image is None:
|
503 |
+
raise HTTPException(status_code=400, detail="low_res_image is required for image super-resolution.")
|
504 |
+
contents = await request.low_res_image.read()
|
505 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
506 |
+
pipeline = model_data["pipeline"]
|
507 |
+
result = pipeline(image)
|
508 |
+
upscaled_image = result[0]
|
509 |
+
img_byte_arr = BytesIO()
|
510 |
+
upscaled_image.save(img_byte_arr, format="PNG")
|
511 |
+
img_byte_arr.seek(0)
|
512 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
513 |
+
elif request.task_type == "object-detection":
|
514 |
+
if request.image_file is None:
|
515 |
+
raise HTTPException(status_code=400, detail="Image file is required for object detection.")
|
516 |
+
contents = await request.image_file.read()
|
517 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
518 |
+
pipeline = model_data["pipeline"]
|
519 |
+
image_processor = model_data["image_processor"]
|
520 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
521 |
+
with torch.no_grad():
|
522 |
+
outputs = pipeline(image)
|
523 |
+
detections = outputs
|
524 |
+
return JSONResponse({"detections": detections})
|
525 |
+
elif request.task_type == "image-captioning":
|
526 |
+
if request.image_file is None:
|
527 |
+
raise HTTPException(status_code=400, detail="Image file is required for image captioning.")
|
528 |
+
contents = await request.image_file.read()
|
529 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
530 |
+
pipeline = model_data["pipeline"]
|
531 |
+
caption = pipeline(image)[0]['generated_text']
|
532 |
+
return JSONResponse({"caption": caption})
|
533 |
+
elif request.task_type == "audio-transcription":
|
534 |
+
if request.audio_file is None:
|
535 |
+
raise HTTPException(status_code=400, detail="Audio file is required for audio transcription.")
|
536 |
+
try:
|
537 |
+
contents = await request.audio_file.read()
|
538 |
+
pipeline = model_data["pipeline"]
|
539 |
+
try:
|
540 |
+
transcription = pipeline(contents, sampling_rate=16000)[0]["text"] # Assuming 16kHz sampling rate
|
541 |
+
return JSONResponse({"transcription": transcription})
|
542 |
+
except Exception as e:
|
543 |
+
raise HTTPException(status_code=500, detail=f"Error during audio transcription (pipeline): {str(e)}")
|
544 |
+
except Exception as e:
|
545 |
+
raise HTTPException(status_code=500, detail=f"Error during audio transcription (file read): {str(e)}")
|
546 |
+
elif request.task_type == "summarization":
|
547 |
+
if request.input_text is None:
|
548 |
+
raise HTTPException(status_code=400, detail="Input text is required for summarization.")
|
549 |
+
model = model_data["model"].to(device)
|
550 |
+
tokenizer = model_data["tokenizer"]
|
551 |
+
inputs = tokenizer(request.input_text, return_tensors="pt", truncation=True, max_length=512) # added max_length for summarization
|
552 |
+
with torch.no_grad():
|
553 |
+
outputs = model.generate(**inputs)
|
554 |
+
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
555 |
+
return JSONResponse({"summary": summary})
|
556 |
+
|
557 |
+
else:
|
558 |
+
raise HTTPException(status_code=500, detail=f"Unsupported task type")
|
559 |
except Exception as e:
|
560 |
+
logger.exception(f"Internal server error: {str(e)}")
|
561 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
562 |
|
563 |
+
|
564 |
+
@app.get("/", response_class=HTMLResponse)
|
565 |
+
async def root(request: Request):
|
566 |
+
return TEMPLATES.TemplateResponse("index.html", {"request": request})
|
567 |
+
|
568 |
+
@app.get("/health")
|
569 |
+
async def health_check():
|
570 |
+
return {"status": "healthy"}
|
571 |
+
|
572 |
+
# Authentication Endpoints
|
573 |
+
|
574 |
+
@app.post("/token", response_model=Token)
|
575 |
+
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
576 |
+
user = authenticate_user(form_data.username, form_data.password)
|
577 |
+
if not user:
|
578 |
+
raise HTTPException(
|
579 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
580 |
+
detail="Incorrect username or password",
|
581 |
+
headers={"WWW-Authenticate": "Bearer"},
|
582 |
+
)
|
583 |
+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
584 |
+
access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
|
585 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
586 |
+
|
587 |
+
def authenticate_user(username: str, password: str):
|
588 |
+
user = get_user(username)
|
589 |
+
if user and pwd_context.verify(password, user.hashed_password):
|
590 |
+
return {"username": user.username}
|
591 |
+
return None
|
592 |
+
|
593 |
+
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
|
594 |
+
to_encode = data.copy()
|
595 |
+
if expires_delta:
|
596 |
+
expire = datetime.utcnow() + expires_delta
|
597 |
+
else:
|
598 |
+
expire = datetime.utcnow() + timedelta(minutes=15)
|
599 |
+
to_encode.update({"exp": expire})
|
600 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
601 |
+
return encoded_jwt
|
602 |
+
|
603 |
+
class Token(BaseModel):
|
604 |
+
access_token: str
|
605 |
+
token_type: str
|
606 |
+
|
607 |
+
|
608 |
+
@app.get("/users/me")
|
609 |
+
async def read_users_me(current_user: str = Depends(get_current_user)):
|
610 |
+
return {"username": current_user}
|
611 |
+
|
612 |
+
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
613 |
+
credentials_exception = HTTPException(
|
614 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
615 |
+
detail="Could not validate credentials",
|
616 |
+
headers={"WWW-Authenticate": "Bearer"},
|
617 |
+
)
|
618 |
try:
|
619 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
620 |
+
username: str = payload.get("sub")
|
621 |
+
if username is None:
|
622 |
+
raise credentials_exception
|
623 |
+
token_data = {"username": username, "token": token}
|
624 |
+
except JWTError:
|
625 |
+
raise credentials_exception
|
626 |
+
user = get_user(username)
|
627 |
+
if user is None:
|
628 |
+
raise credentials_exception
|
629 |
+
return username
|
630 |
+
|
631 |
+
|
632 |
+
@app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
|
633 |
+
async def create_user(user: User):
|
634 |
+
try:
|
635 |
+
hashed_password = pwd_context.hash(user.password)
|
636 |
+
new_user = {"username": user.username, "email": user.email, "hashed_password": hashed_password}
|
637 |
+
inserted_user = insert_user(new_user)
|
638 |
+
if inserted_user:
|
639 |
+
return User(**inserted_user)
|
640 |
+
else:
|
641 |
+
raise HTTPException(status_code=500, detail="Failed to create user.")
|
642 |
+
except Exception as e:
|
643 |
+
logger.error(f"Error creating user: {e}")
|
644 |
+
raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
|
645 |
|
|
|
|
|
|
|
646 |
|
647 |
+
@app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
|
648 |
+
async def update_user_data(username: str, user: User):
|
649 |
+
try:
|
650 |
+
hashed_password = pwd_context.hash(user.password)
|
651 |
+
updated_user_data = {"email": user.email, "hashed_password": hashed_password}
|
652 |
+
updated_user = update_user(username, updated_user_data)
|
653 |
+
if updated_user:
|
654 |
+
return User(**updated_user)
|
655 |
+
else:
|
656 |
+
raise HTTPException(status_code=404, detail="User not found")
|
657 |
|
658 |
except Exception as e:
|
659 |
+
logger.error(f"Error updating user: {e}")
|
660 |
+
raise HTTPException(status_code=500, detail="Error updating user.")
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
+
@app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
|
665 |
+
async def delete_user_account(username: str):
|
666 |
+
try:
|
667 |
+
deleted_user = delete_user(username)
|
668 |
+
if deleted_user:
|
669 |
+
return JSONResponse({"message": "User deleted successfully."}, status_code=200)
|
670 |
+
else:
|
671 |
+
raise HTTPException(status_code=404, detail="User not found")
|
672 |
+
except Exception as e:
|
673 |
+
logger.error(f"Error deleting user: {e}")
|
674 |
+
raise HTTPException(status_code=500, detail="Error deleting user.")
|
675 |
+
|
676 |
+
|
677 |
+
@app.get("/users", dependencies=[Depends(get_current_user)])
|
678 |
+
async def get_all_users_route():
|
679 |
+
return get_all_users()
|
680 |
+
|
681 |
+
|
682 |
+
|
683 |
+
@app.exception_handler(RequestValidationError)
|
684 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
685 |
+
return JSONResponse(
|
686 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
687 |
+
content=json.dumps({"detail": exc.errors(), "body": exc.body}),
|
688 |
+
)
|
689 |
+
|
690 |
|
|
|
691 |
if __name__ == "__main__":
|
692 |
+
|
693 |
+
create_db_and_table() # Initialize database on startup
|
694 |
+
|
695 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) # replace main with your filename
|