Moonfanz commited on
Commit
0af7a49
·
verified ·
1 Parent(s): f314156

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +7 -12
  2. func.py +21 -0
app.py CHANGED
@@ -8,7 +8,6 @@ import func
8
  from apscheduler.schedulers.background import BackgroundScheduler
9
  import requests
10
  import time
11
- from flask_httpauth import HTTPBasicAuth
12
 
13
  os.environ['TZ'] = 'Asia/Shanghai'
14
  app = Flask(__name__)
@@ -25,12 +24,6 @@ handler.setFormatter(formatter)
25
 
26
  logger.addHandler(handler)
27
 
28
- auth = HTTPBasicAuth()
29
-
30
- @auth.verify_password
31
- def verify_password(username, password):
32
- return password == os.environ.get('password')
33
-
34
  safety_settings = [
35
  {
36
  "category": "HARM_CATEGORY_HARASSMENT",
@@ -120,9 +113,11 @@ function copyLink(event) {
120
  return render_template_string(html_template, main_content=main_content)
121
 
122
  @app.route('/hf/v1/chat/completions', methods=['POST'])
123
- @auth.login_required
124
  def chat_completions():
125
  global current_api_key
 
 
 
126
  try:
127
  request_data = request.get_json()
128
  messages = request_data.get('messages', [])
@@ -250,11 +245,11 @@ def chat_completions():
250
  }
251
  }), 500
252
 
253
-
254
-
255
  @app.route('/hf/v1/models', methods=['GET'])
256
- @auth.login_required
257
  def list_models():
 
 
 
258
  response = {"object": "list", "data": GEMINI_MODELS}
259
  return jsonify(response)
260
 
@@ -269,6 +264,6 @@ def keep_alive():
269
  if __name__ == '__main__':
270
  scheduler = BackgroundScheduler()
271
  scheduler.add_job(keep_alive, 'interval', hours = 12)
272
-
273
  scheduler.start()
274
  app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))
 
8
  from apscheduler.schedulers.background import BackgroundScheduler
9
  import requests
10
  import time
 
11
 
12
  os.environ['TZ'] = 'Asia/Shanghai'
13
  app = Flask(__name__)
 
24
 
25
  logger.addHandler(handler)
26
 
 
 
 
 
 
 
27
  safety_settings = [
28
  {
29
  "category": "HARM_CATEGORY_HARASSMENT",
 
113
  return render_template_string(html_template, main_content=main_content)
114
 
115
  @app.route('/hf/v1/chat/completions', methods=['POST'])
 
116
  def chat_completions():
117
  global current_api_key
118
+ is_authenticated, auth_error, status_code = func.authenticate_request(request)
119
+ if not is_authenticated:
120
+ return auth_error if auth_error else jsonify({'error': '未授权'}), status_code if status_code else 401
121
  try:
122
  request_data = request.get_json()
123
  messages = request_data.get('messages', [])
 
245
  }
246
  }), 500
247
 
 
 
248
  @app.route('/hf/v1/models', methods=['GET'])
 
249
  def list_models():
250
+ is_authenticated, auth_error, status_code = func.authenticate_request(request)
251
+ if not is_authenticated:
252
+ return auth_error if auth_error else jsonify({'error': '未授权'}), status_code if status_code else 401
253
  response = {"object": "list", "data": GEMINI_MODELS}
254
  return jsonify(response)
255
 
 
264
  if __name__ == '__main__':
265
  scheduler = BackgroundScheduler()
266
  scheduler.add_job(keep_alive, 'interval', hours = 12)
267
+
268
  scheduler.start()
269
  app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))
func.py CHANGED
@@ -8,6 +8,27 @@ import re
8
  import os
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def process_messages_for_gemini(messages):
12
 
13
  gemini_history = []
 
8
  import os
9
  logger = logging.getLogger(__name__)
10
 
11
+ password = os.environ['password']
12
+
13
+ def authenticate_request(request):
14
+ auth_header = request.headers.get('Authorization')
15
+
16
+ if not auth_header:
17
+ return False, jsonify({'error': '缺少Authorization请求头'}), 401
18
+
19
+ try:
20
+ auth_type, pass_word = auth_header.split(' ', 1)
21
+ except ValueError:
22
+ return False, jsonify({'error': 'Authorization请求头格式错误'}), 401
23
+
24
+ if auth_type.lower() != 'bearer':
25
+ return False, jsonify({'error': 'Authorization类型必须为Bearer'}), 401
26
+
27
+ if pass_word != password:
28
+ return False, jsonify({'error': '未授权'}), 401
29
+
30
+ return True, None, None
31
+
32
  def process_messages_for_gemini(messages):
33
 
34
  gemini_history = []