Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -72,7 +72,8 @@ from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
|
|
72 |
from jose import JWTError, jwt
|
73 |
from passlib.context import CryptContext
|
74 |
from datetime import datetime, timedelta
|
75 |
-
from
|
|
|
76 |
|
77 |
#setting up logging
|
78 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
@@ -153,15 +154,14 @@ class GenerateRequest(BaseModel):
|
|
153 |
mask_image: Optional[UploadFile] = None # for image inpainting
|
154 |
low_res_image: Optional[UploadFile] = None # for image super-resolution
|
155 |
|
156 |
-
|
157 |
-
@validator("task_type")
|
158 |
def validate_task_type(cls, value):
|
159 |
allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
|
160 |
if value not in allowed_types:
|
161 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
162 |
return value
|
163 |
|
164 |
-
@
|
165 |
def check_input(cls, values):
|
166 |
task_type = values.get("task_type")
|
167 |
if task_type == "text" and values.get("input_text") is None:
|
@@ -182,8 +182,6 @@ class GenerateRequest(BaseModel):
|
|
182 |
raise ValueError("low_res_image is required for image super-resolution.")
|
183 |
return values
|
184 |
|
185 |
-
|
186 |
-
|
187 |
class S3ModelLoader:
|
188 |
def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
|
189 |
self.bucket_name = bucket_name
|
@@ -688,4 +686,4 @@ if __name__ == "__main__":
|
|
688 |
|
689 |
create_db_and_table() # Initialize database on startup
|
690 |
|
691 |
-
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
72 |
from jose import JWTError, jwt
|
73 |
from passlib.context import CryptContext
|
74 |
from datetime import datetime, timedelta
|
75 |
+
from pydantic import BaseModel, field_validator, model_validator, Field, EmailStr, constr, ValidationError
|
76 |
+
from typing import Optional, List, Union
|
77 |
|
78 |
#setting up logging
|
79 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
|
|
154 |
mask_image: Optional[UploadFile] = None # for image inpainting
|
155 |
low_res_image: Optional[UploadFile] = None # for image super-resolution
|
156 |
|
157 |
+
@field_validator('task_type')
|
|
|
158 |
def validate_task_type(cls, value):
|
159 |
allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
|
160 |
if value not in allowed_types:
|
161 |
raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
|
162 |
return value
|
163 |
|
164 |
+
@model_validator(mode='after')
|
165 |
def check_input(cls, values):
|
166 |
task_type = values.get("task_type")
|
167 |
if task_type == "text" and values.get("input_text") is None:
|
|
|
182 |
raise ValueError("low_res_image is required for image super-resolution.")
|
183 |
return values
|
184 |
|
|
|
|
|
185 |
class S3ModelLoader:
|
186 |
def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
|
187 |
self.bucket_name = bucket_name
|
|
|
686 |
|
687 |
create_db_and_table() # Initialize database on startup
|
688 |
|
689 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|