import json import random import string import uuid import time from flask import Flask, request, jsonify, session from redis import Redis app = Flask(__name__) app.secret_key = '333888' redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis') # redis = Redis(host='10.254.13.87', port=6379) # redis = Redis(host='localhost', port=6379) # 生成验证码 def generate_verification_code(): # code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) code = ''.join(random.choices(string.digits, k=6)) return code # 发送验证码到用户邮箱(这里只是模拟发送过程) def send_verification_code(email, code): print(f'Sending verification code {code} to {email}...') # 用户请求发送验证码 @app.route('/send_verification_code', methods=['POST']) def send_verification_code_endpoint(): # 从请求中获取邮箱地址 email = request.json.get('email') # 生成验证码 verification_code = generate_verification_code() # 发送验证码到用户邮箱 send_verification_code(email, verification_code) # 保存验证码到Redis,并设置过期时间(例如,5分钟) redis.setex(email, 300, verification_code) return jsonify({'message': 'Verification code sent'}) # 用户注册 @app.route('/register', methods=['POST']) def register(): # 从请求中获取注册信息 email = request.json.get('email') username = request.json.get('username') password = request.json.get('password') verification_code = request.json.get('verification_code') # 检查验证码是否匹配 stored_code = redis.get(email) if stored_code is None or verification_code != stored_code.decode('utf-8'): return jsonify({'message': 'Invalid verification code'}), 400 # 检查用户名是否已被注册 if redis.hexists('users', username): return jsonify({'message': 'Username already exists'}), 400 # 生成唯一的用户ID user_id = str(uuid.uuid4()) # 保存用户信息到Redis user_data = { 'user_id': user_id, 'username': username, 'email': email, 'password': password } redis.hset('users', username, json.dumps(user_data)) return jsonify({'message': 'Registration successful'}) # 用户登录 @app.route('/login', methods=['POST']) def login(): # 从请求中获取登录信息 username = request.json.get('username') password = request.json.get('password') # 检查用户名和密码是否匹配 user_data = redis.hget('users', username) if not user_data: return jsonify({'message': 'Invalid username'}), 400 user_data = user_data.decode('utf-8') if password != eval(user_data)['password']: return jsonify({'message': 'Invalid password'}), 400 # 登录验证通过,将用户信息存储到会话中 session['username'] = request.json.get('username') return jsonify({'message': 'Login successful'}) # 需要验证登录状态的接口 @app.route('/protected', methods=['GET']) def protected(): # 检查会话中的用户信息 if 'username' in session: username = session['username'] # 其他业务逻辑... return jsonify({'message': f'Hello, {username}! This is a protected endpoint.'}) # 如果用户未登录,则返回未授权的响应 return jsonify({'message': 'Unauthorized'}), 401 # 用户注销 @app.route('/logout', methods=['POST']) def logout(): # 清除会话中的用户信息 session.pop('username', None) return jsonify({'message': 'Logout successful'}) # 购买支付套餐 @app.route('/purchase', methods=['POST']) def purchase(): package_id = request.json.get('package_id') # 根据套餐ID获取套餐信息 package = get_package_by_id(package_id) if not package: return jsonify({'message': 'Invalid package ID'}), 400 # 根据用户名查询用户ID username = session.get('username') if not username: return jsonify({'message': 'User not logged in'}), 400 user_id = get_user_id_by_username(username) if not user_id: return jsonify({'message': 'User not found'}), 400 # 检查用户是否已经支付过当前套餐 if not is_package_expired(user_id) and has_purchased_package(user_id, package_id): return jsonify({'message': 'Package already purchased'}), 400 # 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐 if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1': return jsonify({'message': 'Cannot purchase lower level package'}), 400 # 存储用户套餐信息到Redis store_user_package(user_id, package) return jsonify({'message': 'Purchase successful'}) # 验证用户聊天次数 @app.route('/validate', methods=['POST']) def validate(): # 根据用户名查询用户ID username = session.get('username') user_id = get_user_id_by_username(username) if not user_id: return jsonify({'message': 'User not found'}), 400 # 获取用户套餐信息 package = get_user_package(user_id) if not package: return jsonify({'message': 'User has not purchased any package'}), 400 # 检查用户聊天次数是否超过限制 if exceeded_chat_limit(user_id, package): return jsonify({'message': 'Chat limit exceeded'}), 400 return jsonify({'message': 'Chat limit not exceeded'}) # 获取用户ID通过用户名 def get_user_id_by_username(username): user_data = redis.hget('users', username) if user_data: user_data = json.loads(user_data.decode('utf-8')) user_id = user_data.get('user_id') return user_id return None # 根据套餐ID获取套餐信息 def get_package_by_id(package_id): packages = { '1': { 'package_id': '1', 'name': 'Package 1', 'basic_chat_limit': 10, 'advanced_chat_limit': 10, 'price': 10, 'expiration': 30 * 24 * 60 * 60 # 过期时间为30天(以秒为单位) }, '2': { 'package_id': '2', 'name': 'Package 2', 'basic_chat_limit': -1, # -1 表示无限次 'advanced_chat_limit': -1, 'price': 100, 'expiration': 30 * 24 * 60 * 60 # 过期时间为30天(以秒为单位) } } return packages.get(package_id) # 存储用户套餐信息到Redis def store_user_package(user_id, package): user_package_key = f'user:{user_id}:package' redis.hset(user_package_key, 'package_id', package['package_id']) redis.hset(user_package_key, 'name', package['name']) redis.hset(user_package_key, 'basic_chat_limit', package['basic_chat_limit']) redis.hset(user_package_key, 'advanced_chat_limit', package['advanced_chat_limit']) # 设置套餐过期时间 expiration = int(time.time()) + package['expiration'] redis.expireat(user_package_key, expiration) # 获取用户套餐信息 def get_user_package(user_id): user_package_key = f'user:{user_id}:package' package = redis.hgetall(user_package_key) return package # 检查用户是否已经支付过指定套餐 def has_purchased_package(user_id, package_id): user_package_key = f'user:{user_id}:package' purchased_package_id = redis.hget(user_package_key, 'package_id') return purchased_package_id.decode('utf-8') == str(package_id) # 检查用户是否已经支付了高级套餐 def has_purchased_advanced_package(user_id): user_package_key = f'user:{user_id}:package' purchased_package_id = redis.hget(user_package_key, 'package_id') return purchased_package_id.decode('utf-8') == '2' # 检查套餐是否过期 def is_package_expired(user_id): user_package_key = f'user:{user_id}:package' expiration = redis.ttl(user_package_key) return expiration <= 0 # 获取套餐有效期 def get_package_expiration(user_id): user_package_key = f'user:{user_id}:package' expiration = redis.ttl(user_package_key) return expiration # 检查用户聊天次数是否超过限制 def exceeded_chat_limit(user_id, package): user_basic_chat_key = f'user:{user_id}:basic_chat' user_advanced_chat_key = f'user:{user_id}:advanced_chat' basic_chat_limit = int(package.get(b'basic_chat_limit', 0).decode('utf-8')) advanced_chat_limit = int(package.get(b'advanced_chat_limit', 0).decode('utf-8')) if basic_chat_limit >= 0 and int(redis.get(user_basic_chat_key) or 0) >= basic_chat_limit: return True if advanced_chat_limit >= 0 and int(redis.get(user_advanced_chat_key) or 0) >= advanced_chat_limit: return True return False if __name__ == '__main__': app.run(debug=True)