远兮 commited on
Commit
b47ad38
·
1 Parent(s): 599ee8f

由session改为token

Browse files
Files changed (1) hide show
  1. redis/test_user_redis.py +106 -29
redis/test_user_redis.py CHANGED
@@ -3,11 +3,14 @@ import random
3
  import string
4
  import uuid
5
  import time
6
- from flask import Flask, request, jsonify, session
 
 
7
  from redis import Redis
8
 
 
 
9
  app = Flask(__name__)
10
- app.secret_key = '333888'
11
  redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis')
12
  # redis = Redis(host='10.254.13.87', port=6379)
13
  # redis = Redis(host='localhost', port=6379)
@@ -15,7 +18,6 @@ redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis')
15
 
16
  # 生成验证码
17
  def generate_verification_code():
18
- # code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
19
  code = ''.join(random.choices(string.digits, k=6))
20
  return code
21
 
@@ -60,7 +62,7 @@ def register():
60
  # 检查用户名是否已被注册
61
  if redis.hexists('users', username):
62
  return jsonify({'message': 'Username already exists'}), 400
63
-
64
  # 生成唯一的用户ID
65
  user_id = str(uuid.uuid4())
66
 
@@ -73,7 +75,10 @@ def register():
73
  }
74
  redis.hset('users', username, json.dumps(user_data))
75
 
76
- return jsonify({'code': 0, 'message': 'Registration successful'})
 
 
 
77
 
78
 
79
  # 用户登录
@@ -92,20 +97,24 @@ def login():
92
  if password != eval(user_data)['password']:
93
  return jsonify({'code': 400, 'message': 'Invalid password'})
94
 
95
- # 登录验证通过,将用户信息存储到会话中
96
- session['username'] = request.json.get('username')
97
-
98
- return jsonify({'code': 0, 'message': 'Login successful'})
 
 
 
 
 
99
 
100
 
101
  # 需要验证登录状态的接口
102
  @app.route('/protected', methods=['GET'])
103
  def protected():
104
- # 检查会话中的用户信息
105
- if 'username' in session:
106
- username = session['username']
107
- # 其他业务逻辑...
108
- return jsonify({'code': 0, 'message': f'Hello, {username}! This is a protected endpoint.'})
109
 
110
  # 如果用户未登录,则返回未授权的响应
111
  return jsonify({'code': 401, 'message': 'Unauthorized'})
@@ -114,8 +123,11 @@ def protected():
114
  # 用户注销
115
  @app.route('/logout', methods=['POST'])
116
  def logout():
117
- # 清除会话中的用户信息
118
- session.pop('username', None)
 
 
 
119
  return jsonify({'code': 0, 'message': 'Logout successful'})
120
 
121
 
@@ -123,18 +135,18 @@ def logout():
123
  @app.route('/purchase', methods=['POST'])
124
  def purchase():
125
  package_id = request.json.get('package_id')
 
 
 
 
 
126
 
127
  # 根据套餐ID获取套餐信息
128
  package = get_package_by_id(package_id)
129
  if not package:
130
  return jsonify({'code': 400, 'message': 'Invalid package ID'})
131
-
132
- # 根据用户名查询用户ID
133
- username = session.get('username')
134
- if not username:
135
- return jsonify({'code': 400, 'message': 'User not logged in'})
136
 
137
- user_id = get_user_id_by_username(username)
138
  if not user_id:
139
  return jsonify({'code': 400, 'message': 'User not found'})
140
 
@@ -145,7 +157,7 @@ def purchase():
145
  # 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐
146
  if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1':
147
  return jsonify({'code': 400, 'message': 'Cannot purchase lower level package'})
148
-
149
  # 存储用户套餐信息到Redis
150
  store_user_package(user_id, package)
151
 
@@ -155,13 +167,17 @@ def purchase():
155
  # 验证用户聊天次数
156
  @app.route('/validate', methods=['POST'])
157
  def validate():
158
- # 根据用户名查询用户ID
159
- username = session.get('username')
160
- user_id = get_user_id_by_username(username)
161
-
 
 
 
 
162
  if not user_id:
