File size: 2,415 Bytes
fc3e4f0 6e42836 fc3e4f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from fastapi import FastAPI, Request, WebSocket
from fastapi.responses import StreamingResponse
import httpx
import websockets
import asyncio
# 定义反向代理的端口
PROXY_PORT = 7860
# 定义目标服务器的端口
TARGET_PORT = 6860
app = FastAPI()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"])
async def http_proxy(request: Request, path: str):
# 构建目标URL
target_url = f"http://localhost:{TARGET_PORT}/{path}"
url_params = dict(request.query_params)
# 创建一个异步HTTP客户端
async with httpx.AsyncClient() as client:
# 转发请求到目标服务器
response = await client.request(
method=request.method,
url=target_url,
params=url_params,
headers={key: value for key, value in request.headers.items() if key != "host"},
content=await request.body(),
)
# 返回目标服务器的响应
return StreamingResponse(
content=response.aiter_bytes(),
status_code=response.status_code,
headers=response.headers,
)
@app.websocket("/{path:path}")
async def websocket_proxy(websocket: WebSocket, path: str):
await websocket.accept()
# 构建目标WebSocket URL
target_url = f"ws://localhost:{TARGET_PORT}/{path}"
# 连接到目标WebSocket服务器
async with websockets.connect(target_url) as target_websocket:
# 创建两个任务来转发消息
forward_task = asyncio.create_task(forward_messages(websocket, target_websocket))
reverse_task = asyncio.create_task(reverse_messages(websocket, target_websocket))
# 等待两个任务完成
await asyncio.gather(forward_task, reverse_task)
async def forward_messages(source: WebSocket, target: websockets.WebSocketClientProtocol):
try:
while True:
message = await source.receive_text()
await target.send(message)
except websockets.ConnectionClosed:
pass
async def reverse_messages(source: WebSocket, target: websockets.WebSocketClientProtocol):
try:
while True:
message = await target.recv()
await source.send_text(message)
except websockets.ConnectionClosed:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PROXY_PORT) |