Spaces:
Running
Running
import time | |
from tools.schema import ( | |
ServeStreamDelta, | |
ServeStreamResponse, | |
ServeTextPart, | |
ServeVQPart, | |
) | |
def initialize_decode_buffers(num_samples): | |
"""Initialise the decode buffers for each sample.""" | |
decode_buffer = [[] for _ in range(num_samples)] | |
parts = [[] for _ in range(num_samples)] | |
finished = [False for _ in range(num_samples)] | |
return decode_buffer, parts, finished | |
def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request): | |
"""Send the remaining text buffer for a sample.""" | |
if len(decode_buffer[sample_id]) == 0: | |
return [] | |
decoded = tokenizer.decode(decode_buffer[sample_id]) | |
part = ServeTextPart(text=decoded) | |
responses = [] | |
if request.streaming: | |
responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part))) | |
else: | |
parts[sample_id].append(part) | |
decode_buffer[sample_id] = [] | |
return responses | |
def handle_semantic_tokens(tokens, config, sample_id, parts, request): | |
"""Handle the semantic tokens returned by the model.""" | |
responses = [] | |
_tokens = tokens[1:].clone() | |
if not config.share_codebook_embeddings: | |
for i in range(len(_tokens)): | |
_tokens[i] -= config.codebook_size * i | |
# If streaming, send the VQ parts directly | |
if request.streaming: | |
responses.append( | |
ServeStreamResponse( | |
sample_id=sample_id, | |
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), | |
) | |
) | |
else: | |
# If not streaming, accumulate the VQ parts | |
if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart): | |
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) | |
else: | |
# Accumulate the codes | |
for codebook_id, value in enumerate(_tokens): | |
parts[sample_id][-1].codes[codebook_id].append(value.item()) | |
return responses | |
def process_response_tokens( | |
response, | |
tokenizer, | |
config, | |
request, | |
decode_buffer, | |
parts, | |
finished, | |
im_end_id, | |
stats, | |
start, | |
is_first_token, | |
): | |
"""Process the response tokens returned by the model.""" | |
responses = [] | |
for sample_id, tokens in enumerate(response): | |
if finished[sample_id]: | |
continue | |
# End of the conversation | |
if tokens[0] == im_end_id: | |
finished[sample_id] = True | |
# Send the remaining text buffer | |
responses.extend( | |
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) | |
) | |
if request.streaming: | |
responses.append( | |
ServeStreamResponse( | |
sample_id=sample_id, | |
finish_reason="stop", | |
stats=stats, | |
) | |
) | |
continue | |
# Check if the token is semantic | |
is_semantic = ( | |
tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id | |
) | |
if is_semantic: | |
# Before the semantic tokens, send the remaining text buffer | |
responses.extend( | |
send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) | |
) | |
responses.extend( | |
handle_semantic_tokens(tokens, config, sample_id, parts, request) | |
) | |
else: | |
# Accumulate the text tokens (not implemented?) | |
decode_buffer[sample_id].append(tokens[0, 0]) | |
if is_first_token: | |
stats["time_to_first_token"] = (time.time() - start) * 1000 | |
return responses | |