163
  return jsonify({'code': 400, 'message': 'User not found'})
164
-
165
  # 获取用户套餐信息
166
  package = get_user_package(user_id)
167
  if not package:
@@ -174,6 +190,64 @@ def validate():
174
  return jsonify({'code': 0, 'message': 'Chat limit not exceeded'})
175
 
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # 获取用户ID通过用户名
178
  def get_user_id_by_username(username):
179
  user_data = redis.hget('users', username)
@@ -257,12 +331,15 @@ def get_package_expiration(user_id):
257
  return expiration
258
 
259
  # 检查用户聊天次数是否超过限制
 
 
260
  def exceeded_chat_limit(user_id, package):
261
  user_basic_chat_key = f'user:{user_id}:basic_chat'
262
  user_advanced_chat_key = f'user:{user_id}:advanced_chat'
263
 
264
  basic_chat_limit = int(package.get(b'basic_chat_limit', 0).decode('utf-8'))
265
- advanced_chat_limit = int(package.get(b'advanced_chat_limit', 0).decode('utf-8'))
 
266
 
267
  if basic_chat_limit >= 0 and int(redis.get(user_basic_chat_key) or 0) >= basic_chat_limit:
268
  return True
 
3
  import string
4
  import uuid
5
  import time
6
+ import jwt
7
+ import datetime
8
+ from flask import Flask, request, jsonify, Request
9
  from redis import Redis
10
 
11
+ SECERT_KEY = "8U2LL1"
12
+
13
  app = Flask(__name__)
 
14
  redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis')
15
  # redis = Redis(host='10.254.13.87', port=6379)
16
  # redis = Redis(host='localhost', port=6379)
 
18
 
19
  # 生成验证码
20
  def generate_verification_code():
 
21
  code = ''.join(random.choices(string.digits, k=6))
22
  return code
23
 
 
62
  # 检查用户名是否已被注册
63
  if redis.hexists('users', username):
64
  return jsonify({'message': 'Username already exists'}), 400
65
+
66
  # 生成唯一的用户ID
67
  user_id = str(uuid.uuid4())
68
 
 
75
  }
76
  redis.hset('users', username, json.dumps(user_data))
77
 
78
+ return jsonify({
79
+ 'code': 0,
80
+ 'message': 'Registration successful'
81
+ })
82
 
83
 
84
  # 用户登录
 
97
  if password != eval(user_data)['password']:
98
  return jsonify({'code': 400, 'message': 'Invalid password'})
99
 
100
+ # 生成令牌
101
+ token = generate_token(eval(user_data)['user_id'], username)
102
+ return jsonify({
103
+ 'code': 0,
104
+ 'message': 'Login successful',
105
+ 'data': {
106
+ 'token': token
107
+ }
108
+ })
109
 
110
 
111
  # 需要验证登录状态的接口
112
  @app.route('/protected', methods=['GET'])
113
  def protected():
114
+ token = parse_token(request)
115
+ # 验证令牌
116
+ if not validate_token(token):
117
+ return jsonify({'code': 401, 'message': 'Invalid token'}), 200
 
118
 
119
  # 如果用户未登录,则返回未授权的响应
120
  return jsonify({'code': 401, 'message': 'Unauthorized'})
 
123
  # 用户注销
124
  @app.route('/logout', methods=['POST'])
125
  def logout():
126
+ token = parse_token(request)
127
+ # 验证令牌
128
+ if not validate_token(token):
129
+ # 将令牌添加到 Redis 黑名单
130
+ redis.set(token, 'revoked')
131
  return jsonify({'code': 0, 'message': 'Logout successful'})
132
 
133
 
 
135
  @app.route('/purchase', methods=['POST'])
136
  def purchase():
137
  package_id = request.json.get('package_id')
138
+ token = parse_token(request)
139
+
140
+ # 验证令牌
141
+ if not validate_token(token):
142
+ return jsonify({'code': 401, 'message': 'Invalid token'})
143
 
144
  # 根据套餐ID获取套餐信息
145
  package = get_package_by_id(package_id)
146
  if not package:
147
  return jsonify({'code': 400, 'message': 'Invalid package ID'})
 
 
 
 
 
