Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import torch
|
3 |
-
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Request
|
4 |
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse
|
5 |
from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr
|
6 |
from transformers import (
|
@@ -31,7 +33,7 @@ from PIL import Image
|
|
31 |
from typing import Optional, List, Union, Dict, Any
|
32 |
import uuid
|
33 |
import logging
|
34 |
-
import
|
35 |
from passlib.context import CryptContext
|
36 |
from jose import JWTError, jwt
|
37 |
from datetime import datetime, timedelta
|
@@ -41,7 +43,6 @@ from fastapi.middleware.gzip import GZipMiddleware
|
|
41 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
42 |
from starlette.middleware.cors import CORSMiddleware
|
43 |
|
44 |
-
|
45 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
46 |
logger = logging.getLogger(__name__)
|
47 |
|
@@ -53,7 +54,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
53 |
|
54 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
55 |
|
56 |
-
conn = sqlite3.connect('users.db')
|
57 |
cursor = conn.cursor()
|
58 |
cursor.execute('''
|
59 |
CREATE TABLE IF NOT EXISTS users (
|
@@ -80,7 +81,13 @@ TEMPLATES = Jinja2Templates(directory="templates")
|
|
80 |
app = FastAPI()
|
81 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
82 |
app.add_middleware(GZipMiddleware)
|
83 |
-
app.add_middleware(
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
|
86 |
class User(BaseModel):
|
@@ -112,14 +119,34 @@ class GenerateRequest(BaseModel):
|
|
112 |
mask_image: Optional[UploadFile] = None
|
113 |
low_res_image: Optional[UploadFile] = None
|
114 |
|
115 |
-
@
|
116 |
def validate_task_type(cls, value):
|
117 |
-
allowed_types = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
if value not in allowed_types:
|
119 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
120 |
return value
|
121 |
|
122 |
-
@
|
123 |
def check_input(cls, values):
|
124 |
task_type = values.get("task_type")
|
125 |
if task_type == "text" and values.get("input_text") is None:
|
@@ -287,7 +314,6 @@ async def verify_api_key(api_key: str = Depends(api_key_header)):
|
|
287 |
if api_key != API_KEY:
|
288 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
|
289 |
|
290 |
-
|
291 |
@app.post("/generate", dependencies=[Depends(verify_api_key)])
|
292 |
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data=Depends(get_model_data)):
|
293 |
try:
|
@@ -373,7 +399,8 @@ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks,
|
|
373 |
transcription = pipeline_func(contents, sampling_rate=16000)[0]["text"]
|
374 |
return JSONResponse({"transcription": transcription})
|
375 |
except Exception as e:
|
376 |
-
|
|
|
377 |
elif request.task_type == "text-to-speech":
|
378 |
if not request.input_text:
|
379 |
raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
|
@@ -634,7 +661,11 @@ async def get_all_users_route():
|
|
634 |
|
635 |
@app.exception_handler(RequestValidationError)
|
636 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
637 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
638 |
|
639 |
if __name__ == "__main__":
|
640 |
-
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
1 |
import os
|
2 |
+
import json
|
3 |
+
import uvicorn
|
4 |
import torch
|
5 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Request, status
|
6 |
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse
|
7 |
from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr
|
8 |
from transformers import (
|
|
|
33 |
from typing import Optional, List, Union, Dict, Any
|
34 |
import uuid
|
35 |
import logging
|
36 |
+
from fastapi.exceptions import RequestValidationError
|
37 |
from passlib.context import CryptContext
|
38 |
from jose import JWTError, jwt
|
39 |
from datetime import datetime, timedelta
|
|
|
43 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
44 |
from starlette.middleware.cors import CORSMiddleware
|
45 |
|
|
|
46 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
47 |
logger = logging.getLogger(__name__)
|
48 |
|
|
|
54 |
|
55 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
56 |
|
57 |
+
conn = sqlite3.connect('users.db', check_same_thread=False)
|
58 |
cursor = conn.cursor()
|
59 |
cursor.execute('''
|
60 |
CREATE TABLE IF NOT EXISTS users (
|
|
|
81 |
app = FastAPI()
|
82 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
83 |
app.add_middleware(GZipMiddleware)
|
84 |
+
app.add_middleware(
|
85 |
+
CORSMiddleware,
|
86 |
+
allow_origins=["*"],
|
87 |
+
allow_credentials=True,
|
88 |
+
allow_methods=["*"],
|
89 |
+
allow_headers=["*"],
|
90 |
+
)
|
91 |
|
92 |
|
93 |
class User(BaseModel):
|
|
|
119 |
mask_image: Optional[UploadFile] = None
|
120 |
low_res_image: Optional[UploadFile] = None
|
121 |
|
122 |
+
@validator('task_type')
|
123 |
def validate_task_type(cls, value):
|
124 |
+
allowed_types = [
|
125 |
+
"text",
|
126 |
+
"image",
|
127 |
+
"audio",
|
128 |
+
"video",
|
129 |
+
"classification",
|
130 |
+
"translation",
|
131 |
+
"question-answering",
|
132 |
+
"speech-to-text",
|
133 |
+
"text-to-speech",
|
134 |
+
"image-segmentation",
|
135 |
+
"feature-extraction",
|
136 |
+
"token-classification",
|
137 |
+
"fill-mask",
|
138 |
+
"image-inpainting",
|
139 |
+
"image-super-resolution",
|
140 |
+
"object-detection",
|
141 |
+
"image-captioning",
|
142 |
+
"audio-transcription",
|
143 |
+
"summarization",
|
144 |
+
]
|
145 |
if value not in allowed_types:
|
146 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
147 |
return value
|
148 |
|
149 |
+
@root_validator(pre=True)
|
150 |
def check_input(cls, values):
|
151 |
task_type = values.get("task_type")
|
152 |
if task_type == "text" and values.get("input_text") is None:
|
|
|
314 |
if api_key != API_KEY:
|
315 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
|
316 |
|
|
|
317 |
@app.post("/generate", dependencies=[Depends(verify_api_key)])
|
318 |
async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data=Depends(get_model_data)):
|
319 |
try:
|
|
|
399 |
transcription = pipeline_func(contents, sampling_rate=16000)[0]["text"]
|
400 |
return JSONResponse({"transcription": transcription})
|
401 |
except Exception as e:
|
402 |
+
logger.exception(f"Error during speech-to-text: {e}")
|
403 |
+
raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}") from e
|
404 |
elif request.task_type == "text-to-speech":
|
405 |
if not request.input_text:
|
406 |
raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
|
|
|
661 |
|
662 |
@app.exception_handler(RequestValidationError)
|
663 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
664 |
+
return JSONResponse(
|
665 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
666 |
+
content=json.dumps({"detail": exc.errors(), "body": exc.body}),
|
667 |
+
)
|
668 |
+
|
669 |
|
670 |
if __name__ == "__main__":
|
671 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|