kiylu's picture
add project files
b128c76
import time
from tools.schema import ServeMessage, ServeResponse, ServeStreamResponse
from tools.server.agent.generation_utils import (
initialize_decode_buffers,
process_response_tokens,
send_reset_buffer,
)
from tools.server.agent.pre_generation_utils import (
create_generation_request,
send_generation_request,
)
def generate_responses(
input_queue, tokenizer, config, request, prompt, im_end_id, device
):
"""
Main generation function that handles the conversation, encodes the request,
sends the generation request, and handles decoding/streaming.
It returns a response generator (ServeResponse or ServeStreamResponse).
"""
stats = {}
start = time.time()
stats["start_time"] = start
stats["tokens_count"] = 0
# Prepare and send the generation request
req = create_generation_request(prompt, request, im_end_id, device)
response_queue = send_generation_request(input_queue, req)
decode_buffer, parts, finished = initialize_decode_buffers(request.num_samples)
while True:
response = response_queue.get()
# Handle abnormal finish or error
if response in ["stop", "error"]:
finish_reason = response
break
# Process the response tokens
is_first_token = stats["tokens_count"] == 0
responses = process_response_tokens(
response,
tokenizer,
config,
request,
decode_buffer,
parts,
finished,
im_end_id,
stats,
start,
is_first_token,
)
# Yield the responses if streaming
if request.streaming and responses:
for r in responses:
yield r
stats["tokens_count"] += 1
# Check if all samples are finished
if all(finished):
finish_reason = "stop"
break
# Finalize the response
final_responses = finalize_response(
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
)
for fr in final_responses:
yield fr
def finalize_response(
request, finished, decode_buffer, tokenizer, parts, stats, finish_reason
):
"""
Finalize the response by sending the remaining text buffers.
"""
responses = []
# Send the remaining text buffers
for sample_id in range(request.num_samples):
responses.extend(
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request)
)
# Calculate the final stats
stats["total_time"] = (time.time() - stats["start_time"]) * 1000
stats["total_tokens"] = stats["tokens_count"]
# If streaming, send the final chunks for each sample
if request.streaming:
for sample_id in range(request.num_samples):
if finished[sample_id]:
continue
responses.append(
ServeStreamResponse(
finish_reason=finish_reason, stats=stats, sample_id=sample_id
)
)
else:
# If not streaming, send the full messages for each sample
full_messages = [
ServeMessage(role="assistant", parts=parts[i])
for i in range(request.num_samples)
]
responses.append(
ServeResponse(
messages=full_messages,
finish_reason=finish_reason,
stats=stats,
)
)
return responses