Spaces:
Runtime error
Runtime error
r""" | |
Nucleus Sampling was introduced in the paper | |
`The Curious Case of Neural Text Degeneration <https://arxiv.org/abs/1904.09751>`_. | |
If you take it from here, make sure to cite them: | |
.. code-block:: text | |
@inproceedings{, | |
title={The Curious Case of Neural Text Degeneration}, | |
author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi}, | |
journal={ICLR}, | |
year={2020} | |
} | |
Some core parts of this code are adapted with minor modifications from Thomas Wolf's | |
gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
""" | |
from typing import Callable, List, Tuple | |
import torch | |
import torch.nn.functional as F | |
class AutoRegressiveNucleusSampling(object): | |
""" | |
Implements the nucleus sampling for decoding captions. This class only works | |
for auto-regressive models (Transformer-like), not recurrent models (LSTM-like). | |
Parameters | |
---------- | |
eos_index: int | |
The index of the end token (``[EOS]``) in vocabulary. | |
max_steps: int, optional (default = 50) | |
The maximum number of decoding steps. | |
nucleus_size: int, optional (default = 5) | |
Size of top-K nucleus for sampling. | |
""" | |
def __init__( | |
self, | |
eos_index: int, | |
max_steps: int = 50, | |
nucleus_size: float = 0.9, | |
): | |
super().__init__() | |
self._eos_index = eos_index | |
self.max_steps = max_steps | |
self.nucleus_size = nucleus_size | |
def search( | |
self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] | |
) -> Tuple[torch.Tensor, None]: | |
batch_size = start_predictions.size()[0] | |
# List of `(batch_size, )` tensors. One for each timestep. | |
# This includes the start-of-sentence tokens, unlike the implementation | |
# in `AutoregressiveBeamSearch`. We will remove them in the end. | |
# Transpose `start_predictions` and make a list when prompt is provided. | |
predictions = [ | |
start_predictions[:, i] for i in range(start_predictions.size(1)) | |
] | |
for timestep in range(self.max_steps): | |
# Get the predictions from last timestep (most recent). | |
# shape: (batch_size, ) | |
last_predictions = predictions[-1] | |
# If every predicted token from the last step is end-of-sentence token, | |
# then we can stop early. | |
if (last_predictions == self._eos_index).all(): | |
break | |
# Combine step predictions made so far into one tensor. This is our | |
# "partial" caption input to the transformer. | |
# shape: (batch_size, timestep + 1) | |
predictions_so_far = torch.stack(predictions).permute(1, 0) | |
# Take a step, get the distribution of logits from next timestep. | |
# shape: (batch_size, num_classes) | |
current_logits = step(predictions_so_far) | |
# Sort logits in descending order to determine the nucleus. | |
sorted_logits, sorted_idx = torch.sort(current_logits, descending=True) | |
# Get cumulative softmax probabilites. For every instance in batch, a | |
# variable amount of tokens (N) will consitute the nucleus. | |
# shape: (batch_size, num_classes) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
# Determine indices of tokens at the tail of distribution. These will be | |
# removed from the nucleus. | |
sorted_idx_to_remove = cumulative_probs > self.nucleus_size | |
# Shift the indices to the right to keep the first token outside nucleus. | |
sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() | |
sorted_idx_to_remove[..., 0] = 0 | |
# Set logits to large negative value to avoid sampling them. Iterate over | |
# the batch of examples. | |
for t in range(current_logits.size()[0]): | |
idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]] | |
current_logits[t][idx_to_remove] = -1e12 | |
# Set logits for last predicted token to a large negative value to | |
# avoid repetition. | |
current_logits[t][last_predictions[t]] = -1e12 | |
# Sample from the filtered distribution. | |
# shape: (batch_size, num_classes) | |
current_probs = F.softmax(current_logits, dim=-1) | |
# shape: (batch_size, ) | |
current_predictions = torch.multinomial(current_probs, 1) | |
current_predictions = current_predictions.view(batch_size) | |
# Set current predicted tokens to be end-of-sentence for instances where | |
# last prediction was also end-of-sentence token. | |
current_predictions[last_predictions == self._eos_index] = self._eos_index | |
predictions.append(current_predictions) | |
# Remove start-of-sentence token from predictions, and collect them together. | |
# shape: (batch_size, max_steps) .. or could be less than max_steps. | |
all_predictions = torch.stack(predictions[1:]).permute(1, 0) | |
# We don't return any logprobs of generated sequence with nucleus sampling, | |
# unlike `AutoregressiveBeamSearch`. | |
return all_predictions, None | |