Spaces:
Runtime error
Runtime error
File size: 1,912 Bytes
12ea223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import logging
from typing import List
import torch
from transformers import (
LogitsProcessor,
)
class StopAfterTokenIsGenerated(LogitsProcessor):
def __init__(self, stops: List[torch.tensor], eos_token_id: int):
super().__init__()
self.stops = stops
self.eos_token_id = eos_token_id
logging.info(f"Stopping criteria words ids: {self.stops}")
self.first_batch = True
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""
if self.first_batch:
self.first_batch = False
return scores
for seq_no, seq in enumerate(input_ids):
# logging.info(seq_no)
for stop in self.stops:
stop = stop.to(device=seq.device, dtype=seq.dtype)
if (
len(seq) >= len(stop)
and torch.all((stop == seq[-len(stop) :])).item()
):
scores[seq_no, :] = -float("inf")
scores[seq_no, self.eos_token_id] = 0
# logging.info(f"Stopping criteria found: {stop}")
break
return scores
def reset(self):
self.first_batch = True
|