import queue from fish_speech.conversation import Conversation, Message from fish_speech.tokenizer import IM_END_TOKEN from tools.llama.generate import GenerateRequest def prepare_messages(request, tokenizer, config): """ Reorganise the provided list of messages into a conversation. Encode the conversation for inference. """ # Convert the messages to ConversationMessage objects messages = [msg.to_conversation_message() for msg in request.messages] if len(messages) < 1: raise ValueError("At least one message is required") # Check the last message to determine the next step last_role = messages[-1].role match last_role: case "user": # The last message is from the user, ask the assistant to respond with a new message messages.append( Message(role="assistant", parts=[], add_im_end=False, modality="voice") ) case "raw": # The last message is raw text, ask the assistant to complete it messages[-1].add_im_start = False messages[-1].add_im_end = False messages[-1].modality = "voice" case "assistant": # The last message is from the assistant, ask the assistant to continue messages[-1].add_im_end = False case _: # We expect it to be assistant if not user or raw raise ValueError("The last message must be from the assistant, user or raw") # Create a conversation object and encode it for inference conv = Conversation(messages=messages) prompt = conv.encode_for_inference( tokenizer=tokenizer, num_codebooks=config.num_codebooks ) im_end_id = tokenizer.get_token_id(IM_END_TOKEN) return prompt, im_end_id def create_generation_request(prompt, request, im_end_id, device): """ Convert the request into a dictionary that can be sent to the model for generation. """ req = { "prompt": prompt.to(device), "max_new_tokens": request.max_new_tokens, "im_end_id": im_end_id, "temperature": request.temperature, "top_p": request.top_p, "repetition_penalty": request.repetition_penalty, "num_samples": request.num_samples, "early_stop_threshold": request.early_stop_threshold, } return req def send_generation_request(input_queue, req): """ Send the generation request to the model and return a queue to get the response. """ response_queue = queue.Queue() input_queue.put(GenerateRequest(req, response_queue)) return response_queue