Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -42,6 +42,7 @@ from fastapi.templating import Jinja2Templates
|
|
42 |
from fastapi.middleware.gzip import GZipMiddleware
|
43 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
44 |
from starlette.middleware.cors import CORSMiddleware
|
|
|
45 |
|
46 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
47 |
logger = logging.getLogger(__name__)
|
@@ -54,17 +55,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
54 |
|
55 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
56 |
|
57 |
-
conn = sqlite3.connect('users.db', check_same_thread=False)
|
58 |
-
cursor = conn.cursor()
|
59 |
-
cursor.execute('''
|
60 |
-
CREATE TABLE IF NOT EXISTS users (
|
61 |
-
username TEXT PRIMARY KEY,
|
62 |
-
email TEXT UNIQUE,
|
63 |
-
hashed_password TEXT
|
64 |
-
)
|
65 |
-
''')
|
66 |
-
conn.commit()
|
67 |
-
|
68 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
69 |
API_KEY = os.getenv("API_KEY")
|
70 |
api_key_header = APIKeyHeader(name="X-API-Key")
|
@@ -77,6 +67,7 @@ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
|
77 |
TEMP_DIR = "/tmp"
|
78 |
STATIC_DIR = "static"
|
79 |
TEMPLATES = Jinja2Templates(directory="templates")
|
|
|
80 |
|
81 |
app = FastAPI()
|
82 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
@@ -568,73 +559,56 @@ class Token(BaseModel):
|
|
568 |
access_token: str
|
569 |
token_type: str
|
570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
571 |
@app.post("/token", response_model=Token)
|
572 |
-
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
573 |
-
user = authenticate_user(form_data.username, form_data.password)
|
574 |
if not user:
|
575 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
|
576 |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
577 |
access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
|
578 |
return {"access_token": access_token, "token_type": "bearer"}
|
579 |
|
580 |
-
def authenticate_user(username: str, password: str):
|
581 |
-
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
582 |
-
user = cursor.fetchone()
|
583 |
-
if user and pwd_context.verify(password, user[2]):
|
584 |
-
return {"username": username}
|
585 |
-
return None
|
586 |
-
|
587 |
-
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
|
588 |
-
to_encode = data.copy()
|
589 |
-
if expires_delta:
|
590 |
-
expire = datetime.utcnow() + expires_delta
|
591 |
-
else:
|
592 |
-
expire = datetime.utcnow() + timedelta(minutes=15)
|
593 |
-
to_encode.update({"exp": expire})
|
594 |
-
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
595 |
-
return encoded_jwt
|
596 |
-
|
597 |
|
598 |
-
|
599 |
-
async def read_users_me(current_user: str = Depends(get_current_user)):
|
600 |
-
return {"username": current_user}
|
601 |
-
|
602 |
-
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
603 |
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"})
|
604 |
try:
|
605 |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
606 |
username: str = payload.get("sub")
|
607 |
if username is None:
|
608 |
raise credentials_exception
|
|
|
|
|
|
|
|
|
609 |
except JWTError:
|
610 |
raise credentials_exception
|
611 |
-
cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
612 |
-
user = cursor.fetchone()
|
613 |
-
if user is None:
|
614 |
-
raise credentials_exception
|
615 |
-
return username
|
616 |
-
|
617 |
|
618 |
@app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
|
619 |
-
async def create_user(user: User):
|
|
|
620 |
try:
|
621 |
-
hashed_password
|
622 |
-
cursor.execute("INSERT INTO users (username, email, hashed_password) VALUES (?, ?, ?)", (user.username, user.email, hashed_password))
|
623 |
-
conn.commit()
|
624 |
return user
|
625 |
-
except
|
626 |
raise HTTPException(status_code=400, detail="Username or email already exists")
|
627 |
-
except Exception as e:
|
628 |
-
logger.error(f"Error creating user: {e}")
|
629 |
-
raise HTTPException(status_code=500, detail=f"Error creating user: {e}")
|
630 |
|
631 |
|
632 |
@app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
|
633 |
-
async def update_user_data(username: str, user: User):
|
|
|
634 |
try:
|
635 |
-
hashed_password =
|
636 |
-
cursor.execute("UPDATE users SET email = ?, hashed_password = ? WHERE username = ?", (user.email, hashed_password, username))
|
637 |
-
conn.commit()
|
638 |
return user
|
639 |
except Exception as e:
|
640 |
logger.error(f"Error updating user: {e}")
|
@@ -642,10 +616,9 @@ async def update_user_data(username: str, user: User):
|
|
642 |
|
643 |
|
644 |
@app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
|
645 |
-
async def delete_user_account(username: str):
|
646 |
try:
|
647 |
-
|
648 |
-
conn.commit()
|
649 |
return JSONResponse({"message": "User deleted successfully."}, status_code=200)
|
650 |
except Exception as e:
|
651 |
logger.error(f"Error deleting user: {e}")
|
@@ -653,10 +626,17 @@ async def delete_user_account(username: str):
|
|
653 |
|
654 |
|
655 |
@app.get("/users", dependencies=[Depends(get_current_user)])
|
656 |
-
async def get_all_users_route():
|
657 |
-
|
658 |
-
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
|
661 |
|
662 |
@app.exception_handler(RequestValidationError)
|
@@ -667,5 +647,16 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|
667 |
)
|
668 |
|
669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
670 |
if __name__ == "__main__":
|
671 |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
42 |
from fastapi.middleware.gzip import GZipMiddleware
|
43 |
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
44 |
from starlette.middleware.cors import CORSMiddleware
|
45 |
+
import asyncpg
|
46 |
|
47 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s - %(lineno)d - %(message)s')
|
48 |
logger = logging.getLogger(__name__)
|
|
|
55 |
|
56 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
59 |
API_KEY = os.getenv("API_KEY")
|
60 |
api_key_header = APIKeyHeader(name="X-API-Key")
|
|
|
67 |
TEMP_DIR = "/tmp"
|
68 |
STATIC_DIR = "static"
|
69 |
TEMPLATES = Jinja2Templates(directory="templates")
|
70 |
+
DATABASE_URL = os.getenv("DATABASE_URL")
|
71 |
|
72 |
app = FastAPI()
|
73 |
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
559 |
access_token: str
|
560 |
token_type: str
|
561 |
|
562 |
+
async def get_db():
|
563 |
+
async with asyncpg.create_pool(DATABASE_URL) as pool:
|
564 |
+
async with pool.acquire() as conn:
|
565 |
+
yield conn
|
566 |
+
|
567 |
+
async def authenticate_user(username, password, conn):
|
568 |
+
row = await conn.fetchrow("SELECT * FROM users WHERE username = $1", username)
|
569 |
+
if row is not None and pwd_context.verify(password, row["hashed_password"]):
|
570 |
+
return {"username": username}
|
571 |
+
return None
|
572 |
+
|
573 |
@app.post("/token", response_model=Token)
|
574 |
+
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), conn = Depends(get_db)):
|
575 |
+
user = await authenticate_user(form_data.username, form_data.password, conn)
|
576 |
if not user:
|
577 |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
|
578 |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
579 |
access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires)
|
580 |
return {"access_token": access_token, "token_type": "bearer"}
|
581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
|
583 |
+
async def get_current_user(token: str = Depends(oauth2_scheme), conn = Depends(get_db)):
|
|
|
|
|
|
|
|
|
584 |
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"})
|
585 |
try:
|
586 |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
587 |
username: str = payload.get("sub")
|
588 |
if username is None:
|
589 |
raise credentials_exception
|
590 |
+
user = await conn.fetchrow("SELECT * FROM users WHERE username = $1", username)
|
591 |
+
if user is None:
|
592 |
+
raise credentials_exception
|
593 |
+
return username
|
594 |
except JWTError:
|
595 |
raise credentials_exception
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
|
597 |
@app.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
|
598 |
+
async def create_user(user: User, conn = Depends(get_db)):
|
599 |
+
hashed_password = pwd_context.hash(user.password)
|
600 |
try:
|
601 |
+
await conn.execute("INSERT INTO users (username, email, hashed_password) VALUES ($1, $2, $3)", user.username, user.email, hashed_password)
|
|
|
|
|
602 |
return user
|
603 |
+
except asyncpg.exceptions.UniqueViolationError:
|
604 |
raise HTTPException(status_code=400, detail="Username or email already exists")
|
|
|
|
|
|
|
605 |
|
606 |
|
607 |
@app.put("/users/{username}", response_model=User, dependencies=[Depends(get_current_user)])
|
608 |
+
async def update_user_data(username: str, user: User, conn = Depends(get_db)):
|
609 |
+
hashed_password = pwd_context.hash(user.password)
|
610 |
try:
|
611 |
+
await conn.execute("UPDATE users SET email = $1, hashed_password = $2 WHERE username = $3", user.email, hashed_password, username)
|
|
|
|
|
612 |
return user
|
613 |
except Exception as e:
|
614 |
logger.error(f"Error updating user: {e}")
|
|
|
616 |
|
617 |
|
618 |
@app.delete("/users/{username}", dependencies=[Depends(get_current_user)])
|
619 |
+
async def delete_user_account(username: str, conn = Depends(get_db)):
|
620 |
try:
|
621 |
+
await conn.execute("DELETE FROM users WHERE username = $1", username)
|
|
|
622 |
return JSONResponse({"message": "User deleted successfully."}, status_code=200)
|
623 |
except Exception as e:
|
624 |
logger.error(f"Error deleting user: {e}")
|
|
|
626 |
|
627 |
|
628 |
@app.get("/users", dependencies=[Depends(get_current_user)])
|
629 |
+
async def get_all_users_route(conn = Depends(get_db)):
|
630 |
+
rows = await conn.fetch("SELECT username, email FROM users")
|
631 |
+
return [{"username": row["username"], "email": row["email"]} for row in rows]
|
632 |
+
|
633 |
+
|
634 |
+
@app.get("/users/me", dependencies=[Depends(get_current_user)]) # Requires authentication
|
635 |
+
async def read_users_me(current_user: str = Depends(get_current_user), conn=Depends(get_db)):
|
636 |
+
user = await conn.fetchrow("SELECT username, email FROM users WHERE username = $1", current_user)
|
637 |
+
if user:
|
638 |
+
return {"username": user["username"], "email": user["email"]}
|
639 |
+
raise HTTPException(status_code=404, detail="User not found")
|
640 |
|
641 |
|
642 |
@app.exception_handler(RequestValidationError)
|
|
|
647 |
)
|
648 |
|
649 |
|
650 |
+
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None):
|
651 |
+
to_encode = data.copy()
|
652 |
+
if expires_delta:
|
653 |
+
expire = datetime.utcnow() + expires_delta
|
654 |
+
else:
|
655 |
+
expire = datetime.utcnow() + timedelta(minutes=15)
|
656 |
+
to_encode.update({"exp": expire})
|
657 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
658 |
+
return encoded_jwt
|
659 |
+
|
660 |
+
|
661 |
if __name__ == "__main__":
|
662 |
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|