from utils import * import datetime from pydub import AudioSegment, effects def normalizeAudio(file, format): #https://stackoverflow.com/questions/42492246/how-to-normalize-the-volume-of-an-audio-file-in-python rawsound = AudioSegment.from_file(file, format) normalizedsound = effects.normalize(rawsound) timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") output_file = f"normalized_{timestamp}.wav" normalizedsound.export(output_file, format="wav") return output_file def mp3_to_wav(mp3_path, tag): # Load the MP3 file audio = AudioSegment.from_mp3(mp3_path) outfile = mp3_path.split(".")[0] + tag +".wav" # Export the audio in WAV format audio.export(outfile, format="wav") return outfile def stereo_to_mono(wav_path): # Load the stereo WAV file audio = AudioSegment.from_wav(wav_path) # Convert to mono audio_mono = audio.set_channels(1) # Export the audio in WAV format audio_mono.export(wav_path, format="wav") return wav_path def cutaudio(audiopath, start_time, end_time): audio = AudioSegment.from_wav(audiopath)[start_time:end_time] exportname = str(start_time)+"_"+str(end_time)+".wav" audio.export(exportname, format="wav") return exportname def compose_audio(audio_files, timestamps, output_file): # Example usage: # audio_files = ["clip1.wav", "clip2.wav", "clip3.wav"] # timestamps = [0, 5000, 10000, 15] # clip1 starts at 0s, clip2 at 5s, and clip3 at 10s; audio ends at 15s # output_file = "composed_audio.wav" # compose_audio(audio_files, timestamps, output_file) # Check if lengths are consistent if len(audio_files) != len(timestamps) - 1: raise ValueError("Number of timestamps should be one more than number of audio files") # Load the first audio file final_audio = AudioSegment.silent(duration=timestamps[0]) for i, audio_file in enumerate(audio_files): # Load the audio clip clip = AudioSegment.from_wav(audio_file) # Change this if you're using a different format # Calculate the amount of silence needed before the clip silence_duration = (timestamps[i + 1] - timestamps[i] - len(clip) ) # in milliseconds if silence_duration < 0: print(f"Warning: Clip {audio_file} is longer than the gap between timestamps {i} and {i + 1}. Trimming the audio.") clip = clip[:timestamps[i + 1] - timestamps[i]] # Trim the clip silence_duration = 0 final_audio += clip + AudioSegment.silent(duration=silence_duration) # Export final audio #timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") #output_file_time = f"{output_file}_{timestamp}.wav" final_audio.export(output_file, format="wav") return output_file def append_wav_files(filenames, output_filename): # Load the first WAV file combined = AudioSegment.from_wav(filenames[0]) # Load each subsequent WAV file and append to the combined segment for filename in filenames[1:]: audio = AudioSegment.from_wav(filename) combined += audio # Export the combined audio combined.export(output_filename, format="wav") return output_filename # def generateAudio(respuesta, elabs_key): # user = ElevenLabsUser(elabs_key) # premadeVoice = user.get_voices_by_name("Rachel")[0] # playbackOptions = PlaybackOptions(runInBackground=False) # generationOptions = GenerationOptions(model_id="eleven_multilingual_v1", stability=0.3, similarity_boost=0.7, style=0.6, #eleven_english_v2 # use_speaker_boost=True) # audioData, historyID = premadeVoice.generate_audio_v2(respuesta, generationOptions) # #generationData = premadeVoice.generate_play_audio_v2(text, PlaybackOptions(runInBackground=False), GenerationOptions(stability=0.4)) # filename = "output.wav" # #Save them to disk, in ogg format (can be any format supported by SoundFile) # save_audio_bytes(audioData, filename, outputFormat="wav") # return filename def overlay_audios(audio_paths, output_file): # Load all the audios audios = [AudioSegment.from_wav(path) for path in audio_paths] # assuming WAV format # Find the length of the longest audio max_length = max(audio.duration_seconds for audio in audios) # Pad all audios to the length of the longest one padded_audios = [audio + AudioSegment.silent(duration=(max_length - audio.duration_seconds) * 1000) for audio in audios] # Start with the first padded audio overlay_audio = padded_audios[0] # Overlay the rest of the audios on top for audio in padded_audios[1:]: overlay_audio = overlay_audio.overlay(audio) overlay_audio.export(output_file, format="wav") return output_file def total_duration(audiofile): audiofile = Path(audiofile) format = audiofile.suffix.replace(".","") song = AudioSegment.from_file(audiofile, format=format) #song = load_audio_segment(audiofile, audiofile.split(".")[-1]) n_msecs = len(song) return n_msecs ########################################################################### def separateVoiceInstrumental(audiofile): audiofile = Path(audiofile) filename = audiofile.stem format = audiofile.suffix.replace(".","") song = AudioSegment.from_file(audiofile, format=format) #song = load_audio_segment(audiofile, audiofile.split(".")[-1]) n_secs = round(len(song) / 1000) start_time = 0 end_time = n_secs model_name, file_sources = ("htdemucs", ["vocals.mp3", "no_vocals.mp3"]) out_path = Path("output") stem = "vocals" separator( tracks=[audiofile], out=out_path, model=model_name, shifts=1, overlap=0.5, stem=stem, int24=False, float32=False, clip_mode="rescale", mp3=True, mp3_bitrate=320, verbose=True, start_time=start_time, end_time=end_time, ) instrumentalFile = f"output/htdemucs/{filename}/no_vocals.mp3" voiceFile = f"output/htdemucs/{filename}/vocals.mp3" return instrumentalFile, voiceFile ################################################################################ import argparse import sys from pathlib import Path from typing import List import os from dora.log import fatal import torch as th from demucs.apply import apply_model, BagOfModels from demucs.audio import save_audio from demucs.pretrained import get_model_from_args, ModelLoadingError from demucs.separate import load_track def separator( tracks: List[Path], out: Path, model: str, shifts: int, overlap: float, stem: str, int24: bool, float32: bool, clip_mode: str, mp3: bool, mp3_bitrate: int, verbose: bool, *args, **kwargs, ): """Separate the sources for the given tracks Args: tracks (Path): Path to tracks out (Path): Folder where to put extracted tracks. A subfolder with the model name will be created. model (str): Model name shifts (int): Number of random shifts for equivariant stabilization. Increase separation time but improves quality for Demucs. 10 was used in the original paper. overlap (float): Overlap stem (str): Only separate audio into {STEM} and no_{STEM}. int24 (bool): Save wav output as 24 bits wav. float32 (bool): Save wav output as float32 (2x bigger). clip_mode (str): Strategy for avoiding clipping: rescaling entire signal if necessary (rescale) or hard clipping (clamp). mp3 (bool): Convert the output wavs to mp3. mp3_bitrate (int): Bitrate of converted mp3. verbose (bool): Verbose """ if os.environ.get("LIMIT_CPU", False): th.set_num_threads(1) jobs = 1 else: # Number of jobs. This can increase memory usage but will be much faster when # multiple cores are available. jobs = os.cpu_count() if th.cuda.is_available(): device = "cuda" else: device = "cpu" args = argparse.Namespace() args.tracks = tracks args.out = out args.model = model args.device = device args.shifts = shifts args.overlap = overlap args.stem = stem args.int24 = int24 args.float32 = float32 args.clip_mode = clip_mode args.mp3 = mp3 args.mp3_bitrate = mp3_bitrate args.jobs = jobs args.verbose = verbose args.filename = "{track}/{stem}.{ext}" args.split = True args.segment = None args.name = model args.repo = None try: model = get_model_from_args(args) except ModelLoadingError as error: fatal(error.args[0]) if args.segment is not None and args.segment < 8: fatal("Segment must greater than 8. ") if ".." in args.filename.replace("\\", "/").split("/"): fatal('".." must not appear in filename. ') if isinstance(model, BagOfModels): print( f"Selected model is a bag of {len(model.models)} models. " "You will see that many progress bars per track." ) if args.segment is not None: for sub in model.models: sub.segment = args.segment else: if args.segment is not None: model.segment = args.segment model.cpu() model.eval() if args.stem is not None and args.stem not in model.sources: fatal( 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( stem=args.stem, sources=", ".join(model.sources) ) ) out = args.out / args.name out.mkdir(parents=True, exist_ok=True) print(f"Separated tracks will be stored in {out.resolve()}") for track in args.tracks: if not track.exists(): print( f"File {track} does not exist. If the path contains spaces, " 'please try again after surrounding the entire path with quotes "".', file=sys.stderr, ) continue print(f"Separating track {track}") wav = load_track(track, model.audio_channels, model.samplerate) ref = wav.mean(0) wav = (wav - ref.mean()) / ref.std() sources = apply_model( model, wav[None], device=args.device, shifts=args.shifts, split=args.split, overlap=args.overlap, progress=True, num_workers=args.jobs, )[0] sources = sources * ref.std() + ref.mean() if args.mp3: ext = "mp3" else: ext = "wav" kwargs = { "samplerate": model.samplerate, "bitrate": args.mp3_bitrate, "clip": args.clip_mode, "as_float": args.float32, "bits_per_sample": 24 if args.int24 else 16, } if args.stem is None: for source, name in zip(sources, model.sources): stem = out / args.filename.format( track=track.name.rsplit(".", 1)[0], trackext=track.name.rsplit(".", 1)[-1], stem=name, ext=ext, ) stem.parent.mkdir(parents=True, exist_ok=True) save_audio(source, str(stem), **kwargs) else: sources = list(sources) stem = out / args.filename.format( track=track.name.rsplit(".", 1)[0], trackext=track.name.rsplit(".", 1)[-1], stem=args.stem, ext=ext, ) stem.parent.mkdir(parents=True, exist_ok=True) save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs) # Warning : after poping the stem, selected stem is no longer in the list 'sources' other_stem = th.zeros_like(sources[0]) for i in sources: other_stem += i stem = out / args.filename.format( track=track.name.rsplit(".", 1)[0], trackext=track.name.rsplit(".", 1)[-1], stem="no_" + args.stem, ext=ext, ) stem.parent.mkdir(parents=True, exist_ok=True) save_audio(other_stem, str(stem), **kwargs) ############################################################################## import os import logging import librosa import numpy as np import soundfile as sf import torch from pydub import AudioSegment if os.environ.get("LIMIT_CPU", False): torch.set_num_threads(1) def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32): if min_range < fade_size * 2: raise ValueError("min_range must be >= fade_size * 2") idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0] start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) artifact_idx = np.where(end_idx - start_idx > min_range)[0] weight = np.zeros_like(y_mask) if len(artifact_idx) > 0: start_idx = start_idx[artifact_idx] end_idx = end_idx[artifact_idx] old_e = None for s, e in zip(start_idx, end_idx): if old_e is not None and s - old_e < fade_size: s = old_e - fade_size * 2 if s != 0: weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size) else: s -= fade_size if e != y_mask.shape[2]: weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size) else: e += fade_size weight[:, :, s + fade_size : e - fade_size] = 1 old_e = e v_mask = 1 - y_mask y_mask += weight * v_mask return y_mask def make_padding(width, cropsize, offset): left = offset roi_size = cropsize - offset * 2 if roi_size == 0: roi_size = cropsize right = roi_size - (width % roi_size) + left return left, right, roi_size def wave_to_spectrogram(wave, hop_length, n_fft): wave_left = np.asfortranarray(wave[0]) wave_right = np.asfortranarray(wave[1]) spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length) spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) spec = np.asfortranarray([spec_left, spec_right]) return spec def spectrogram_to_wave(spec, hop_length=1024): if spec.ndim == 2: wave = librosa.istft(spec, hop_length=hop_length) elif spec.ndim == 3: spec_left = np.asfortranarray(spec[0]) spec_right = np.asfortranarray(spec[1]) wave_left = librosa.istft(spec_left, hop_length=hop_length) wave_right = librosa.istft(spec_right, hop_length=hop_length) wave = np.asfortranarray([wave_left, wave_right]) return wave class Separator(object): def __init__(self, model, device, batchsize, cropsize, postprocess=False, progress_bar=None): self.model = model self.offset = model.offset self.device = device self.batchsize = batchsize self.cropsize = cropsize self.postprocess = postprocess self.progress_bar = progress_bar def _separate(self, X_mag_pad, roi_size): X_dataset = [] patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size for i in range(patches): start = i * roi_size X_mag_crop = X_mag_pad[:, :, start : start + self.cropsize] X_dataset.append(X_mag_crop) X_dataset = np.asarray(X_dataset) self.model.eval() with torch.no_grad(): mask = [] # To reduce the overhead, dataloader is not used. for i in range(0, patches, self.batchsize): X_batch = X_dataset[i : i + self.batchsize] X_batch = torch.from_numpy(X_batch).to(self.device) pred = self.model.predict_mask(X_batch) pred = pred.detach().cpu().numpy() pred = np.concatenate(pred, axis=2) mask.append(pred) mask = np.concatenate(mask, axis=2) return mask def _preprocess(self, X_spec): X_mag = np.abs(X_spec) X_phase = np.angle(X_spec) return X_mag, X_phase def _postprocess(self, mask, X_mag, X_phase): if self.postprocess: mask = merge_artifacts(mask) y_spec = mask * X_mag * np.exp(1.0j * X_phase) v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase) return y_spec, v_spec def separate(self, X_spec): X_mag, X_phase = self._preprocess(X_spec) n_frame = X_mag.shape[2] pad_l, pad_r, roi_size = make_padding(n_frame, self.cropsize, self.offset) X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") X_mag_pad /= X_mag_pad.max() mask = self._separate(X_mag_pad, roi_size) mask = mask[:, :, :n_frame] y_spec, v_spec = self._postprocess(mask, X_mag, X_phase) return y_spec, v_spec def load_model(pretrained_model, n_fft=2048): model = CascadedNet(n_fft, 32, 128) if torch.cuda.is_available(): device = torch.device("cuda:0") model.to(device) # elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # device = torch.device("mps") # model.to(device) else: device = torch.device("cpu") model.load_state_dict(torch.load(pretrained_model, map_location=device)) return model, device def separate( input, model, device, output_dir, batchsize=4, cropsize=256, postprocess=False, hop_length=1024, n_fft=2048, sr=44100, progress_bar=None, only_no_vocals=False, ): X, sr = librosa.load(input, sr=sr, mono=False, dtype=np.float32, res_type="kaiser_fast") basename = os.path.splitext(os.path.basename(input))[0] if X.ndim == 1: # mono to stereo X = np.asarray([X, X]) X_spec = wave_to_spectrogram(X, hop_length, n_fft) with torch.no_grad(): sp = Separator(model, device, batchsize, cropsize, postprocess, progress_bar=progress_bar) y_spec, v_spec = sp.separate(X_spec) base_dir = f"{output_dir}/vocal_remover/{basename}" os.makedirs(base_dir, exist_ok=True) wave = spectrogram_to_wave(y_spec, hop_length=hop_length) try: sf.write(f"{base_dir}/no_vocals.mp3", wave.T, sr) except Exception: logging.error("Failed to write no_vocals.mp3, trying pydub...") pydub_write(wave, f"{base_dir}/no_vocals.mp3", sr) if only_no_vocals: return wave = spectrogram_to_wave(v_spec, hop_length=hop_length) try: sf.write(f"{base_dir}/vocals.mp3", wave.T, sr) except Exception: logging.error("Failed to write vocals.mp3, trying pydub...") pydub_write(wave, f"{base_dir}/vocals.mp3", sr) def pydub_write(wave, output_path, frame_rate, audio_format="mp3"): # Ensure the wave data is in the right format for pydub (mono and 16-bit depth) wave_16bit = (wave * 32767).astype(np.int16) audio_segment = AudioSegment( wave_16bit.tobytes(), frame_rate=frame_rate, sample_width=wave_16bit.dtype.itemsize, channels=1, ) audio_segment.export(output_path, format=audio_format) ##################################################################################### import torch from torch import nn import torch.nn.functional as F class BaseNet(nn.Module): def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): super(BaseNet, self).__init__() self.enc1 = Conv2DBNActiv(nin, nout, 3, 1, 1) self.enc2 = Encoder(nout, nout * 2, 3, 2, 1) self.enc3 = Encoder(nout * 2, nout * 4, 3, 2, 1) self.enc4 = Encoder(nout * 4, nout * 6, 3, 2, 1) self.enc5 = Encoder(nout * 6, nout * 8, 3, 2, 1) self.aspp = ASPPModule(nout * 8, nout * 8, dilations, dropout=True) self.dec4 = Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) self.dec3 = Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) self.dec2 = Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) self.lstm_dec2 = LSTMModule(nout * 2, nin_lstm, nout_lstm) self.dec1 = Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) def __call__(self, x): e1 = self.enc1(x) e2 = self.enc2(e1) e3 = self.enc3(e2) e4 = self.enc4(e3) e5 = self.enc5(e4) h = self.aspp(e5) h = self.dec4(h, e4) h = self.dec3(h, e3) h = self.dec2(h, e2) h = torch.cat([h, self.lstm_dec2(h)], dim=1) h = self.dec1(h, e1) return h class CascadedNet(nn.Module): def __init__(self, n_fft, nout=32, nout_lstm=128): super(CascadedNet, self).__init__() self.max_bin = n_fft // 2 self.output_bin = n_fft // 2 + 1 self.nin_lstm = self.max_bin // 2 self.offset = 64 self.stg1_low_band_net = nn.Sequential( BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0), ) self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2) self.stg2_low_band_net = nn.Sequential( BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), Conv2DBNActiv(nout, nout // 2, 1, 1, 0), ) self.stg2_high_band_net = BaseNet( nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2 ) self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm) self.out = nn.Conv2d(nout, 2, 1, bias=False) self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) def forward(self, x): x = x[:, :, : self.max_bin] bandw = x.size()[2] // 2 l1_in = x[:, :, :bandw] h1_in = x[:, :, bandw:] l1 = self.stg1_low_band_net(l1_in) h1 = self.stg1_high_band_net(h1_in) aux1 = torch.cat([l1, h1], dim=2) l2_in = torch.cat([l1_in, l1], dim=1) h2_in = torch.cat([h1_in, h1], dim=1) l2 = self.stg2_low_band_net(l2_in) h2 = self.stg2_high_band_net(h2_in) aux2 = torch.cat([l2, h2], dim=2) f3_in = torch.cat([x, aux1, aux2], dim=1) f3 = self.stg3_full_band_net(f3_in) mask = torch.sigmoid(self.out(f3)) mask = F.pad( input=mask, pad=(0, 0, 0, self.output_bin - mask.size()[2]), mode="replicate", ) if self.training: aux = torch.cat([aux1, aux2], dim=1) aux = torch.sigmoid(self.aux_out(aux)) aux = F.pad( input=aux, pad=(0, 0, 0, self.output_bin - aux.size()[2]), mode="replicate", ) return mask, aux else: return mask def predict_mask(self, x): mask = self.forward(x) if self.offset > 0: mask = mask[:, :, :, self.offset : -self.offset] assert mask.size()[3] > 0 return mask def predict(self, x): mask = self.forward(x) pred_mag = x * mask if self.offset > 0: pred_mag = pred_mag[:, :, :, self.offset : -self.offset] assert pred_mag.size()[3] > 0 return pred_mag ############################################################################## def crop_center(h1, h2): h1_shape = h1.size() h2_shape = h2.size() if h1_shape[3] == h2_shape[3]: return h1 elif h1_shape[3] < h2_shape[3]: raise ValueError("h1_shape[3] must be greater than h2_shape[3]") s_time = (h1_shape[3] - h2_shape[3]) // 2 e_time = s_time + h2_shape[3] h1 = h1[:, :, :, s_time:e_time] return h1 class Conv2DBNActiv(nn.Module): def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): super(Conv2DBNActiv, self).__init__() self.conv = nn.Sequential( nn.Conv2d( nin, nout, kernel_size=ksize, stride=stride, padding=pad, dilation=dilation, bias=False, ), nn.BatchNorm2d(nout), activ(), ) def __call__(self, x): return self.conv(x) class Encoder(nn.Module): def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): super(Encoder, self).__init__() self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) def __call__(self, x): h = self.conv1(x) h = self.conv2(h) return h class Decoder(nn.Module): def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): super(Decoder, self).__init__() self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.dropout = nn.Dropout2d(0.1) if dropout else None def __call__(self, x, skip=None): x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) if skip is not None: skip = crop_center(skip, x) x = torch.cat([x, skip], dim=1) h = self.conv1(x) # h = self.conv2(h) if self.dropout is not None: h = self.dropout(h) return h class ASPPModule(nn.Module): def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): super(ASPPModule, self).__init__() self.conv1 = nn.Sequential( nn.AdaptiveAvgPool2d((1, None)), Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ), ) self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ) self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ) self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ) self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) self.dropout = nn.Dropout2d(0.1) if dropout else None def forward(self, x): _, _, h, w = x.size() feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True) feat2 = self.conv2(x) feat3 = self.conv3(x) feat4 = self.conv4(x) feat5 = self.conv5(x) out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) out = self.bottleneck(out) if self.dropout is not None: out = self.dropout(out) return out class LSTMModule(nn.Module): def __init__(self, nin_conv, nin_lstm, nout_lstm): super(LSTMModule, self).__init__() self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True) self.dense = nn.Sequential( nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU() ) def forward(self, x): N, _, nbins, nframes = x.size() h = self.conv(x)[:, 0] # N, nbins, nframes h = h.permute(2, 0, 1) # nframes, N, nbins h, _ = self.lstm(h) h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins h = h.reshape(nframes, N, 1, nbins) h = h.permute(1, 2, 3, 0) return h