|
from fastapi import FastAPI, Request, WebSocket |
|
from fastapi.responses import StreamingResponse |
|
import httpx |
|
import websockets |
|
import asyncio |
|
|
|
|
|
PROXY_PORT = 7860 |
|
|
|
TARGET_PORT = 6860 |
|
|
|
app = FastAPI() |
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) |
|
async def http_proxy(request: Request, path: str): |
|
|
|
target_url = f"http://localhost:{TARGET_PORT}/{path}" |
|
url_params = dict(request.query_params) |
|
|
|
|
|
async with httpx.AsyncClient() as client: |
|
|
|
response = await client.request( |
|
method=request.method, |
|
url=target_url, |
|
params=url_params, |
|
headers={key: value for key, value in request.headers.items() if key != "host"}, |
|
content=await request.body(), |
|
) |
|
|
|
|
|
return StreamingResponse( |
|
content=response.aiter_bytes(), |
|
status_code=response.status_code, |
|
headers=response.headers, |
|
) |
|
|
|
@app.websocket("/{path:path}") |
|
async def websocket_proxy(websocket: WebSocket, path: str): |
|
await websocket.accept() |
|
|
|
|
|
target_url = f"ws://localhost:{TARGET_PORT}/{path}" |
|
|
|
|
|
async with websockets.connect(target_url) as target_websocket: |
|
|
|
forward_task = asyncio.create_task(forward_messages(websocket, target_websocket)) |
|
reverse_task = asyncio.create_task(reverse_messages(websocket, target_websocket)) |
|
|
|
|
|
await asyncio.gather(forward_task, reverse_task) |
|
|
|
async def forward_messages(source: WebSocket, target: websockets.WebSocketClientProtocol): |
|
try: |
|
while True: |
|
message = await source.receive_text() |
|
await target.send(message) |
|
except websockets.ConnectionClosed: |
|
pass |
|
|
|
async def reverse_messages(source: WebSocket, target: websockets.WebSocketClientProtocol): |
|
try: |
|
while True: |
|
message = await target.recv() |
|
await source.send_text(message) |
|
except websockets.ConnectionClosed: |
|
pass |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=PROXY_PORT) |