import struct from functools import partial import ormsgpack from tools.server.agent.generate import generate_responses from tools.server.agent.pre_generation_utils import prepare_messages def execute_request(input_queue, tokenizer, config, request, device): """ This function prepares the conversation, encodes the request, sends the generation request, and handles decoding/streaming. It returns a response generator (ServeResponse or ServeStreamResponse). """ prompt, im_end_id = prepare_messages(request, tokenizer, config) yield from generate_responses( input_queue, tokenizer, config, request, prompt, im_end_id, device ) def response_generator(req, llama_queue, tokenizer, config, device): """ Non-streaming response wrapper for the chat endpoint. Only returns the final result. """ generator = execute_request(llama_queue, tokenizer, config, req, device) return next(generator) async def streaming_generator(req, llama_queue, tokenizer, config, device, json_mode): """ Streaming response wrapper for the chat endpoint. Returns the response in chunks. """ generator = execute_request(llama_queue, tokenizer, config, req, device) for i in generator: if json_mode: body = i.model_dump_json().encode("utf-8") yield b"data: " + body + b"\n\n" else: body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC) yield struct.pack("I", len(body)) + body def get_response_generator( llama_queue, tokenizer, config, req, device, json_mode ) -> partial: """ Get the correct response generator based on the request. """ if not req.streaming: return partial(response_generator, req, llama_queue, tokenizer, config, device) else: return partial( streaming_generator, req, llama_queue, tokenizer, config, device, json_mode )