File size: 4,656 Bytes
1ef9436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File    :   users.py
@Time    :   2024/08/30
@Project :   https://github.com/PeterH0323/Streamer-Sales
@Author  :   HinGwenWong
@Version :   1.0
@Desc    :   用户登录和 Token 认证接口
"""

from datetime import datetime, timedelta, timezone

import jwt
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from loguru import logger
from passlib.context import CryptContext

from ...web_configs import WEB_CONFIGS
from ..database.user_db import get_db_user_info
from ..models.user_model import TokenItem
from ..utils import ResultCode, make_return_data

router = APIRouter(
    prefix="/user",
    tags=["user"],
    responses={404: {"description": "Not found"}},
)

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login")

# 密码加解密
PWD_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")


def verify_password(plain_password: str, hashed_password: str) -> bool:
    """密码校验

    Args:
        plain_password (str): 明文密码
        hashed_password (str): 加密后的密码,用于对比

    Returns:
        bool: 校验时候通过
    """
    logger.info(f"expect password = {PWD_CONTEXT.hash('123456')}")
    return PWD_CONTEXT.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
    """生成哈希密码

    Args:
        password (str): 明文密码

    Returns:
        str: 加密后的哈希密码
    """
    return PWD_CONTEXT.hash(password)


def authenticate_user(username: str, password: str) -> bool:
    """对用户名和密码进行校验

    Args:
        username (str): 用户名
        password (str): 密码

    Returns:
        bool: 是否检验通过
    """

    # 获取用户信息
    user_info = get_db_user_info(username=username, all_info=True)
    if not user_info:
        # 没有找到用户名
        logger.info(f"Cannot find username = {username}")
        return False

    # 校验密码
    if not verify_password(password, user_info.hashed_password):
        logger.info(f"verify_password fail")
        # 密码校验失败
        return False

    return user_info


def get_current_user_info(token: str = Depends(oauth2_scheme)):
    """在 token 中提取 user id

    Args:
        token (str, optional): token. Defaults to Depends(oauth2_scheme).

    Raises:
        HTTPException: 401 获取失败

    Returns:
        int: 用户 ID
    """
    logger.info(token)
    try:
        token_data = jwt.decode(token, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithms=WEB_CONFIGS.TOKEN_JWT_ALGORITHM)
        logger.info(token_data)
        user_id = token_data.get("user_id", None)
    except Exception as e:
        logger.error(e)
        raise HTTPException(status_code=401, detail="Could not validate credentials")

    if not user_id:
        logger.error(f"can not get user_id: {user_id}")
        raise HTTPException(status_code=401, detail="Could not validate credentials")

    # TODO 超时强制重新登录

    logger.info(f"Got user_id: {user_id}")
    return user_id


@router.post("/login", summary="登录接口")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):

    # 校验用户名和密码
    user_info = authenticate_user(form_data.username, form_data.password)

    if not user_info:
        raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})

    # 过期时间
    token_expires = datetime.now(timezone.utc) + timedelta(days=7)

    # token 生成包含内容,记录 IP 的原因是防止被其他人拿到用户的 token 进行假冒访问
    token_data = {
        "user_id": user_info.user_id,
        "username": user_info.username,
        "exp": int(token_expires.timestamp()),
        "ip": user_info.ip_address,
        "login_time": int(datetime.now(timezone.utc).timestamp()),
    }
    logger.info(f"token_data = {token_data}")

    # 生成 token
    token = jwt.encode(token_data, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithm=WEB_CONFIGS.TOKEN_JWT_ALGORITHM)

    # 返回
    res_json = TokenItem(access_token=token, token_type="bearer")
    logger.info(f"Got token info = {res_json}")
    # return make_return_data(True, ResultCode.SUCCESS, "成功", content)
    return res_json


@router.get("/me", summary="获取用户信息")
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
    """获取用户信息"""
    user_info = get_db_user_info(id=user_id)
    return make_return_data(True, ResultCode.SUCCESS, "成功", user_info)