Spaces:
Running
Running
from typing import Any, Union,List,Dict | |
import numpy as np | |
import torch | |
from dataclasses import dataclass | |
from transformers.feature_extraction_utils import BatchFeature | |
from .vits_output import VitsTextEncoderOutput | |
#............................................. | |
class DataCollatorTTSWithPadding: | |
""" | |
Data collator that will dynamically pad the inputs received. | |
Args: | |
tokenizer ([`VitsTokenizer`]) | |
The tokenizer used for processing the data. | |
feature_extractor ([`VitsFeatureExtractor`]) | |
The tokenizer used for processing the data. | |
forward_attention_mask (`bool`) | |
Whether to return attention_mask. | |
""" | |
tokenizer: Any | |
feature_extractor: Any | |
forward_attention_mask: bool | |
def pad_waveform(self, raw_speech): | |
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 | |
if is_batched_numpy and len(raw_speech.shape) > 2: | |
raise ValueError(f"Only mono-channel audio is supported for input to {self}") | |
is_batched = is_batched_numpy or ( | |
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) | |
) | |
if is_batched: | |
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] | |
elif not is_batched and not isinstance(raw_speech, np.ndarray): | |
raw_speech = np.asarray(raw_speech, dtype=np.float32) | |
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): | |
raw_speech = raw_speech.astype(np.float32) | |
# always return batch | |
if not is_batched: | |
raw_speech = [np.asarray([raw_speech]).T] | |
batched_speech = BatchFeature({"input_features": raw_speech}) | |
# convert into correct format for padding | |
padded_inputs = self.feature_extractor.pad( | |
batched_speech, | |
padding=True, | |
return_attention_mask=False, | |
return_tensors="pt", | |
)["input_features"] | |
return padded_inputs | |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: | |
# split inputs and labels since they have to be of different lengths and need | |
# different padding methods | |
model_input_name = "input_ids" | |
input_ids = [{model_input_name: feature[model_input_name][0]} for feature in features] | |
# pad input tokens | |
batch = self.tokenizer.pad(input_ids, return_tensors="pt", return_attention_mask=self.forward_attention_mask) | |
# pad waveform | |
waveforms = [np.array(feature["waveform"]) for feature in features] | |
batch["waveform"] = self.pad_waveform(waveforms) | |
# pad spectrogram | |
label_features = [np.array(feature["labels"]) for feature in features] | |
labels_batch = self.feature_extractor.pad( | |
{"input_features": [i.T for i in label_features]}, return_tensors="pt", return_attention_mask=True | |
) | |
labels = labels_batch["input_features"].transpose(1, 2) | |
batch["labels"] = labels | |
batch["labels_attention_mask"] = labels_batch["attention_mask"] | |
# pad mel spectrogram | |
mel_scaled_input_features = { | |
"input_features": [np.array(feature["mel_scaled_input_features"]).squeeze().T for feature in features] | |
} | |
mel_scaled_input_features = self.feature_extractor.pad( | |
mel_scaled_input_features, return_tensors="pt", return_attention_mask=True | |
)["input_features"].transpose(1, 2) | |
batch["mel_scaled_input_features"] = mel_scaled_input_features | |
batch["speaker_id"] = ( | |
torch.tensor([feature["speaker_id"] for feature in features]) if "speaker_id" in features[0] else None | |
) | |
# text_encoder_output = [{ | |
# 'last_hidden_state':torch.tensor(features["text_encoder_output"]['last_hidden_state']), | |
# 'prior_log_variances':torch.tensor(feature["text_encoder_output"]['prior_log_variances']), | |
# 'prior_means':torch.tensor(feature["text_encoder_output"]['prior_means']), | |
# } for feature in features] | |
batch['text_encoder_output'] = VitsTextEncoderOutput( | |
last_hidden_state=torch.tensor(features[0]["text_encoder_output"]['last_hidden_state']), | |
prior_means=torch.tensor(features[0]["text_encoder_output"]['prior_means']), | |
prior_log_variances=torch.tensor(features[0]["text_encoder_output"]['prior_log_variances']), | |
) | |
# print("DataColl ",batch.keys()) | |
return batch | |
#............................................................................................. |