|
|
|
|
|
""" |
|
@File : users.py |
|
@Time : 2024/08/30 |
|
@Project : https://github.com/PeterH0323/Streamer-Sales |
|
@Author : HinGwenWong |
|
@Version : 1.0 |
|
@Desc : 用户登录和 Token 认证接口 |
|
""" |
|
|
|
from datetime import datetime, timedelta, timezone |
|
|
|
import jwt |
|
from fastapi import APIRouter, Depends, HTTPException |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from loguru import logger |
|
from passlib.context import CryptContext |
|
|
|
from ...web_configs import WEB_CONFIGS |
|
from ..database.user_db import get_db_user_info |
|
from ..models.user_model import TokenItem |
|
from ..utils import ResultCode, make_return_data |
|
|
|
router = APIRouter( |
|
prefix="/user", |
|
tags=["user"], |
|
responses={404: {"description": "Not found"}}, |
|
) |
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login") |
|
|
|
|
|
PWD_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool: |
|
"""密码校验 |
|
|
|
Args: |
|
plain_password (str): 明文密码 |
|
hashed_password (str): 加密后的密码,用于对比 |
|
|
|
Returns: |
|
bool: 校验时候通过 |
|
""" |
|
logger.info(f"expect password = {PWD_CONTEXT.hash('123456')}") |
|
return PWD_CONTEXT.verify(plain_password, hashed_password) |
|
|
|
|
|
def get_password_hash(password: str) -> str: |
|
"""生成哈希密码 |
|
|
|
Args: |
|
password (str): 明文密码 |
|
|
|
Returns: |
|
str: 加密后的哈希密码 |
|
""" |
|
return PWD_CONTEXT.hash(password) |
|
|
|
|
|
def authenticate_user(username: str, password: str) -> bool: |
|
"""对用户名和密码进行校验 |
|
|
|
Args: |
|
username (str): 用户名 |
|
password (str): 密码 |
|
|
|
Returns: |
|
bool: 是否检验通过 |
|
""" |
|
|
|
|
|
user_info = get_db_user_info(username=username, all_info=True) |
|
if not user_info: |
|
|
|
logger.info(f"Cannot find username = {username}") |
|
return False |
|
|
|
|
|
if not verify_password(password, user_info.hashed_password): |
|
logger.info(f"verify_password fail") |
|
|
|
return False |
|
|
|
return user_info |
|
|
|
|
|
def get_current_user_info(token: str = Depends(oauth2_scheme)): |
|
"""在 token 中提取 user id |
|
|
|
Args: |
|
token (str, optional): token. Defaults to Depends(oauth2_scheme). |
|
|
|
Raises: |
|
HTTPException: 401 获取失败 |
|
|
|
Returns: |
|
int: 用户 ID |
|
""" |
|
logger.info(token) |
|
try: |
|
token_data = jwt.decode(token, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithms=WEB_CONFIGS.TOKEN_JWT_ALGORITHM) |
|
logger.info(token_data) |
|
user_id = token_data.get("user_id", None) |
|
except Exception as e: |
|
logger.error(e) |
|
raise HTTPException(status_code=401, detail="Could not validate credentials") |
|
|
|
if not user_id: |
|
logger.error(f"can not get user_id: {user_id}") |
|
raise HTTPException(status_code=401, detail="Could not validate credentials") |
|
|
|
|
|
|
|
logger.info(f"Got user_id: {user_id}") |
|
return user_id |
|
|
|
|
|
@router.post("/login", summary="登录接口") |
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()): |
|
|
|
|
|
user_info = authenticate_user(form_data.username, form_data.password) |
|
|
|
if not user_info: |
|
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}) |
|
|
|
|
|
token_expires = datetime.now(timezone.utc) + timedelta(days=7) |
|
|
|
|
|
token_data = { |
|
"user_id": user_info.user_id, |
|
"username": user_info.username, |
|
"exp": int(token_expires.timestamp()), |
|
"ip": user_info.ip_address, |
|
"login_time": int(datetime.now(timezone.utc).timestamp()), |
|
} |
|
logger.info(f"token_data = {token_data}") |
|
|
|
|
|
token = jwt.encode(token_data, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithm=WEB_CONFIGS.TOKEN_JWT_ALGORITHM) |
|
|
|
|
|
res_json = TokenItem(access_token=token, token_type="bearer") |
|
logger.info(f"Got token info = {res_json}") |
|
|
|
return res_json |
|
|
|
|
|
@router.get("/me", summary="获取用户信息") |
|
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)): |
|
"""获取用户信息""" |
|
user_info = get_db_user_info(id=user_id) |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", user_info) |
|
|