Spaces:
Runtime error
Runtime error
import json | |
import random | |
import string | |
import uuid | |
import time | |
import jwt | |
import datetime | |
from flask import Flask, request, jsonify, Request | |
from redis import Redis | |
SECERT_KEY = "8U2LL1" | |
app = Flask(__name__) | |
redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis') | |
# redis = Redis(host='10.254.13.87', port=6379) | |
# redis = Redis(host='localhost', port=6379) | |
# 生成验证码 | |
def generate_verification_code(): | |
code = ''.join(random.choices(string.digits, k=6)) | |
return code | |
# 发送验证码到用户邮箱(这里只是模拟发送过程) | |
def send_verification_code(email, code): | |
print(f'Sending verification code {code} to {email}...') | |
# 用户请求发送验证码 | |
def send_verification_code_endpoint(): | |
# 从请求中获取邮箱地址 | |
email = request.json.get('email') | |
# 生成验证码 | |
verification_code = generate_verification_code() | |
# 发送验证码到用户邮箱 | |
send_verification_code(email, verification_code) | |
# 保存验证码到Redis,并设置过期时间(例如,5分钟) | |
redis.setex(email, 300, verification_code) | |
return jsonify({'code': 0, 'message': 'Verification code sent'}) | |
# 用户注册 | |
def register(): | |
# 从请求中获取注册信息 | |
email = request.json.get('email') | |
username = request.json.get('username') | |
password = request.json.get('password') | |
verification_code = request.json.get('verification_code') | |
# 检查验证码是否匹配 | |
stored_code = redis.get(email) | |
if stored_code is None or verification_code != stored_code.decode('utf-8'): | |
return jsonify({'code': 400, 'message': 'Invalid verification code'}) | |
# 检查用户名是否已被注册 | |
if redis.hexists('users', username): | |
return jsonify({'code': 400, 'message': 'Username already exists'}) | |
# 生成唯一的用户ID | |
user_id = str(uuid.uuid4()) | |
# 保存用户信息到Redis | |
user_data = { | |
'user_id': user_id, | |
'username': username, | |
'email': email, | |
'password': password | |
} | |
redis.hset('users', username, json.dumps(user_data)) | |
return jsonify({ | |
'code': 0, | |
'message': 'Registration successful' | |
}) | |
# 用户登录 | |
def login(): | |
# 从请求中获取登录信息 | |
username = request.json.get('username') | |
password = request.json.get('password') | |
# 检查用户名和密码是否匹配 | |
user_data = redis.hget('users', username) | |
if not user_data: | |
return jsonify({'code': 400, 'message': 'Invalid username'}) | |
user_data = user_data.decode('utf-8') | |
if password != eval(user_data)['password']: | |
return jsonify({'code': 400, 'message': 'Invalid password'}) | |
# 生成令牌 | |
token = generate_token(eval(user_data)['user_id'], username) | |
return jsonify({ | |
'code': 0, | |
'message': 'Login successful', | |
'data': { | |
'token': token | |
} | |
}) | |
# 需要验证登录状态的接口 | |
def protected(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}), 200 | |
# 如果用户未登录,则返回未授权的响应 | |
return jsonify({'code': 401, 'message': 'Unauthorized'}) | |
# 用户注销 | |
def logout(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
# 将令牌添加到 Redis 黑名单 | |
redis.set(token, 'revoked') | |
return jsonify({'code': 0, 'message': 'Logout successful'}) | |
# 购买支付套餐 | |
def purchase(): | |
package_id = request.json.get('package_id') | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
# 根据套餐ID获取套餐信息 | |
package = get_package_by_id(package_id) | |
if not package: | |
return jsonify({'code': 400, 'message': 'Invalid package ID'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
# 检查用户是否已经支付过当前套餐 | |
if not is_package_expired(user_id) and has_purchased_package(user_id, package_id): | |
return jsonify({'code': 400, 'message': 'Package already purchased'}) | |
# 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐 | |
if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1': | |
return jsonify({'code': 400, 'message': 'Cannot purchase lower level package'}) | |
# 存储用户套餐信息到Redis | |
store_user_package(user_id, package) | |
return jsonify({'code': 0, 'message': 'Purchase successful'}) | |
# 验证用户聊天次数 | |
def validate(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
# 获取用户套餐信息 | |
package = get_user_package(user_id) | |
if not package: | |
return jsonify({'code': 400, 'message': 'User has not purchased any package'}) | |
# 检查用户聊天次数是否超过限制 | |
if exceeded_chat_limit(user_id, package): | |
return jsonify({'code': 400, 'message': 'Chat limit exceeded'}) | |
return jsonify({'code': 0, 'message': 'Chat limit not exceeded'}) | |
def parse_token(request: Request): | |
token_with_bearer = request.headers.get('Authorization') | |
if token_with_bearer is not None and token_with_bearer.startswith('Bearer '): | |
token = token_with_bearer.split(' ')[1] | |
else: | |
# 处理未包含 "Bearer" 前缀的情况 | |
token = token_with_bearer | |
return token | |
# 生成令牌 | |
def generate_token(user_id, username): | |
# 构造包含用户信息的负载 | |
payload = { | |
'user_id': user_id, | |
'username': username, | |
'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1) | |
} | |
# 在这里,您可以使用您的密钥(secret key)来签署令牌 | |
# 选择适当的签名算法,并设置适当的过期时间等参数 | |
# 仅使用 HS256 算法和过期时间为1小时 | |
token = jwt.encode(payload, SECERT_KEY, algorithm='HS256') | |
return token | |
# 验证令牌 | |
def validate_token(token): | |
try: | |
print(token) | |
# 使用密钥进行解码 | |
payload = jwt.decode(token, SECERT_KEY, algorithms=['HS256']) | |
print(payload) | |
# 检查令牌的过期时间 | |
if 'exp' in payload and datetime.datetime.utcnow() > datetime.datetime.fromtimestamp(payload['exp']): | |
return False | |
return True | |
except (jwt.DecodeError, jwt.InvalidTokenError): | |
return False | |
def get_user_id_from_token(token): | |
try: | |
decoded_token = jwt.decode( | |
token, SECERT_KEY, algorithms=['HS256']) | |
user_id = decoded_token.get('user_id') | |
return user_id | |
except jwt.ExpiredSignatureError: | |
# 处理过期的令牌 | |
return None | |
except (jwt.DecodeError, jwt.InvalidTokenError): | |
# 处理解码或无效的令牌 | |
return None | |
# 获取用户ID通过用户名 | |
def get_user_id_by_username(username): | |
user_data = redis.hget('users', username) | |
if user_data: | |
user_data = json.loads(user_data.decode('utf-8')) | |
user_id = user_data.get('user_id') | |
return user_id | |
return None | |
# 根据套餐ID获取套餐信息 | |
def get_package_by_id(package_id): | |
packages = { | |
'1': { | |
'package_id': '1', | |
'name': 'Package 1', | |
'basic_chat_limit': 10, | |
'advanced_chat_limit': 10, | |
'price': 10, | |
'expiration': 30 * 24 * 60 * 60 # 过期时间为30天(以秒为单位) | |
}, | |
'2': { | |
'package_id': '2', | |
'name': 'Package 2', | |
'basic_chat_limit': -1, # -1 表示无限次 | |
'advanced_chat_limit': -1, | |
'price': 100, | |
'expiration': 30 * 24 * 60 * 60 # 过期时间为30天(以秒为单位) | |
} | |
} | |
return packages.get(package_id) | |
# 存储用户套餐信息到Redis | |
def store_user_package(user_id, package): | |
user_package_key = f'user:{user_id}:package' | |
redis.hset(user_package_key, 'package_id', package['package_id']) | |
redis.hset(user_package_key, 'name', package['name']) | |
redis.hset(user_package_key, 'basic_chat_limit', | |
package['basic_chat_limit']) | |
redis.hset(user_package_key, 'advanced_chat_limit', | |
package['advanced_chat_limit']) | |
# 设置套餐过期时间 | |
expiration = int(time.time()) + package['expiration'] | |
redis.expireat(user_package_key, expiration) | |
# 获取用户套餐信息 | |
def get_user_package(user_id): | |
user_package_key = f'user:{user_id}:package' | |
package = redis.hgetall(user_package_key) | |
return package | |
# 检查用户是否已经支付过指定套餐 | |
def has_purchased_package(user_id, package_id): | |
user_package_key = f'user:{user_id}:package' | |
purchased_package_id = redis.hget(user_package_key, 'package_id') | |
return purchased_package_id.decode('utf-8') == str(package_id) | |
# 检查用户是否已经支付了高级套餐 | |
def has_purchased_advanced_package(user_id): | |
user_package_key = f'user:{user_id}:package' | |
purchased_package_id = redis.hget(user_package_key, 'package_id') | |
return purchased_package_id.decode('utf-8') == '2' | |
# 检查套餐是否过期 | |
def is_package_expired(user_id): | |
user_package_key = f'user:{user_id}:package' | |
expiration = redis.ttl(user_package_key) | |
return expiration <= 0 | |
# 获取套餐有效期 | |
def get_package_expiration(user_id): | |
user_package_key = f'user:{user_id}:package' | |
expiration = redis.ttl(user_package_key) | |
return expiration | |
# 检查用户聊天次数是否超过限制 | |
def exceeded_chat_limit(user_id, package): | |
user_basic_chat_key = f'user:{user_id}:basic_chat' | |
user_advanced_chat_key = f'user:{user_id}:advanced_chat' | |
basic_chat_limit = int(package.get(b'basic_chat_limit', 0).decode('utf-8')) | |
advanced_chat_limit = int(package.get( | |
b'advanced_chat_limit', 0).decode('utf-8')) | |
if basic_chat_limit >= 0 and int(redis.get(user_basic_chat_key) or 0) >= basic_chat_limit: | |
return True | |
if advanced_chat_limit >= 0 and int(redis.get(user_advanced_chat_key) or 0) >= advanced_chat_limit: | |
return True | |
return False | |
if __name__ == '__main__': | |
app.run(debug=True) | |