import time from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse from tools.server.agent.generation_utils import ( initialize_decode_buffers, process_response_tokens, send_reset_buffer, ) from tools.server.agent.pre_generation_utils import ( create_generation_request, send_generation_request, ) def generate_responses( input_queue, tokenizer, config, request, prompt, im_end_id, device ): """ Main generation function that handles the conversation, encodes the request, sends the generation request, and handles decoding/streaming. It returns a response generator (ServeResponse or ServeStreamResponse). """ stats = {} start = time.time() stats["start_time"] = start stats["tokens_count"] = 0 # Prepare and send the generation request req = create_generation_request(prompt, request, im_end_id, device) response_queue = send_generation_request(input_queue, req) decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples) while True: response = response_queue.get() # Handle abnormal finish or error if response in ["stop", "error"]: finish_reason = response break # Process the response tokens is_first_token = stats["tokens_count"] == 0 responses = process_response_tokens( response, tokenizer, config, request, decode_buffer, parts, finished, im_end_id, stats, start, is_first_token, ) # Yield the responses if streaming if request.streaming and responses: for r in responses: yield r stats["tokens_count"] += 1 # Check if all samples are finished if all(finished): finish_reason = "stop" break # Finalize the response final_responses = finalize_response( request, finished, decode_buffer, tokenizer, parts, stats, finish_reason ) for fr in final_responses: yield fr def finalize_response( request, finished, decode_buffer, tokenizer, parts, stats, finish_reason ): """ Finalize the response by sending the remaining text buffers. """ responses = [] # Send the remaining text buffers for sample_id in range(request.num_samples): responses.extend( send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) ) # Calculate the final stats stats["total_time"] = (time.time() - stats["start_time"]) * 1000 stats["total_tokens"] = stats["tokens_count"] # If streaming, send the final chunks for each sample if request.streaming: for sample_id in range(request.num_samples): if finished[sample_id]: continue responses.append( ServeStreamResponse( finish_reason=finish_reason, stats=stats, sample_id=sample_id ) ) else: # If not streaming, send the full messages for each sample full_messages = [ ServeMessage(role="assistant", parts=parts[i]) for i in range(request.num_samples) ] responses.append( ServeResponse( messages=full_messages, finish_reason=finish_reason, stats=stats, ) ) return responses