|
from flask import Flask, request, make_response |
|
import hashlib |
|
import time |
|
import xml.etree.ElementTree as ET |
|
import os |
|
import json |
|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
from markdown import markdown |
|
import re |
|
import threading |
|
import logging |
|
from datetime import datetime |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
import queue |
|
import uuid |
|
import base64 |
|
from Crypto.Cipher import AES |
|
import struct |
|
import random |
|
import string |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('wechat_service.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
load_dotenv() |
|
|
|
app = Flask(__name__) |
|
|
|
TOKEN = os.getenv('TOKEN') |
|
ENCODING_AES_KEY = os.getenv('ENCODING_AES_KEY') |
|
APPID = os.getenv('APPID') |
|
API_KEY = os.getenv("API_KEY") |
|
BASE_URL = os.getenv("OPENAI_BASE_URL") |
|
|
|
client = OpenAI(api_key=API_KEY, base_url=BASE_URL) |
|
executor = ThreadPoolExecutor(max_workers=10) |
|
|
|
class WeChatCrypto: |
|
def __init__(self, key, app_id): |
|
self.key = base64.b64decode(key + '=') |
|
self.app_id = app_id |
|
|
|
def encrypt(self, text): |
|
|
|
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=16)) |
|
text_bytes = text.encode('utf-8') |
|
|
|
|
|
msg_len = struct.pack('>I', len(text_bytes)) |
|
message = random_str.encode('utf-8') + msg_len + text_bytes + self.app_id.encode('utf-8') |
|
|
|
|
|
pad_len = 32 - (len(message) % 32) |
|
message += chr(pad_len).encode('utf-8') * pad_len |
|
|
|
|
|
cipher = AES.new(self.key, AES.MODE_CBC, self.key[:16]) |
|
encrypted = cipher.encrypt(message) |
|
return base64.b64encode(encrypted).decode('utf-8') |
|
|
|
def decrypt(self, encrypted_text): |
|
|
|
encrypted_data = base64.b64decode(encrypted_text) |
|
|
|
|
|
cipher = AES.new(self.key, AES.MODE_CBC, self.key[:16]) |
|
decrypted = cipher.decrypt(encrypted_data) |
|
|
|
|
|
pad_len = decrypted[-1] |
|
if not isinstance(pad_len, int): |
|
pad_len = ord(pad_len) |
|
content = decrypted[16:-pad_len] |
|
|
|
|
|
msg_len = struct.unpack('>I', content[:4])[0] |
|
xml_content = content[4:msg_len + 4].decode('utf-8') |
|
app_id = content[msg_len + 4:].decode('utf-8') |
|
|
|
if app_id != self.app_id: |
|
raise ValueError('Invalid AppID') |
|
|
|
return xml_content |
|
|
|
class AsyncResponse: |
|
def __init__(self): |
|
self.status = "processing" |
|
self.result = None |
|
self.error = None |
|
self.create_time = time.time() |
|
self.timeout = 3600 |
|
|
|
def is_expired(self): |
|
return time.time() - self.create_time > self.timeout |
|
|
|
class UserSession: |
|
def __init__(self): |
|
self.messages = [{"role": "system", "content": "你是HXIAO公众号的智能助手,这一个用来分享与学习人工智能的公众号,我们的目标是专注AI应用的简单研究与实践。致力于分享切实可行的技术方案,希望让复杂的技术变得简单易懂。也喜欢用通俗的语言来解释专业概念,让技术真正服务于每个学习者"}] |
|
self.pending_parts = [] |
|
self.last_active = time.time() |
|
self.current_task = None |
|
self.response_queue = {} |
|
self.session_timeout = 3600 |
|
|
|
def is_expired(self): |
|
return time.time() - self.last_active > self.session_timeout |
|
|
|
def cleanup_expired_tasks(self): |
|
expired_tasks = [ |
|
task_id for task_id, response in self.response_queue.items() |
|
if response.is_expired() |
|
] |
|
for task_id in expired_tasks: |
|
del self.response_queue[task_id] |
|
if self.current_task == task_id: |
|
self.current_task = None |
|
|
|
class SessionManager: |
|
def __init__(self): |
|
self.sessions = {} |
|
self._lock = threading.Lock() |
|
self.crypto = WeChatCrypto(ENCODING_AES_KEY, APPID) |
|
|
|
def get_session(self, user_id): |
|
with self._lock: |
|
current_time = time.time() |
|
if user_id in self.sessions: |
|
session = self.sessions[user_id] |
|
if session.is_expired(): |
|
session = UserSession() |
|
else: |
|
session.cleanup_expired_tasks() |
|
else: |
|
session = UserSession() |
|
session.last_active = current_time |
|
self.sessions[user_id] = session |
|
return session |
|
|
|
def clear_session(self, user_id): |
|
with self._lock: |
|
if user_id in self.sessions: |
|
self.sessions[user_id] = UserSession() |
|
|
|
def cleanup_expired_sessions(self): |
|
with self._lock: |
|
current_time = time.time() |
|
expired_users = [ |
|
user_id for user_id, session in self.sessions.items() |
|
if session.is_expired() |
|
] |
|
for user_id in expired_users: |
|
del self.sessions[user_id] |
|
logging.info(f"已清理过期会话: {user_id}") |
|
|
|
session_manager = SessionManager() |
|
|
|
def convert_markdown_to_wechat(md_text): |
|
if not md_text: |
|
return md_text |
|
|
|
md_text = re.sub(r'^# (.*?)$', r'【标题】\1', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'^## (.*?)$', r'【子标题】\1', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'^### (.*?)$', r'【小标题】\1', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'\*\*(.*?)\*\*', r'『\1』', md_text) |
|
md_text = re.sub(r'\*(.*?)\*', r'「\1」', md_text) |
|
md_text = re.sub(r'`(.*?)`', r'「\1」', md_text) |
|
md_text = re.sub(r'^\- ', '• ', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'^\d\. ', '○ ', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'```[\w]*\n(.*?)```', r'【代码开始】\n\1\n【代码结束】', md_text, flags=re.DOTALL) |
|
md_text = re.sub(r'^> (.*?)$', r'▎\1', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'^-{3,}$', r'—————————', md_text, flags=re.MULTILINE) |
|
md_text = re.sub(r'\[(.*?)\]\((.*?)\)', r'\1(\2)', md_text) |
|
md_text = re.sub(r'\n{3,}', '\n\n', md_text) |
|
|
|
return md_text |
|
|
|
def verify_signature(signature, timestamp, nonce, token): |
|
items = [token, timestamp, nonce] |
|
items.sort() |
|
temp_str = ''.join(items) |
|
hash_sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest() |
|
return hash_sha1 == signature |
|
|
|
def verify_msg_signature(msg_signature, timestamp, nonce, token, encrypt_msg): |
|
items = [token, timestamp, nonce, encrypt_msg] |
|
items.sort() |
|
temp_str = ''.join(items) |
|
hash_sha1 = hashlib.sha1(temp_str.encode('utf-8')).hexdigest() |
|
return hash_sha1 == msg_signature |
|
|
|
def parse_xml_message(xml_content): |
|
root = ET.fromstring(xml_content) |
|
return { |
|
'content': root.find('Content').text if root.find('Content') is not None else '', |
|
'from_user': root.find('FromUserName').text, |
|
'to_user': root.find('ToUserName').text, |
|
'msg_id': root.find('MsgId').text if root.find('MsgId') is not None else '', |
|
'create_time': root.find('CreateTime').text, |
|
'msg_type': root.find('MsgType').text |
|
} |
|
|
|
def generate_response_xml(to_user, from_user, content, encrypt_type='aes'): |
|
formatted_content = convert_markdown_to_wechat(content) |
|
timestamp = str(int(time.time())) |
|
nonce = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) |
|
|
|
if encrypt_type == 'aes': |
|
xml_content = f''' |
|
<xml> |
|
<ToUserName><![CDATA[{to_user}]]></ToUserName> |
|
<FromUserName><![CDATA[{from_user}]]></FromUserName> |
|
<CreateTime>{timestamp}</CreateTime> |
|
<MsgType><![CDATA[text]]></MsgType> |
|
<Content><![CDATA[{formatted_content}]]></Content> |
|
</xml> |
|
''' |
|
|
|
|
|
encrypted = session_manager.crypto.encrypt(xml_content) |
|
|
|
|
|
signature_list = [TOKEN, timestamp, nonce, encrypted] |
|
signature_list.sort() |
|
msg_signature = hashlib.sha1(''.join(signature_list).encode('utf-8')).hexdigest() |
|
|
|
response_xml = f''' |
|
<xml> |
|
<Encrypt><![CDATA[{encrypted}]]></Encrypt> |
|
<MsgSignature><![CDATA[{msg_signature}]]></MsgSignature> |
|
<TimeStamp>{timestamp}</TimeStamp> |
|
<Nonce><![CDATA[{nonce}]]></Nonce> |
|
</xml> |
|
''' |
|
else: |
|
response_xml = f''' |
|
<xml> |
|
<ToUserName><![CDATA[{to_user}]]></ToUserName> |
|
<FromUserName><![CDATA[{from_user}]]></FromUserName> |
|
<CreateTime>{timestamp}</CreateTime> |
|
<MsgType><![CDATA[text]]></MsgType> |
|
<Content><![CDATA[{formatted_content}]]></Content> |
|
</xml> |
|
''' |
|
|
|
response = make_response(response_xml) |
|
response.content_type = 'application/xml' |
|
return response |
|
|
|
def process_long_running_task(messages): |
|
try: |
|
response = client.chat.completions.create( |
|
model="o3-mini", |
|
messages=messages, |
|
timeout=60 |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
logging.error(f"API调用错误: {str(e)}") |
|
raise |
|
|
|
def handle_async_task(session, task_id, messages): |
|
try: |
|
if task_id not in session.response_queue: |
|
return |
|
|
|
result = process_long_running_task(messages) |
|
|
|
if task_id in session.response_queue and not session.response_queue[task_id].is_expired(): |
|
session.response_queue[task_id].status = "completed" |
|
session.response_queue[task_id].result = result |
|
except Exception as e: |
|
if task_id in session.response_queue: |
|
session.response_queue[task_id].status = "failed" |
|
session.response_queue[task_id].error = str(e) |
|
|
|
def generate_initial_response(): |
|
return "您的请求正在处理中,请回复'查询'获取结果" |
|
|
|
def split_message(message, max_length=500): |
|
return [message[i:i+max_length] for i in range(0, len(message), max_length)] |
|
|
|
def append_status_message(content, has_pending_parts=False, is_processing=False): |
|
if "您的请求正在处理中" in content: |
|
return content + "\n\n-------------------\n发送'新对话'开始新的对话" |
|
|
|
status_message = "\n\n-------------------" |
|
if is_processing: |
|
status_message += "\n请回复'查询'获取结果" |
|
elif has_pending_parts: |
|
status_message += "\n当前消息已截断,发送'继续'查看后续内容" |
|
status_message += "\n发送'新对话'开始新的对话" |
|
return content + status_message |
|
|
|
@app.route('/api/wx', methods=['GET', 'POST']) |
|
def wechatai(): |
|
if request.method == 'GET': |
|
signature = request.args.get('signature') |
|
timestamp = request.args.get('timestamp') |
|
nonce = request.args.get('nonce') |
|
echostr = request.args.get('echostr') |
|
|
|
if verify_signature(signature, timestamp, nonce, TOKEN): |
|
return echostr |
|
return 'error', 403 |
|
|
|
try: |
|
encrypt_type = request.args.get('encrypt_type', '') |
|
|
|
if encrypt_type == 'aes': |
|
msg_signature = request.args.get('msg_signature') |
|
timestamp = request.args.get('timestamp') |
|
nonce = request.args.get('nonce') |
|
|
|
|
|
xml_tree = ET.fromstring(request.data) |
|
encrypted_text = xml_tree.find('Encrypt').text |
|
|
|
|
|
if not verify_msg_signature(msg_signature, timestamp, nonce, TOKEN, encrypted_text): |
|
return 'Invalid signature', 403 |
|
|
|
|
|
decrypted_xml = session_manager.crypto.decrypt(encrypted_text) |
|
message_data = parse_xml_message(decrypted_xml) |
|
else: |
|
message_data = parse_xml_message(request.data) |
|
|
|
user_content = message_data['content'].strip() |
|
from_user = message_data['from_user'] |
|
to_user = message_data['to_user'] |
|
|
|
logging.info(f"收到用户({from_user})消息: {user_content}") |
|
session = session_manager.get_session(from_user) |
|
|
|
if user_content == '新对话': |
|
session_manager.clear_session(from_user) |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message('已开始新的对话。请描述您的问题。'), |
|
encrypt_type |
|
) |
|
|
|
if user_content == '继续': |
|
if session.pending_parts: |
|
next_part = session.pending_parts.pop(0) |
|
has_more = bool(session.pending_parts) |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message(next_part, has_more), |
|
encrypt_type |
|
) |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message('没有更多内容了。请继续您的问题。'), |
|
encrypt_type |
|
) |
|
|
|
if user_content == '查询': |
|
if session.current_task: |
|
task_response = session.response_queue.get(session.current_task) |
|
if task_response: |
|
if task_response.is_expired(): |
|
del session.response_queue[session.current_task] |
|
session.current_task = None |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message('请求已过期,请重新提问。'), |
|
encrypt_type |
|
) |
|
|
|
if task_response.status == "completed": |
|
response = task_response.result |
|
del session.response_queue[session.current_task] |
|
session.current_task = None |
|
session.messages.append({"role": "assistant", "content": response}) |
|
|
|
if len(response) > 500: |
|
parts = split_message(response) |
|
first_part = parts.pop(0) |
|
session.pending_parts = parts |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message(first_part, True), |
|
encrypt_type |
|
) |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message(response), |
|
encrypt_type |
|
) |
|
elif task_response.status == "failed": |
|
error_message = '处理过程中出现错误,请重新提问。' |
|
del session.response_queue[session.current_task] |
|
session.current_task = None |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message(error_message), |
|
encrypt_type |
|
) |
|
else: |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message('正在处理中,请稍后再次查询。', is_processing=True), |
|
encrypt_type |
|
) |
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message('没有正在处理的请求。'), |
|
encrypt_type |
|
) |
|
|
|
session.messages.append({"role": "user", "content": user_content}) |
|
|
|
task_id = str(uuid.uuid4()) |
|
session.current_task = task_id |
|
session.response_queue[task_id] = AsyncResponse() |
|
|
|
executor.submit(handle_async_task, session, task_id, session.messages.copy()) |
|
|
|
return generate_response_xml( |
|
from_user, |
|
to_user, |
|
append_status_message(generate_initial_response(), is_processing=True), |
|
encrypt_type |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"处理请求时出错: {str(e)}") |
|
return generate_response_xml( |
|
message_data['from_user'], |
|
message_data['to_user'], |
|
append_status_message('抱歉,系统暂时出现问题,请稍后重试。'), |
|
encrypt_type if 'encrypt_type' in locals() else '' |
|
) |
|
|
|
def cleanup_sessions(): |
|
while True: |
|
time.sleep(3600) |
|
try: |
|
session_manager.cleanup_expired_sessions() |
|
except Exception as e: |
|
logging.error(f"清理会话时出错: {str(e)}") |
|
|
|
if __name__ == '__main__': |
|
cleanup_thread = threading.Thread(target=cleanup_sessions, daemon=True) |
|
cleanup_thread.start() |
|
|
|
app.run(host='0.0.0.0', port=7860, debug=True) |