Spaces:
Runtime error
Runtime error
远兮
commited on
Commit
·
b47ad38
1
Parent(s):
599ee8f
由session改为token
Browse files- redis/test_user_redis.py +106 -29
redis/test_user_redis.py
CHANGED
@@ -3,11 +3,14 @@ import random
|
|
3 |
import string
|
4 |
import uuid
|
5 |
import time
|
6 |
-
|
|
|
|
|
7 |
from redis import Redis
|
8 |
|
|
|
|
|
9 |
app = Flask(__name__)
|
10 |
-
app.secret_key = '333888'
|
11 |
redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis')
|
12 |
# redis = Redis(host='10.254.13.87', port=6379)
|
13 |
# redis = Redis(host='localhost', port=6379)
|
@@ -15,7 +18,6 @@ redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis')
|
|
15 |
|
16 |
# 生成验证码
|
17 |
def generate_verification_code():
|
18 |
-
# code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
19 |
code = ''.join(random.choices(string.digits, k=6))
|
20 |
return code
|
21 |
|
@@ -60,7 +62,7 @@ def register():
|
|
60 |
# 检查用户名是否已被注册
|
61 |
if redis.hexists('users', username):
|
62 |
return jsonify({'message': 'Username already exists'}), 400
|
63 |
-
|
64 |
# 生成唯一的用户ID
|
65 |
user_id = str(uuid.uuid4())
|
66 |
|
@@ -73,7 +75,10 @@ def register():
|
|
73 |
}
|
74 |
redis.hset('users', username, json.dumps(user_data))
|
75 |
|
76 |
-
return jsonify({
|
|
|
|
|
|
|
77 |
|
78 |
|
79 |
# 用户登录
|
@@ -92,20 +97,24 @@ def login():
|
|
92 |
if password != eval(user_data)['password']:
|
93 |
return jsonify({'code': 400, 'message': 'Invalid password'})
|
94 |
|
95 |
-
#
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
|
101 |
# 需要验证登录状态的接口
|
102 |
@app.route('/protected', methods=['GET'])
|
103 |
def protected():
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
return jsonify({'code': 0, 'message': f'Hello, {username}! This is a protected endpoint.'})
|
109 |
|
110 |
# 如果用户未登录,则返回未授权的响应
|
111 |
return jsonify({'code': 401, 'message': 'Unauthorized'})
|
@@ -114,8 +123,11 @@ def protected():
|
|
114 |
# 用户注销
|
115 |
@app.route('/logout', methods=['POST'])
|
116 |
def logout():
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
119 |
return jsonify({'code': 0, 'message': 'Logout successful'})
|
120 |
|
121 |
|
@@ -123,18 +135,18 @@ def logout():
|
|
123 |
@app.route('/purchase', methods=['POST'])
|
124 |
def purchase():
|
125 |
package_id = request.json.get('package_id')
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# 根据套餐ID获取套餐信息
|
128 |
package = get_package_by_id(package_id)
|
129 |
if not package:
|
130 |
return jsonify({'code': 400, 'message': 'Invalid package ID'})
|
131 |
-
|
132 |
-
# 根据用户名查询用户ID
|
133 |
-
username = session.get('username')
|
134 |
-
if not username:
|
135 |
-
return jsonify({'code': 400, 'message': 'User not logged in'})
|
136 |
|
137 |
-
user_id =
|
138 |
if not user_id:
|
139 |
return jsonify({'code': 400, 'message': 'User not found'})
|
140 |
|
@@ -145,7 +157,7 @@ def purchase():
|
|
145 |
# 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐
|
146 |
if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1':
|
147 |
return jsonify({'code': 400, 'message': 'Cannot purchase lower level package'})
|
148 |
-
|
149 |
# 存储用户套餐信息到Redis
|
150 |
store_user_package(user_id, package)
|
151 |
|
@@ -155,13 +167,17 @@ def purchase():
|
|
155 |
# 验证用户聊天次数
|
156 |
@app.route('/validate', methods=['POST'])
|
157 |
def validate():
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
if not user_id:
|
163 |
return jsonify({'code': 400, 'message': 'User not found'})
|
164 |
-
|
165 |
# 获取用户套餐信息
|
166 |
package = get_user_package(user_id)
|
167 |
if not package:
|
@@ -174,6 +190,64 @@ def validate():
|
|
174 |
return jsonify({'code': 0, 'message': 'Chat limit not exceeded'})
|
175 |
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
# 获取用户ID通过用户名
|
178 |
def get_user_id_by_username(username):
|
179 |
user_data = redis.hget('users', username)
|
@@ -257,12 +331,15 @@ def get_package_expiration(user_id):
|
|
257 |
return expiration
|
258 |
|
259 |
# 检查用户聊天次数是否超过限制
|
|
|
|
|
260 |
def exceeded_chat_limit(user_id, package):
|
261 |
user_basic_chat_key = f'user:{user_id}:basic_chat'
|
262 |
user_advanced_chat_key = f'user:{user_id}:advanced_chat'
|
263 |
|
264 |
basic_chat_limit = int(package.get(b'basic_chat_limit', 0).decode('utf-8'))
|
265 |
-
advanced_chat_limit = int(package.get(
|
|
|
266 |
|
267 |
if basic_chat_limit >= 0 and int(redis.get(user_basic_chat_key) or 0) >= basic_chat_limit:
|
268 |
return True
|
|
|
3 |
import string
|
4 |
import uuid
|
5 |
import time
|
6 |
+
import jwt
|
7 |
+
import datetime
|
8 |
+
from flask import Flask, request, jsonify, Request
|
9 |
from redis import Redis
|
10 |
|
11 |
+
SECERT_KEY = "8U2LL1"
|
12 |
+
|
13 |
app = Flask(__name__)
|
|
|
14 |
redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis')
|
15 |
# redis = Redis(host='10.254.13.87', port=6379)
|
16 |
# redis = Redis(host='localhost', port=6379)
|
|
|
18 |
|
19 |
# 生成验证码
|
20 |
def generate_verification_code():
|
|
|
21 |
code = ''.join(random.choices(string.digits, k=6))
|
22 |
return code
|
23 |
|
|
|
62 |
# 检查用户名是否已被注册
|
63 |
if redis.hexists('users', username):
|
64 |
return jsonify({'message': 'Username already exists'}), 400
|
65 |
+
|
66 |
# 生成唯一的用户ID
|
67 |
user_id = str(uuid.uuid4())
|
68 |
|
|
|
75 |
}
|
76 |
redis.hset('users', username, json.dumps(user_data))
|
77 |
|
78 |
+
return jsonify({
|
79 |
+
'code': 0,
|
80 |
+
'message': 'Registration successful'
|
81 |
+
})
|
82 |
|
83 |
|
84 |
# 用户登录
|
|
|
97 |
if password != eval(user_data)['password']:
|
98 |
return jsonify({'code': 400, 'message': 'Invalid password'})
|
99 |
|
100 |
+
# 生成令牌
|
101 |
+
token = generate_token(eval(user_data)['user_id'], username)
|
102 |
+
return jsonify({
|
103 |
+
'code': 0,
|
104 |
+
'message': 'Login successful',
|
105 |
+
'data': {
|
106 |
+
'token': token
|
107 |
+
}
|
108 |
+
})
|
109 |
|
110 |
|
111 |
# 需要验证登录状态的接口
|
112 |
@app.route('/protected', methods=['GET'])
|
113 |
def protected():
|
114 |
+
token = parse_token(request)
|
115 |
+
# 验证令牌
|
116 |
+
if not validate_token(token):
|
117 |
+
return jsonify({'code': 401, 'message': 'Invalid token'}), 200
|
|
|
118 |
|
119 |
# 如果用户未登录,则返回未授权的响应
|
120 |
return jsonify({'code': 401, 'message': 'Unauthorized'})
|
|
|
123 |
# 用户注销
|
124 |
@app.route('/logout', methods=['POST'])
|
125 |
def logout():
|
126 |
+
token = parse_token(request)
|
127 |
+
# 验证令牌
|
128 |
+
if not validate_token(token):
|
129 |
+
# 将令牌添加到 Redis 黑名单
|
130 |
+
redis.set(token, 'revoked')
|
131 |
return jsonify({'code': 0, 'message': 'Logout successful'})
|
132 |
|
133 |
|
|
|
135 |
@app.route('/purchase', methods=['POST'])
|
136 |
def purchase():
|
137 |
package_id = request.json.get('package_id')
|
138 |
+
token = parse_token(request)
|
139 |
+
|
140 |
+
# 验证令牌
|
141 |
+
if not validate_token(token):
|
142 |
+
return jsonify({'code': 401, 'message': 'Invalid token'})
|
143 |
|
144 |
# 根据套餐ID获取套餐信息
|
145 |
package = get_package_by_id(package_id)
|
146 |
if not package:
|
147 |
return jsonify({'code': 400, 'message': 'Invalid package ID'})
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
user_id = get_user_id_from_token(token)
|
150 |
if not user_id:
|
151 |
return jsonify({'code': 400, 'message': 'User not found'})
|
152 |
|
|
|
157 |
# 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐
|
158 |
if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1':
|
159 |
return jsonify({'code': 400, 'message': 'Cannot purchase lower level package'})
|
160 |
+
|
161 |
# 存储用户套餐信息到Redis
|
162 |
store_user_package(user_id, package)
|
163 |
|
|
|
167 |
# 验证用户聊天次数
|
168 |
@app.route('/validate', methods=['POST'])
|
169 |
def validate():
|
170 |
+
token = parse_token(request)
|
171 |
+
|
172 |
+
# 验证令牌
|
173 |
+
if not validate_token(token):
|
174 |
+
return jsonify({'code': 401, 'message': 'Invalid token'})
|
175 |
+
|
176 |
+
user_id = get_user_id_from_token(token)
|
177 |
+
|
178 |
if not user_id:
|
179 |
return jsonify({'code': 400, 'message': 'User not found'})
|
180 |
+
|
181 |
# 获取用户套餐信息
|
182 |
package = get_user_package(user_id)
|
183 |
if not package:
|
|
|
190 |
return jsonify({'code': 0, 'message': 'Chat limit not exceeded'})
|
191 |
|
192 |
|
193 |
+
def parse_token(request: Request):
|
194 |
+
token_with_bearer = request.headers.get('Authorization')
|
195 |
+
|
196 |
+
if token_with_bearer is not None and token_with_bearer.startswith('Bearer '):
|
197 |
+
token = token_with_bearer.split(' ')[1]
|
198 |
+
else:
|
199 |
+
# 处理未包含 "Bearer" 前缀的情况
|
200 |
+
token = token_with_bearer
|
201 |
+
return token
|
202 |
+
|
203 |
+
|
204 |
+
# 生成令牌
|
205 |
+
def generate_token(user_id, username):
|
206 |
+
# 构造包含用户信息的负载
|
207 |
+
payload = {
|
208 |
+
'user_id': user_id,
|
209 |
+
'username': username,
|
210 |
+
'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1)
|
211 |
+
}
|
212 |
+
|
213 |
+
# 在这里,您可以使用您的密钥(secret key)来签署令牌
|
214 |
+
# 选择适当的签名算法,并设置适当的过期时间等参数
|
215 |
+
# 仅使用 HS256 算法和过期时间为1小时
|
216 |
+
token = jwt.encode(payload, SECERT_KEY, algorithm='HS256')
|
217 |
+
return token
|
218 |
+
|
219 |
+
|
220 |
+
# 验证令牌
|
221 |
+
def validate_token(token):
|
222 |
+
try:
|
223 |
+
print(token)
|
224 |
+
# 使用密钥进行解码
|
225 |
+
payload = jwt.decode(token, SECERT_KEY, algorithms=['HS256'])
|
226 |
+
print(payload)
|
227 |
+
|
228 |
+
# 检查令牌的过期时间
|
229 |
+
if 'exp' in payload and datetime.datetime.utcnow() > datetime.datetime.fromtimestamp(payload['exp']):
|
230 |
+
return False
|
231 |
+
|
232 |
+
return True
|
233 |
+
except (jwt.DecodeError, jwt.InvalidTokenError):
|
234 |
+
return False
|
235 |
+
|
236 |
+
|
237 |
+
def get_user_id_from_token(token):
|
238 |
+
try:
|
239 |
+
decoded_token = jwt.decode(
|
240 |
+
token, SECERT_KEY, algorithms=['HS256'])
|
241 |
+
user_id = decoded_token.get('user_id')
|
242 |
+
return user_id
|
243 |
+
except jwt.ExpiredSignatureError:
|
244 |
+
# 处理过期的令牌
|
245 |
+
return None
|
246 |
+
except (jwt.DecodeError, jwt.InvalidTokenError):
|
247 |
+
# 处理解码或无效的令牌
|
248 |
+
return None
|
249 |
+
|
250 |
+
|
251 |
# 获取用户ID通过用户名
|
252 |
def get_user_id_by_username(username):
|
253 |
user_data = redis.hget('users', username)
|
|
|
331 |
return expiration
|
332 |
|
333 |
# 检查用户聊天次数是否超过限制
|
334 |
+
|
335 |
+
|
336 |
def exceeded_chat_limit(user_id, package):
|
337 |
user_basic_chat_key = f'user:{user_id}:basic_chat'
|
338 |
user_advanced_chat_key = f'user:{user_id}:advanced_chat'
|
339 |
|
340 |
basic_chat_limit = int(package.get(b'basic_chat_limit', 0).decode('utf-8'))
|
341 |
+
advanced_chat_limit = int(package.get(
|
342 |
+
b'advanced_chat_limit', 0).decode('utf-8'))
|
343 |
|
344 |
if basic_chat_limit >= 0 and int(redis.get(user_basic_chat_key) or 0) >= basic_chat_limit:
|
345 |
return True
|