from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from fairseq2.assets.card import AssetCard from fairseq2.data import Collater from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter from fairseq2.data.text.text_tokenizer import TextTokenizer from fairseq2.data.typing import StringLike from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions from fairseq2.memory import MemoryBlock from fairseq2.typing import DataType, Device from torch import Tensor from enum import Enum, auto from seamless_communication.models.inference.ngram_repeat_block_processor import ( NGramRepeatBlockProcessor, ) from seamless_communication.models.unity import ( UnitTokenizer, UnitYGenerator, UnitYModel, load_unity_model, load_unity_text_tokenizer, load_unity_unit_tokenizer, ) from seamless_communication.models.unity.generator import SequenceToUnitOutput from seamless_communication.models.vocoder import load_vocoder_model, Vocoder # from seamless_communication.models.streaming.agents import ( # SileroVADAgent, # TestTimeWaitKS2TVAD, # TestTimeWaitKUnityV1M4T # ) from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import ( TestTimeWaitKUnityS2TM4T, ) from seamless_communication.cli.streaming.dataloader import Fairseq2SpeechToTextDataloader ### From test_pipeline import math import soundfile from argparse import Namespace, ArgumentParser from simuleval.data.segments import SpeechSegment, EmptySegment from simuleval.utils import build_system_from_dir from pathlib import Path import numpy as np class AudioFrontEnd: def __init__(self, wav_file, segment_size) -> None: self.samples, self.sample_rate = soundfile.read(wav_file) # print(len(self.samples), self.samples[:100]) self.samples = self.samples.tolist() self.segment_size = segment_size self.step = 0 def send_segment(self): """ This is the front-end logic in simuleval instance.py """ num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate) print("self.segment_size", self.segment_size) print('num_samples is', num_samples) print('self.sample_rate is', self.sample_rate) if self.step < len(self.samples): if self.step + num_samples >= len(self.samples): samples = self.samples[self.step :] is_finished = True else: samples = self.samples[self.step : self.step + num_samples] is_finished = False self.step = min(self.step + num_samples, len(self.samples)) # print("len(samples) is", len(samples)) # import pdb # pdb.set_trace() segment = SpeechSegment( index=self.step / self.sample_rate * 1000, content=samples, sample_rate=self.sample_rate, finished=is_finished, ) else: # Finish reading this audio segment = EmptySegment( index=self.step / self.sample_rate * 1000, finished=True, ) return segment def load_model_for_inference( load_model_fn: Callable[..., nn.Module], model_name_or_card: Union[str, AssetCard], device: Device, dtype: DataType, ) -> nn.Module: model = load_model_fn(model_name_or_card, device=device, dtype=dtype) model.eval() return model def load_model_fairseq2(): data_configs = dict( dataloader="fairseq2_s2t", data_file="/large_experiments/seamless/ust/abinesh/data/s2st50_manifests/50-10/simuleval/dev_mtedx_filt_50-10_debug.tsv", ) model_configs = dict( model_name="seamlessM4T_v2_large", device="cuda:0", source_segment_size=320, waitk_lagging=7, fixed_pre_decision_ratio=2, init_target_tokens=" __eng__", max_len_a=0, max_len_b=200, agent_class="seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t.TestTimeWaitKUnityS2TM4T", task="s2st", tgt_lang="eng", ) eval_configs = dict( latency_metrics="StartOffset EndOffset AL", output=f"{TestTimeWaitKUnityS2TM4T.__name__}-wait{model_configs['waitk_lagging']}-debug", ) model = TestTimeWaitKUnityS2TM4T({**data_configs, **model_configs, **eval_configs}) print("model", model) evaluate( TestTimeWaitKUnityS2TM4T, {**data_configs, **model_configs, **eval_configs} ) class SimulevalTranscoder: # def __init__(self, agent, sample_rate, debug, buffer_limit): def __init__(self): # print("MDUPPES in here", SileroVADAgent, TestTimeWaitKS2TVAD) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") load_model_fairseq2() device = "cpu" print("DEVICE", device) model_name_or_card="seamlessM4T_medium" vocoder_name_or_card="vocoder_36langs" # dtype=torch.float16, # For CPU Mode need to use 32, float16 causes errors downstream dtype=dtype=torch.float32 model: UnitYModel = load_model_for_inference( load_unity_model, model_name_or_card, device, dtype ) print(model, type(model)) parser = ArgumentParser() source_segment_size = 320 # milliseconds audio_frontend = AudioFrontEnd( wav_file="/checkpoint/mduppes/samples/marta.wav", segment_size=source_segment_size, ) # mostly taken from S2S first agent: OnlineFeatureExtractorAgent defaults SHIFT_SIZE = 10 WINDOW_SIZE = 25 SAMPLE_RATE = 16000 FEATURE_DIM = 80 # args and convert to namespace so it can be accesed via . args = { "shift_size": SHIFT_SIZE, "window_size": WINDOW_SIZE, "sample_rate": audio_frontend.sample_rate, "feature_dim": 160, # from Wav2Vec2Frontend "denormalize": False, # not sure.. "global_stats": None, # default file path containing cmvn stats.. } print(args) args = Namespace(**args) pipeline = TestTimeWaitKUnityV1M4T(model, args) system_states = pipeline.build_states() print('system states:') for state in system_states: print(state, vars(state)) input_segment = np.empty(0, dtype=np.int16) segments = [] while True: speech_segment = audio_frontend.send_segment() input_segment = np.concatenate((input_segment, np.array(speech_segment.content))) # Translation happens here output_segment = pipeline.pushpop(speech_segment, system_states) print('pushpop result') print(output_segment) print('system states after pushpop:') for state in system_states: print(state, vars(state)) if output_segment.finished: segments.append(input_segment) input_segment = np.empty(0, dtype=np.int16) print("Resetting states") for state in system_states: state.reset() if speech_segment.finished: break # The VAD-segmented samples from the full input audio for i, seg in enumerate(segments): with soundfile.SoundFile( Path("/checkpoint/mduppes/samples") / f"marta_{i}.wav", mode="w+", format="WAV", samplerate=16000, channels=1, ) as f: f.seek(0, soundfile.SEEK_END) f.write(seg)