Hjgugugjhuhjggg commited on
Commit
757421a
·
verified ·
1 Parent(s): ba80fed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -72,7 +72,8 @@ from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
72
  from jose import JWTError, jwt
73
  from passlib.context import CryptContext
74
  from datetime import datetime, timedelta
75
- from typing import Optional
 
76
 
77
  #setting up logging
78
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
@@ -153,15 +154,14 @@ class GenerateRequest(BaseModel):
153
  mask_image: Optional[UploadFile] = None # for image inpainting
154
  low_res_image: Optional[UploadFile] = None # for image super-resolution
155
 
156
-
157
- @validator("task_type")
158
  def validate_task_type(cls, value):
159
  allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
160
  if value not in allowed_types:
161
  raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
162
  return value
163
 
164
- @root_validator
165
  def check_input(cls, values):
166
  task_type = values.get("task_type")
167
  if task_type == "text" and values.get("input_text") is None:
@@ -182,8 +182,6 @@ class GenerateRequest(BaseModel):
182
  raise ValueError("low_res_image is required for image super-resolution.")
183
  return values
184
 
185
-
186
-
187
  class S3ModelLoader:
188
  def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
189
  self.bucket_name = bucket_name
@@ -688,4 +686,4 @@ if __name__ == "__main__":
688
 
689
  create_db_and_table() # Initialize database on startup
690
 
691
- uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) # replace main with your filename
 
72
  from jose import JWTError, jwt
73
  from passlib.context import CryptContext
74
  from datetime import datetime, timedelta
75
+ from pydantic import BaseModel, field_validator, model_validator, Field, EmailStr, constr, ValidationError
76
+ from typing import Optional, List, Union
77
 
78
  #setting up logging
79
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
 
154
  mask_image: Optional[UploadFile] = None # for image inpainting
155
  low_res_image: Optional[UploadFile] = None # for image super-resolution
156
 
157
+ @field_validator('task_type')
 
158
  def validate_task_type(cls, value):
159
  allowed_types = ["text", "image", "audio", "video", "classification", "translation", "question-answering", "speech-to-text", "text-to-speech", "image-segmentation", "feature-extraction", "token-classification", "fill-mask", "image-inpainting", "image-super-resolution", "object-detection", "image-captioning", "audio-transcription", "summarization"]
160
  if value not in allowed_types:
161
  raise ValueError(f"Invalid task_type. Allowed types are: {allowed_types}")
162
  return value
163
 
164
+ @model_validator(mode='after')
165
  def check_input(cls, values):
166
  task_type = values.get("task_type")
167
  if task_type == "text" and values.get("input_text") is None:
 
182
  raise ValueError("low_res_image is required for image super-resolution.")
183
  return values
184
 
 
 
185
  class S3ModelLoader:
186
  def __init__(self, bucket_name, aws_access_key_id, aws_secret_access_key, aws_region):
187
  self.bucket_name = bucket_name
 
686
 
687
  create_db_and_table() # Initialize database on startup
688
 
689
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)