Spaces:
Running
Running
File size: 4,839 Bytes
38f004a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from typing import Any, Union,List,Dict
import numpy as np
import torch
from dataclasses import dataclass
from transformers.feature_extraction_utils import BatchFeature
from .vits_output import VitsTextEncoderOutput
#.............................................
@dataclass
class DataCollatorTTSWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
tokenizer ([`VitsTokenizer`])
The tokenizer used for processing the data.
feature_extractor ([`VitsFeatureExtractor`])
The tokenizer used for processing the data.
forward_attention_mask (`bool`)
Whether to return attention_mask.
"""
tokenizer: Any
feature_extractor: Any
forward_attention_mask: bool
def pad_waveform(self, raw_speech):
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
batched_speech = BatchFeature({"input_features": raw_speech})
# convert into correct format for padding
padded_inputs = self.feature_extractor.pad(
batched_speech,
padding=True,
return_attention_mask=False,
return_tensors="pt",
)["input_features"]
return padded_inputs
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
model_input_name = "input_ids"
input_ids = [{model_input_name: feature[model_input_name][0]} for feature in features]
# pad input tokens
batch = self.tokenizer.pad(input_ids, return_tensors="pt", return_attention_mask=self.forward_attention_mask)
# pad waveform
waveforms = [np.array(feature["waveform"]) for feature in features]
batch["waveform"] = self.pad_waveform(waveforms)
# pad spectrogram
label_features = [np.array(feature["labels"]) for feature in features]
labels_batch = self.feature_extractor.pad(
{"input_features": [i.T for i in label_features]}, return_tensors="pt", return_attention_mask=True
)
labels = labels_batch["input_features"].transpose(1, 2)
batch["labels"] = labels
batch["labels_attention_mask"] = labels_batch["attention_mask"]
# pad mel spectrogram
mel_scaled_input_features = {
"input_features": [np.array(feature["mel_scaled_input_features"]).squeeze().T for feature in features]
}
mel_scaled_input_features = self.feature_extractor.pad(
mel_scaled_input_features, return_tensors="pt", return_attention_mask=True
)["input_features"].transpose(1, 2)
batch["mel_scaled_input_features"] = mel_scaled_input_features
batch["speaker_id"] = (
torch.tensor([feature["speaker_id"] for feature in features]) if "speaker_id" in features[0] else None
)
# text_encoder_output = [{
# 'last_hidden_state':torch.tensor(features["text_encoder_output"]['last_hidden_state']),
# 'prior_log_variances':torch.tensor(feature["text_encoder_output"]['prior_log_variances']),
# 'prior_means':torch.tensor(feature["text_encoder_output"]['prior_means']),
# } for feature in features]
batch['text_encoder_output'] = VitsTextEncoderOutput(
last_hidden_state=torch.tensor(features[0]["text_encoder_output"]['last_hidden_state']),
prior_means=torch.tensor(features[0]["text_encoder_output"]['prior_means']),
prior_log_variances=torch.tensor(features[0]["text_encoder_output"]['prior_log_variances']),
)
# print("DataColl ",batch.keys())
return batch
#............................................................................................. |