148
 
149
+ user_id = get_user_id_from_token(token)
150
  if not user_id:
151
  return jsonify({'code': 400, 'message': 'User not found'})
152
 
 
157
  # 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐
158
  if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1':
159
  return jsonify({'code': 400, 'message': 'Cannot purchase lower level package'})
160
+
161
  # 存储用户套餐信息到Redis
162
  store_user_package(user_id, package)
163
 
 
167
  # 验证用户聊天次数
168
  @app.route('/validate', methods=['POST'])
169
  def validate():
170
+ token = parse_token(request)
171
+
172
+ # 验证令牌
173
+ if not validate_token(token):
174
+ return jsonify({'code': 401, 'message': 'Invalid token'})
175
+
176
+ user_id = get_user_id_from_token(token)
177
+
178
  if not user_id:
179
  return jsonify({'code': 400, 'message': 'User not found'})
180
+
181
  # 获取用户套餐信息
182
  package = get_user_package(user_id)
183
  if not package:
 
190
  return jsonify({'code': 0, 'message': 'Chat limit not exceeded'})
191
 
192
 
193
+ def parse_token(request: Request):
194
+ token_with_bearer = request.headers.get('Authorization')
195
+
196
+ if token_with_bearer is not None and token_with_bearer.startswith('Bearer '):
197
+ token = token_with_bearer.split(' ')[1]
198
+ else:
199
+ # 处理未包含 "Bearer" 前缀的情况
200
+ token = token_with_bearer
201
+ return token
202
+
203
+
204
+ # 生成令牌
205
+ def generate_token(user_id, username):
206
+ # 构造包含用户信息的负载
207
+ payload = {
208
+ 'user_id': user_id,
209
+ 'username': username,
210
+ 'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1)
211
+ }
212
+
213
+ # 在这里,您可以使用您的密钥(secret key)来签署令牌
214
+ # 选择适当的签名算法,并设置适当的过期时间等参数
215
+ # 仅使用 HS256 算法和过期时间为1小时
216
+ token = jwt.encode(payload, SECERT_KEY, algorithm='HS256')
217
+ return token
218
+
219
+
220
+ # 验证令牌
221
+ def validate_token(token):
222
+ try:
223
+ print(token)
224
+ # 使用密钥进行解码
225
+ payload = jwt.decode(token, SECERT_KEY, algorithms=['HS256'])
226
+ print(payload)
227
+
228
+ # 检查令牌的过期时间
229
+ if 'exp' in payload and datetime.datetime.utcnow() > datetime.datetime.fromtimestamp(payload['exp']):
230
+ return False
231
+
232
+ return True
233
+ except (jwt.DecodeError, jwt.InvalidTokenError):
234
+ return False
235
+
236
+
237
+ def get_user_id_from_token(token):
238
+ try:
239
+ decoded_token = jwt.decode(
240
+ token, SECERT_KEY, algorithms=['HS256'])
241
+ user_id = decoded_token.get('user_id')
242
+ return user_id
243
+ except jwt.ExpiredSignatureError:
244
+ # 处理过期的令牌
245
+ return None
246
+ except (jwt.DecodeError, jwt.InvalidTokenError):
247
+ # 处理解码或无效的令牌
248
+ return None
249
+
250
+
251
  # 获取用户ID通过用户名
252
  def get_user_id_by_username(username):
253
  user_data = redis.hget('users', username)
 
331
  return expiration
332
 
333
  # 检查用户聊天次数是否超过限制
334
+
335
+
336
  def exceeded_chat_limit(user_id, package):
337
  user_basic_chat_key = f'user:{user_id}:basic_chat'
338
  user_advanced_chat_key = f'user:{user_id}:advanced_chat'
339
 
340
  basic_chat_limit = int(package.get(b'basic_chat_limit', 0).decode('utf-8'))
341
+ advanced_chat_limit = int(package.get(
342
+ b'advanced_chat_limit', 0).decode('utf-8'))
343
 
344
  if basic_chat_limit >= 0 and int(redis.get(user_basic_chat_key) or 0) >= basic_chat_limit:
345
  return True