import time from tools.schema import ( ServeStreamDelta, ServeStreamResponse, ServeTextPart, ServeVQPart, ) def initialize_decode_buffers(num_samples): """Initialise the decode buffers for each sample.""" decode_buffer = [[] for _ in range(num_samples)] parts = [[] for _ in range(num_samples)] finished = [False for _ in range(num_samples)] return decode_buffer, parts, finished def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request): """Send the remaining text buffer for a sample.""" if len(decode_buffer[sample_id]) == 0: return [] decoded = tokenizer.decode(decode_buffer[sample_id]) part = ServeTextPart(text=decoded) responses = [] if request.streaming: responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part))) else: parts[sample_id].append(part) decode_buffer[sample_id] = [] return responses def handle_semantic_tokens(tokens, config, sample_id, parts, request): """Handle the semantic tokens returned by the model.""" responses = [] _tokens = tokens[1:].clone() if not config.share_codebook_embeddings: for i in range(len(_tokens)): _tokens[i] -= config.codebook_size * i # If streaming, send the VQ parts directly if request.streaming: responses.append( ServeStreamResponse( sample_id=sample_id, delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), ) ) else: # If not streaming, accumulate the VQ parts if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart): parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) else: # Accumulate the codes for codebook_id, value in enumerate(_tokens): parts[sample_id][-1].codes[codebook_id].append(value.item()) return responses def process_response_tokens( response, tokenizer, config, request, decode_buffer, parts, finished, im_end_id, stats, start, is_first_token, ): """Process the response tokens returned by the model.""" responses = [] for sample_id, tokens in enumerate(response): if finished[sample_id]: continue # End of the conversation if tokens[0] == im_end_id: finished[sample_id] = True # Send the remaining text buffer responses.extend( send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) ) if request.streaming: responses.append( ServeStreamResponse( sample_id=sample_id, finish_reason="stop", stats=stats, ) ) continue # Check if the token is semantic is_semantic = ( tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id ) if is_semantic: # Before the semantic tokens, send the remaining text buffer responses.extend( send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) ) responses.extend( handle_semantic_tokens(tokens, config, sample_id, parts, request) ) else: # Accumulate the text tokens (not implemented?) decode_buffer[sample_id].append(tokens[0, 0]) if is_first_token: stats["time_to_first_token"] = (time.time() - start) * 1000 return responses