|
import transformers
|
|
import torch
|
|
from typing import Optional, Tuple, Union
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput
|
|
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
|
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
|
|
|
|
|
class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
@property
|
|
def task_mapping(self):
|
|
return {v: k for k, v in self.config.task_mapping.items()}
|
|
|
|
@property
|
|
def named_task_mapping(self):
|
|
return self.config.task_mapping
|
|
|
|
def forward(
|
|
self,
|
|
input_features: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
decoder_position_ids: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
forced_ac_decoder_ids: Optional[
|
|
torch.LongTensor
|
|
] = None,
|
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
|
return super().forward(
|
|
input_features=input_features,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_position_ids=decoder_position_ids,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
head_mask=head_mask,
|
|
decoder_head_mask=decoder_head_mask,
|
|
cross_attn_head_mask=cross_attn_head_mask,
|
|
encoder_outputs=encoder_outputs,
|
|
past_key_values=past_key_values,
|
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
labels=labels,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|
|
def generate(
|
|
self,
|
|
inputs: Optional[torch.Tensor] = None,
|
|
forced_ac_decoder_ids: Optional[torch.Tensor] = None,
|
|
generation_config=None,
|
|
logits_processor=None,
|
|
stopping_criteria=None,
|
|
prefix_allowed_tokens_fn=None,
|
|
synced_gpus=False,
|
|
return_timestamps=None,
|
|
task="transcribe",
|
|
language="english",
|
|
**kwargs,
|
|
):
|
|
if generation_config is None:
|
|
generation_config = self.generation_config
|
|
|
|
if return_timestamps is not None:
|
|
if not hasattr(generation_config, "no_timestamps_token_id"):
|
|
raise ValueError(
|
|
"You are trying to return timestamps, but the generation config is not properly set."
|
|
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
|
|
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
|
)
|
|
|
|
generation_config.return_timestamps = return_timestamps
|
|
else:
|
|
generation_config.return_timestamps = False
|
|
|
|
if language is not None:
|
|
generation_config.language = language
|
|
if task is not None:
|
|
generation_config.task = task
|
|
|
|
forced_decoder_ids = []
|
|
if task is not None or language is not None:
|
|
if hasattr(generation_config, "language"):
|
|
if generation_config.language in generation_config.lang_to_id.keys():
|
|
language_token = generation_config.language
|
|
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
|
language_token = (
|
|
f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported language: {language}. Language should be one of:"
|
|
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
|
|
)
|
|
forced_decoder_ids.append(
|
|
(1, generation_config.lang_to_id[language_token])
|
|
)
|
|
else:
|
|
forced_decoder_ids.append(
|
|
(1, None)
|
|
)
|
|
|
|
if hasattr(generation_config, "task"):
|
|
if generation_config.task in TASK_IDS:
|
|
forced_decoder_ids.append(
|
|
(2, generation_config.task_to_id[generation_config.task])
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
|
)
|
|
else:
|
|
forced_decoder_ids.append(
|
|
(2, generation_config.task_to_id["transcribe"])
|
|
)
|
|
if (
|
|
hasattr(generation_config, "no_timestamps_token_id")
|
|
and not generation_config.return_timestamps
|
|
):
|
|
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
|
forced_decoder_ids.append(
|
|
(idx, generation_config.no_timestamps_token_id)
|
|
)
|
|
|
|
|
|
elif (
|
|
hasattr(self.config, "forced_decoder_ids")
|
|
and self.config.forced_decoder_ids is not None
|
|
):
|
|
forced_decoder_ids = self.config.forced_decoder_ids
|
|
elif (
|
|
hasattr(self.generation_config, "forced_decoder_ids")
|
|
and self.generation_config.forced_decoder_ids is not None
|
|
):
|
|
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
|
|
|
if generation_config.return_timestamps:
|
|
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
|
|
|
decoder_input_ids = None
|
|
|
|
if len(forced_decoder_ids) > 0:
|
|
|
|
forced_decoder_ids.sort()
|
|
if min(forced_decoder_ids)[0] != 0:
|
|
forced_decoder_ids = [
|
|
(0, self.config.decoder_start_token_id)
|
|
] + forced_decoder_ids
|
|
|
|
position_indices, decoder_input_ids = zip(*forced_decoder_ids)
|
|
assert tuple(position_indices) == tuple(
|
|
range(len(position_indices))
|
|
), "forced_decoder_ids is not a (continuous) prefix, we can't handle that"
|
|
|
|
device = self.get_decoder().device
|
|
|
|
if forced_ac_decoder_ids is None:
|
|
forced_ac_decoder_ids = torch.tensor(
|
|
[[]], device=device, dtype=torch.long
|
|
)
|
|
|
|
|
|
batch_size = forced_ac_decoder_ids.shape[0]
|
|
fluff_len = len(decoder_input_ids)
|
|
decoder_input_ids = torch.tensor(
|
|
decoder_input_ids, device=device, dtype=torch.long
|
|
)
|
|
decoder_input_ids = decoder_input_ids.expand((batch_size, fluff_len))
|
|
decoder_input_ids = torch.cat(
|
|
[decoder_input_ids, forced_ac_decoder_ids], dim=1
|
|
)
|
|
|
|
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
|
|
return super(
|
|
transformers.WhisperPreTrainedModel, self
|
|
).generate(
|
|
inputs,
|
|
generation_config,
|
|
logits_processor,
|
|
stopping_criteria,
|
|
prefix_allowed_tokens_fn,
|
|
synced_gpus,
|
|
decoder_input_ids=decoder_input_ids,
|
|
**kwargs,
|
|
)
|
|
|