Text-to-Audio / utils /topk_sampling.py
yuancwang
init
5548515
raw
history blame
3.83 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
# This function is modified from https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits (torch.Tensor): Logits distribution with shape (batch size, vocabulary size).
top_k (int, optional): Keep only top k tokens with highest probability (top-k filtering).
Set to 0 to disable. Defaults to 0.
top_p (float, optional): Keep the top tokens with a cumulative probability >= top_p (nucleus filtering).
Must be between 0 and 1, inclusive. Defaults to 1.0.
filter_value (float, optional): The value to assign to filtered logits. Defaults to -float('Inf').
min_tokens_to_keep (int, optional): Ensure that at least this number of tokens are kept per batch example.
Defaults to 1.
Returns:
torch.Tensor: The filtered logits.
"""
"""
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
# Apply top-k filtering
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k).values[..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
# Apply top-p filtering
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask to remove tokens with cumulative probability above the top_p threshold
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Scatter sorted tensors back to original indexing
indices_to_remove = sorted_indices.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0):
"""
Perform top-k and top-p sampling on logits.
Args:
logits (torch.Tensor): The logits to sample from.
top_k (int, optional): The number of highest probability tokens to keep for top-k filtering.
Must be a positive integer. Defaults to 50.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
Must be between 0 and 1. Defaults to 1.0.
temperature (float, optional): The scaling factor to adjust the logits distribution.
Must be strictly positive. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
# Adjust logits using temperature
if temperature != 1.0:
logits = logits / temperature
# Top-p/top-k filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample from the filtered distribution
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return token