Hjgugugjhuhjggg commited on
Commit
e31f7ec
·
verified ·
1 Parent(s): 6a7d8ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
 
2
  import torch
3
- from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Request
4
  from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse
5
  from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr
6
  from transformers import (
@@ -31,7 +33,7 @@ from PIL import Image
31
  from typing import Optional, List, Union, Dict, Any
32
  import uuid
33
  import logging
34
- import sqlite3
35
  from passlib.context import CryptContext
36
  from jose import JWTError, jwt
37
  from datetime import datetime, timedelta
@@ -41,7 +43,6 @@ from fastapi.middleware.gzip import GZipMiddleware
41
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
42
  from starlette.middleware.cors import CORSMiddleware
43
 
44
-
45
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
46
  logger = logging.getLogger(__name__)
47
 
@@ -53,7 +54,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
53
 
54
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
55
 
56
- conn = sqlite3.connect('users.db')
57
  cursor = conn.cursor()
58
  cursor.execute('''
59
  CREATE TABLE IF NOT EXISTS users (
@@ -80,7 +81,13 @@ TEMPLATES = Jinja2Templates(directory="templates")
80
  app = FastAPI()
81
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
82
  app.add_middleware(GZipMiddleware)
83
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
84
 
85
 
86
  class User(BaseModel):
@@ -112,14 +119,34 @@ class GenerateRequest(BaseModel):
112
  mask_image: Optional[UploadFile] = None
113
  low_res_image: Optional[UploadFile] = None
114
 
115
- @field_validator('task_type')
116
  def validate_task_type(cls, value):
117
- allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if value not in allowed_types:
119
  raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
120
  return value
121
 
122
- @model_validator(mode='after')
123
  def check_input(cls, values):
124
  task_type = values.get("task_type")
125
  if task_type == "text" and values.get("input_text") is None:
@@ -287,7 +314,6 @@ async def verify_api_key(api_key: str = Depends(api_key_header)):
287
  if api_key != API_KEY:
288
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
289
 
290
-
291
  @app.post("/generate", dependencies=[Depends(verify_api_key)])
292
  async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data=Depends(get_model_data)):
293
  try:
@@ -373,7 +399,8 @@ async def generate(request: GenerateRequest, background_tasks: BackgroundTasks,
373
  transcription = pipeline_func(contents, sampling_rate=16000)[0]["text"]
374
  return JSONResponse({"transcription": transcription})
375
  except Exception as e:
376
- raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}")
 
377
  elif request.task_type == "text-to-speech":
378
  if not request.input_text:
379
  raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
@@ -634,7 +661,11 @@ async def get_all_users_route():
634
 
635
  @app.exception_handler(RequestValidationError)
636
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
637
- return JSONResponse(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=json.dumps({"detail": exc.errors(), "body": exc.body}))
 
 
 
 
638
 
639
  if __name__ == "__main__":
640
- uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
1
  import os
2
+ import json
3
+ import uvicorn
4
  import torch
5
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Request, status
6
  from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse
7
  from pydantic import BaseModel, validator, Field, root_validator, EmailStr, constr
8
  from transformers import (
 
33
  from typing import Optional, List, Union, Dict, Any
34
  import uuid
35
  import logging
36
+ from fastapi.exceptions import RequestValidationError
37
  from passlib.context import CryptContext
38
  from jose import JWTError, jwt
39
  from datetime import datetime, timedelta
 
43
  from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
44
  from starlette.middleware.cors import CORSMiddleware
45
 
 
46
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
47
  logger = logging.getLogger(__name__)
48
 
 
54
 
55
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
56
 
57
+ conn = sqlite3.connect('users.db', check_same_thread=False)
58
  cursor = conn.cursor()
59
  cursor.execute('''
60
  CREATE TABLE IF NOT EXISTS users (
 
81
  app = FastAPI()
82
  app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
83
  app.add_middleware(GZipMiddleware)
84
+ app.add_middleware(
85
+ CORSMiddleware,
86
+ allow_origins=["*"],
87
+ allow_credentials=True,
88
+ allow_methods=["*"],
89
+ allow_headers=["*"],
90
+ )
91
 
92
 
93
  class User(BaseModel):
 
119
  mask_image: Optional[UploadFile] = None
120
  low_res_image: Optional[UploadFile] = None
121
 
122
+ @validator('task_type')
123
  def validate_task_type(cls, value):
124
+ allowed_types = [
125
+ "text",
126
+ "image",
127
+ "audio",
128
+ "video",
129
+ "classification",
130
+ "translation",
131
+ "question-answering",
132
+ "speech-to-text",
133
+ "text-to-speech",
134
+ "image-segmentation",
135
+ "feature-extraction",
136
+ "token-classification",
137
+ "fill-mask",
138
+ "image-inpainting",
139
+ "image-super-resolution",
140
+ "object-detection",
141
+ "image-captioning",
142
+ "audio-transcription",
143
+ "summarization",
144
+ ]
145
  if value not in allowed_types:
146
  raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
147
  return value
148
 
149
+ @root_validator(pre=True)
150
  def check_input(cls, values):
151
  task_type = values.get("task_type")
152
  if task_type == "text" and values.get("input_text") is None:
 
314
  if api_key != API_KEY:
315
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
316
 
 
317
  @app.post("/generate", dependencies=[Depends(verify_api_key)])
318
  async def generate(request: GenerateRequest, background_tasks: BackgroundTasks, model_data=Depends(get_model_data)):
319
  try:
 
399
  transcription = pipeline_func(contents, sampling_rate=16000)[0]["text"]
400
  return JSONResponse({"transcription": transcription})
401
  except Exception as e:
402
+ logger.exception(f"Error during speech-to-text: {e}")
403
+ raise HTTPException(status_code=500, detail=f"Error during speech-to-text: {str(e)}") from e
404
  elif request.task_type == "text-to-speech":
405
  if not request.input_text:
406
  raise HTTPException(status_code=400, detail="Input text is required for text-to-speech.")
 
661
 
662
  @app.exception_handler(RequestValidationError)
663
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
664
+ return JSONResponse(
665
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
666
+ content=json.dumps({"detail": exc.errors(), "body": exc.body}),
667
+ )
668
+
669
 
670
  if __name__ == "__main__":
671
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)