Moonfanz commited on
Commit
39bdd55
·
verified ·
1 Parent(s): fcbb4e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +27 -130
  2. func.py +152 -0
app.py CHANGED
@@ -4,12 +4,10 @@ import google.generativeai as genai
4
  import json
5
  from datetime import datetime
6
  import os
7
- import base64
8
  from termcolor import colored
9
  import logging
10
- from PIL import Image
11
- from io import BytesIO
12
- import re
13
  os.environ['TZ'] = 'Asia/Shanghai'
14
  app = Flask(__name__)
15
  if 'API_KEYS' not in os.environ:
@@ -63,78 +61,31 @@ current_api_key = key_manager.get_available_key()
63
  logger.info(f"Current API key: {current_api_key}")
64
  # 模型列表
65
  GEMINI_MODELS = [
66
- {"id": "gemini-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
67
- {"id": "gemini-pro-vision", "object": "model", "created": 1700000000, "owned_by": "google"},
68
- {"id": "gemini-1.0-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
69
- {"id": "gemini-1.0-pro-vision", "object": "model", "created": 1700000000, "owned_by": "google"},
70
- {"id": "gemini-1.5-pro-002", "object": "model", "created": 1700000000, "owned_by": "google"},
71
- {"id": "gemini-exp-1114", "object": "model", "created": 1700000000, "owned_by": "google"},
72
- {"id": "gemini-exp-1121", "object": "model", "created": 1700000000, "owned_by": "google"},
73
- {"id": "gemini-exp-1206", "object": "model", "created": 1700000000, "owned_by": "google"},
74
- {"id": "gemini-2.0-flash-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
75
- {"id": "gemini-2.0-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
76
- {"id": "gemini-2.0-pro-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
77
  ]
78
- def authenticate_request(request):
79
- auth_header = request.headers.get('Authorization')
80
- hf_api_key = os.environ.get('HF_API_KEY').split(',')
81
-
82
- if not auth_header:
83
- return False, jsonify({'error': 'Authorization header is missing'}), 401
84
-
85
- try:
86
- auth_type, api_key = auth_header.split(' ', 1)
87
- except ValueError:
88
- return False, jsonify({'error': 'Invalid Authorization header format'}), 401
89
-
90
- if auth_type.lower() != 'bearer':
91
- return False, jsonify({'error': 'Authorization type must be Bearer'}), 401
92
-
93
- if api_key not in hf_api_key:
94
- return False, jsonify({'error': 'Unauthorized'}), 401
95
-
96
- return True, None, None
97
-
98
- def sanitize_request_data(request_data):
99
- """
100
- 从请求数据中删除base64编码的数据。
101
-
102
- Args:
103
- request_data: 包含可能存在base64数据的字典。
104
-
105
- Returns:
106
- 清理后的字典,其中base64数据被替换为"[Base64 Data Omitted]"。
107
- """
108
-
109
-
110
- def replace_base64(match):
111
- # 替换base64数据为提示信息
112
- return '"[Base64 Data Omitted]"'
113
-
114
-
115
- request_data_str = json.dumps(request_data)
116
-
117
- # 使用正则表达式匹配base64数据,并替换为提示信息
118
- sanitized_request_data_str = re.sub(
119
- r'"(data:[^;]+;base64,)[^"]+"',
120
- replace_base64,
121
- request_data_str
122
- )
123
-
124
- return json.loads(sanitized_request_data_str)
125
 
126
  @app.route('/hf/v1/chat/completions', methods=['POST'])
127
  def chat_completions():
128
  global current_api_key
129
- is_authenticated, auth_error, status_code = authenticate_request(request)
130
  if not is_authenticated:
131
  return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
132
  try:
133
  request_data = request.get_json()
134
- r_data = sanitize_request_data(request_data)
135
  logger.info(json.dumps(r_data, indent=4, ensure_ascii=False), extra={"color": "green"})
136
  messages = request_data.get('messages', [])
137
- model = request_data.get('model', 'gemini-exp-1206') # 默认改为 gemini-pro-vision 用于处理图文
138
  temperature = request_data.get('temperature', 1)
139
  max_tokens = request_data.get('max_tokens', 8192)
140
  stream = request_data.get('stream', False)
@@ -142,67 +93,12 @@ def chat_completions():
142
  logger.info(colored(f"\n{model} [r] -> {current_api_key[:11]}...", 'yellow'))
143
 
144
  # 将 OpenAI 格式的消息转换为 Gemini 格式
145
- gemini_history = []
146
- for message in messages:
147
- role = message.get('role')
148
- content = message.get('content')
149
-
150
- if isinstance(content, str): # 纯文本
151
- if role == 'system':
152
- gemini_history.append({"role": "user", "parts": [content]})
153
- elif role == 'user':
154
- gemini_history.append({"role": "user", "parts": [content]})
155
- elif role == 'assistant':
156
- gemini_history.append({"role": "model", "parts": [content]})
157
- elif isinstance(content, list): # 图文
158
- parts = []
159
- for item in content:
160
- if item.get('type') == 'text':
161
- parts.append(item.get('text'))
162
- elif item.get('type') == 'image_url':
163
- image_data = item.get('image_url', {}).get('url', '')
164
- if image_data.startswith('data:image/'): # 修改判断条件
165
- try:
166
- # 提取 base64 编码和图片类型
167
- image_type = image_data.split(';')[0].split('/')[1].upper() # 提取图片类型并转为大写
168
- base64_image = image_data.split(';base64,')[1]
169
-
170
- image = Image.open(BytesIO(base64.b64decode(base64_image)))
171
-
172
- # 将图片转换为 RGB 模式
173
- if image.mode != 'RGB':
174
- image = image.convert('RGB')
175
-
176
- # 压缩图像
177
- if image.width > 2048 or image.height > 2048:
178
- image.thumbnail((2048, 2048))
179
-
180
- output_buffer = BytesIO()
181
- image.save(output_buffer, format=image_type) # 使用原始图片类型保存
182
- output_buffer.seek(0)
183
- parts.append(image) # 直接添加 image 对象
184
- except Exception as e:
185
- logger.error(f"Error processing image: {e}")
186
- return jsonify({'error': 'Invalid image data'}), 400
187
- else:
188
- return jsonify({'error': 'Invalid image URL format'}), 400
189
-
190
-
191
- # 根据 role 添加到 gemini_history
192
- if role in ['user', 'system']:
193
- gemini_history.append({"role": "user", "parts": parts}) # 保持 parts 的原有顺序
194
- elif role == 'assistant':
195
- gemini_history.append({"role": "model", "parts": parts})
196
- else:
197
- return jsonify({'error': f'Invalid role: {role}'}), 400
198
-
199
- # 用户最后一条消息
200
- if gemini_history:
201
- user_message = gemini_history[-1]
202
- gemini_history = gemini_history[:-1] # 历史记录不包含最后一条消息
203
- else:
204
- user_message = {"role": "user", "parts": [""]}
205
-
206
  genai.configure(api_key=current_api_key)
207
 
208
  generation_config = {
@@ -305,10 +201,10 @@ def chat_completions():
305
  }
306
  logger.info(colored(f"Generation Success", 'green'))
307
  return jsonify(response_data)
308
-
309
  except Exception as e:
310
  logger.error(f"Error in chat completions: {str(e)}")
311
-
312
  return jsonify({
313
  'error': {
314
  'message': str(e),
@@ -319,10 +215,11 @@ def chat_completions():
319
  current_api_key = key_manager.get_available_key()
320
  logger.info(colored(f"API KEY Switched -> {current_api_key[:11]}...", 'aqua'))
321
 
322
-
323
-
324
  @app.route('/hf/v1/models', methods=['GET'])
325
  def list_models():
 
 
 
326
  response = {"object": "list", "data": GEMINI_MODELS}
327
  return jsonify(response)
328
 
 
4
  import json
5
  from datetime import datetime
6
  import os
 
7
  from termcolor import colored
8
  import logging
9
+ import func
10
+
 
11
  os.environ['TZ'] = 'Asia/Shanghai'
12
  app = Flask(__name__)
13
  if 'API_KEYS' not in os.environ:
 
61
  logger.info(f"Current API key: {current_api_key}")
62
  # 模型列表
63
  GEMINI_MODELS = [
64
+ {"id": "gemini-pro"},
65
+ {"id": "gemini-pro-vision"},
66
+ {"id": "gemini-1.0-pro"},
67
+ {"id": "gemini-1.0-pro-vision"},
68
+ {"id": "gemini-1.5-pro-002"},
69
+ {"id": "gemini-exp-1114"},
70
+ {"id": "gemini-exp-1121"},
71
+ {"id": "gemini-exp-1206"},
72
+ {"id": "gemini-2.0-flash-exp"},
73
+ {"id": "gemini-2.0-exp"},
74
+ {"id": "gemini-2.0-pro-exp"},
75
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @app.route('/hf/v1/chat/completions', methods=['POST'])
78
  def chat_completions():
79
  global current_api_key
80
+ is_authenticated, auth_error, status_code = func.authenticate_request(request)
81
  if not is_authenticated:
82
  return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
83
  try:
84
  request_data = request.get_json()
85
+ r_data = func.sanitize_request_data(request_data)
86
  logger.info(json.dumps(r_data, indent=4, ensure_ascii=False), extra={"color": "green"})
87
  messages = request_data.get('messages', [])
88
+ model = request_data.get('model', 'gemini-exp-1206')
89
  temperature = request_data.get('temperature', 1)
90
  max_tokens = request_data.get('max_tokens', 8192)
91
  stream = request_data.get('stream', False)
 
93
  logger.info(colored(f"\n{model} [r] -> {current_api_key[:11]}...", 'yellow'))
94
 
95
  # 将 OpenAI 格式的消息转换为 Gemini 格式
96
+ gemini_history, user_message, error_response = func.process_messages_for_gemini(messages)
97
+
98
+ if error_response:
99
+ # 处理错误
100
+ print(error_response)
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  genai.configure(api_key=current_api_key)
103
 
104
  generation_config = {
 
201
  }
202
  logger.info(colored(f"Generation Success", 'green'))
203
  return jsonify(response_data)
204
+
205
  except Exception as e:
206
  logger.error(f"Error in chat completions: {str(e)}")
207
+
208
  return jsonify({
209
  'error': {
210
  'message': str(e),
 
215
  current_api_key = key_manager.get_available_key()
216
  logger.info(colored(f"API KEY Switched -> {current_api_key[:11]}...", 'aqua'))
217
 
 
 
218
  @app.route('/hf/v1/models', methods=['GET'])
219
  def list_models():
220
+ is_authenticated, auth_error, status_code = func.authenticate_request(request)
221
+ if not is_authenticated:
222
+ return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
223
  response = {"object": "list", "data": GEMINI_MODELS}
224
  return jsonify(response)
225
 
func.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import base64
3
+ from PIL import Image
4
+ from flask import jsonify
5
+ import logging
6
+ import json
7
+ import re
8
+ import os
9
+ GEMINI_MODELS = [
10
+ {"id": "gemini-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
11
+ {"id": "gemini-pro-vision", "object": "model", "created": 1700000000, "owned_by": "google"},
12
+ {"id": "gemini-1.0-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
13
+ {"id": "gemini-1.0-pro-vision", "object": "model", "created": 1700000000, "owned_by": "google"},
14
+ {"id": "gemini-1.5-pro-002", "object": "model", "created": 1700000000, "owned_by": "google"},
15
+ {"id": "gemini-exp-1114", "object": "model", "created": 1700000000, "owned_by": "google"},
16
+ {"id": "gemini-exp-1121", "object": "model", "created": 1700000000, "owned_by": "google"},
17
+ {"id": "gemini-exp-1206", "object": "model", "created": 1700000000, "owned_by": "google"},
18
+ {"id": "gemini-2.0-flash-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
19
+ {"id": "gemini-2.0-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
20
+ {"id": "gemini-2.0-pro-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
21
+ ]
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ def authenticate_request(request):
26
+ auth_header = request.headers.get('Authorization')
27
+ hf_api_key = os.environ.get('HF_API_KEY').split(',')
28
+
29
+ if not auth_header:
30
+ return False, jsonify({'error': 'Authorization header is missing'}), 401
31
+
32
+ try:
33
+ auth_type, api_key = auth_header.split(' ', 1)
34
+ except ValueError:
35
+ return False, jsonify({'error': 'Invalid Authorization header format'}), 401
36
+
37
+ if auth_type.lower() != 'bearer':
38
+ return False, jsonify({'error': 'Authorization type must be Bearer'}), 401
39
+
40
+ if api_key not in hf_api_key:
41
+ return False, jsonify({'error': 'Unauthorized'}), 401
42
+
43
+ return True, None, None
44
+
45
+ def sanitize_request_data(request_data):
46
+ """
47
+ 从请求数据中删除base64编码的数据。
48
+
49
+ Args:
50
+ request_data: 包含可能存在base64数据的字典。
51
+
52
+ Returns:
53
+ 清理后的字典,其中base64数据被替换为"[Base64 Data Omitted]"。
54
+ """
55
+
56
+
57
+ def replace_base64(match):
58
+ # 替换base64数据为提示信息
59
+ return '"[Base64 Data Omitted]"'
60
+
61
+
62
+ request_data_str = json.dumps(request_data)
63
+
64
+ # 使用正则表达式匹配base64数据,并替换为提示信息
65
+ sanitized_request_data_str = re.sub(
66
+ r'"(data:[^;]+;base64,)[^"]+"',
67
+ replace_base64,
68
+ request_data_str
69
+ )
70
+
71
+ return json.loads(sanitized_request_data_str)
72
+
73
+ def process_messages_for_gemini(messages):
74
+ """
75
+ Processes a list of messages to construct the history format expected by Gemini,
76
+ handling both text and image content.
77
+
78
+ Args:
79
+ messages: A list of message dictionaries. Each dictionary should have a 'role' key
80
+ ('system', 'user', or 'assistant') and a 'content' key.
81
+ The 'content' can be either a string (for text messages) or a list
82
+ (for multi-modal messages).
83
+ For multi-modal messages, the list should contain dictionaries with 'type'
84
+ ('text' or 'image_url') and corresponding 'text' or 'image_url' keys.
85
+
86
+ Returns:
87
+ A tuple containing:
88
+ - gemini_history: A list of dictionaries formatted for the Gemini API, excluding the last user message.
89
+ - user_message: The last user message from the input, or a default empty user message if the input is empty.
90
+ - error_response: A Flask error response (jsonify, status_code) if an error occurred, otherwise None.
91
+ """
92
+ gemini_history = []
93
+ for message in messages:
94
+ role = message.get('role')
95
+ content = message.get('content')
96
+
97
+ if isinstance(content, str): # 纯文本
98
+ if role == 'system':
99
+ gemini_history.append({"role": "user", "parts": [content]})
100
+ elif role == 'user':
101
+ gemini_history.append({"role": "user", "parts": [content]})
102
+ elif role == 'assistant':
103
+ gemini_history.append({"role": "model", "parts": [content]})
104
+ elif isinstance(content, list): # 图文
105
+ parts = []
106
+ for item in content:
107
+ if item.get('type') == 'text':
108
+ parts.append(item.get('text'))
109
+ elif item.get('type') == 'image_url':
110
+ image_data = item.get('image_url', {}).get('url', '')
111
+ if image_data.startswith('data:image/'):
112
+ try:
113
+ # 提取 base64 编码和图片类型
114
+ image_type = image_data.split(';')[0].split('/')[1].upper() # 提取图片类型并转为大写
115
+ base64_image = image_data.split(';base64,')[1]
116
+
117
+ image = Image.open(BytesIO(base64.b64decode(base64_image)))
118
+
119
+ # 将图片转换为 RGB 模式
120
+ if image.mode != 'RGB':
121
+ image = image.convert('RGB')
122
+
123
+ # 压缩图像
124
+ if image.width > 2048 or image.height > 2048:
125
+ image.thumbnail((2048, 2048))
126
+
127
+ output_buffer = BytesIO()
128
+ image.save(output_buffer, format=image_type) # 使用原始图片类型保存
129
+ output_buffer.seek(0)
130
+ parts.append(image)
131
+ except Exception as e:
132
+ logger.error(f"Error processing image: {e}")
133
+ return [], None, (jsonify({'error': 'Invalid image data'}), 400)
134
+ else:
135
+ return [], None, (jsonify({'error': 'Invalid image URL format'}), 400)
136
+
137
+ # 根据 role 添加到 gemini_history
138
+ if role in ['user', 'system']:
139
+ gemini_history.append({"role": "user", "parts": parts})
140
+ elif role == 'assistant':
141
+ gemini_history.append({"role": "model", "parts": parts})
142
+ else:
143
+ return [], None, (jsonify({'error': f'Invalid role: {role}'}), 400)
144
+
145
+ # 用户最后一条消息
146
+ if gemini_history:
147
+ user_message = gemini_history[-1]
148
+ gemini_history = gemini_history[:-1] # 历史记录不包含最后一条消息
149
+ else:
150
+ user_message = {"role": "user", "parts": [""]}
151
+
152
+ return gemini_history, user_message, None