gordonchan's picture
Upload 41 files
ca56e6a verified
from functools import partial
from typing import Iterator
import anyio
from fastapi import APIRouter, Depends, Request, HTTPException
from loguru import logger
from sse_starlette import EventSourceResponse
from starlette.concurrency import run_in_threadpool
from api.core.default import DefaultEngine
from api.models import GENERATE_ENGINE
from api.utils.compat import model_dump
from api.utils.protocol import ChatCompletionCreateParams, Role
from api.utils.request import (
handle_request,
check_api_key,
get_event_publisher,
)
chat_router = APIRouter(prefix="/chat")
def get_engine():
yield GENERATE_ENGINE
@chat_router.post("/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(
request: ChatCompletionCreateParams,
raw_request: Request,
engine: DefaultEngine = Depends(get_engine),
):
"""Creates a completion for the chat message"""
if (not request.messages) or request.messages[-1]["role"] == Role.ASSISTANT:
raise HTTPException(status_code=400, detail="Invalid request")
request = await handle_request(request, engine.stop)
request.max_tokens = request.max_tokens or 1024
params = model_dump(request, exclude={"messages"})
params.update(dict(prompt_or_messages=request.messages, echo=False))
logger.debug(f"==== request ====\n{params}")
iterator_or_completion = await run_in_threadpool(engine.create_chat_completion, params)
if isinstance(iterator_or_completion, Iterator):
# It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
# If no exception was raised from first_response, we can assume that
# the iterator is valid, and we can use it to stream the response.
def iterator() -> Iterator:
yield first_response
yield from iterator_or_completion
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator(),
),
)
else:
return iterator_or_completion