from typing import List import torch import torch.nn as nn from torch import Tensor from transformers import Wav2Vec2Processor, Wav2Vec2Model SAMPLE_RATE = 16000 class UpstreamExpert(nn.Module): def __init__(self, ckpt: str = None, model_config: str = None, **kwargs): super().__init__() self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") def get_downsample_rates(self, key: str) -> int: return 320 def forward(self, wavs: List[Tensor]): wavs_silence = [] #Total 7 settings #original wavs_silence = wavs #front, 5 for wav in wavs: temp_wav = torch.zeros(len(wav)//5).to(wav.device) wavs_silence.append(torch.cat((temp_wav, wav))) #front, 10 for wav in wavs: temp_wav = torch.zeros(len(wav)//10).to(wav.device) wavs_silence.append(torch.cat((temp_wav, wav))) #front, 20 for wav in wavs: temp_wav = torch.zeros(len(wav)//20).to(wav.device) wavs_silence.append(torch.cat((temp_wav, wav))) #end, 5 for wav in wavs: temp_wav = torch.zeros(len(wav)//5).to(wav.device) wavs_silence.append(torch.cat((wav, temp_wav))) #end, 10 for wav in wavs: temp_wav = torch.zeros(len(wav)//10).to(wav.device) wavs_silence.append(torch.cat((wav, temp_wav))) #end, 20 for wav in wavs: temp_wav = torch.zeros(len(wav)//20).to(wav.device) wavs_silence.append(torch.cat((wav, temp_wav))) wavs = wavs_silence device = wavs[0].device processor_outputs = self.processor( [wav.cpu().numpy() for wav in wavs], return_tensors="pt", sampling_rate=SAMPLE_RATE, padding="longest", ) attention_mask = processor_outputs.get("attention_mask", None) if isinstance(attention_mask, torch.Tensor): attention_mask = attention_mask.to(device) model_outputs = self.model( processor_outputs.input_values.to(device), attention_mask=attention_mask, output_hidden_states=True, ) return { "last_hidden_state": model_outputs.last_hidden_state, "hidden_states": model_outputs.hidden_states, }