amanu / audio.py
katospiegel's picture
First Test
73b906e
raw
history blame
27.9 kB
from utils import *
import datetime
from pydub import AudioSegment, effects
def normalizeAudio(file, format):
#https://stackoverflow.com/questions/42492246/how-to-normalize-the-volume-of-an-audio-file-in-python
rawsound = AudioSegment.from_file(file, format)
normalizedsound = effects.normalize(rawsound)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_file = f"normalized_{timestamp}.wav"
normalizedsound.export(output_file, format="wav")
return output_file
def mp3_to_wav(mp3_path, tag):
# Load the MP3 file
audio = AudioSegment.from_mp3(mp3_path)
outfile = mp3_path.split(".")[0] + tag +".wav"
# Export the audio in WAV format
audio.export(outfile, format="wav")
return outfile
def stereo_to_mono(wav_path):
# Load the stereo WAV file
audio = AudioSegment.from_wav(wav_path)
# Convert to mono
audio_mono = audio.set_channels(1)
# Export the audio in WAV format
audio_mono.export(wav_path, format="wav")
return wav_path
def cutaudio(audiopath, start_time, end_time):
audio = AudioSegment.from_wav(audiopath)[start_time:end_time]
exportname = str(start_time)+"_"+str(end_time)+".wav"
audio.export(exportname, format="wav")
return exportname
def compose_audio(audio_files, timestamps, output_file):
# Example usage:
# audio_files = ["clip1.wav", "clip2.wav", "clip3.wav"]
# timestamps = [0, 5000, 10000, 15] # clip1 starts at 0s, clip2 at 5s, and clip3 at 10s; audio ends at 15s
# output_file = "composed_audio.wav"
# compose_audio(audio_files, timestamps, output_file)
# Check if lengths are consistent
if len(audio_files) != len(timestamps) - 1:
raise ValueError("Number of timestamps should be one more than number of audio files")
# Load the first audio file
final_audio = AudioSegment.silent(duration=timestamps[0])
for i, audio_file in enumerate(audio_files):
# Load the audio clip
clip = AudioSegment.from_wav(audio_file) # Change this if you're using a different format
# Calculate the amount of silence needed before the clip
silence_duration = (timestamps[i + 1] - timestamps[i] - len(clip) ) # in milliseconds
if silence_duration < 0:
print(f"Warning: Clip {audio_file} is longer than the gap between timestamps {i} and {i + 1}. Trimming the audio.")
clip = clip[:timestamps[i + 1] - timestamps[i]] # Trim the clip
silence_duration = 0
final_audio += clip + AudioSegment.silent(duration=silence_duration)
# Export final audio
#timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#output_file_time = f"{output_file}_{timestamp}.wav"
final_audio.export(output_file, format="wav")
return output_file
def append_wav_files(filenames, output_filename):
# Load the first WAV file
combined = AudioSegment.from_wav(filenames[0])
# Load each subsequent WAV file and append to the combined segment
for filename in filenames[1:]:
audio = AudioSegment.from_wav(filename)
combined += audio
# Export the combined audio
combined.export(output_filename, format="wav")
return output_filename
# def generateAudio(respuesta, elabs_key):
# user = ElevenLabsUser(elabs_key)
# premadeVoice = user.get_voices_by_name("Rachel")[0]
# playbackOptions = PlaybackOptions(runInBackground=False)
# generationOptions = GenerationOptions(model_id="eleven_multilingual_v1", stability=0.3, similarity_boost=0.7, style=0.6, #eleven_english_v2
# use_speaker_boost=True)
# audioData, historyID = premadeVoice.generate_audio_v2(respuesta, generationOptions)
# #generationData = premadeVoice.generate_play_audio_v2(text, PlaybackOptions(runInBackground=False), GenerationOptions(stability=0.4))
# filename = "output.wav"
# #Save them to disk, in ogg format (can be any format supported by SoundFile)
# save_audio_bytes(audioData, filename, outputFormat="wav")
# return filename
def overlay_audios(audio_paths, output_file):
# Load all the audios
audios = [AudioSegment.from_wav(path) for path in audio_paths] # assuming WAV format
# Find the length of the longest audio
max_length = max(audio.duration_seconds for audio in audios)
# Pad all audios to the length of the longest one
padded_audios = [audio + AudioSegment.silent(duration=(max_length - audio.duration_seconds) * 1000) for audio in audios]
# Start with the first padded audio
overlay_audio = padded_audios[0]
# Overlay the rest of the audios on top
for audio in padded_audios[1:]:
overlay_audio = overlay_audio.overlay(audio)
overlay_audio.export(output_file, format="wav")
return output_file
def total_duration(audiofile):
audiofile = Path(audiofile)
format = audiofile.suffix.replace(".","")
song = AudioSegment.from_file(audiofile, format=format)
#song = load_audio_segment(audiofile, audiofile.split(".")[-1])
n_msecs = len(song)
return n_msecs
###########################################################################
def separateVoiceInstrumental(audiofile):
audiofile = Path(audiofile)
filename = audiofile.stem
format = audiofile.suffix.replace(".","")
song = AudioSegment.from_file(audiofile, format=format)
#song = load_audio_segment(audiofile, audiofile.split(".")[-1])
n_secs = round(len(song) / 1000)
start_time = 0
end_time = n_secs
model_name, file_sources = ("htdemucs", ["vocals.mp3", "no_vocals.mp3"])
out_path = Path("output")
stem = "vocals"
separator(
tracks=[audiofile],
out=out_path,
model=model_name,
shifts=1,
overlap=0.5,
stem=stem,
int24=False,
float32=False,
clip_mode="rescale",
mp3=True,
mp3_bitrate=320,
verbose=True,
start_time=start_time,
end_time=end_time,
)
instrumentalFile = f"output/htdemucs/{filename}/no_vocals.mp3"
voiceFile = f"output/htdemucs/{filename}/vocals.mp3"
return instrumentalFile, voiceFile
################################################################################
import argparse
import sys
from pathlib import Path
from typing import List
import os
from dora.log import fatal
import torch as th
from demucs.apply import apply_model, BagOfModels
from demucs.audio import save_audio
from demucs.pretrained import get_model_from_args, ModelLoadingError
from demucs.separate import load_track
def separator(
tracks: List[Path],
out: Path,
model: str,
shifts: int,
overlap: float,
stem: str,
int24: bool,
float32: bool,
clip_mode: str,
mp3: bool,
mp3_bitrate: int,
verbose: bool,
*args,
**kwargs,
):
"""Separate the sources for the given tracks
Args:
tracks (Path): Path to tracks
out (Path): Folder where to put extracted tracks. A subfolder with the model name will be
created.
model (str): Model name
shifts (int): Number of random shifts for equivariant stabilization.
Increase separation time but improves quality for Demucs.
10 was used in the original paper.
overlap (float): Overlap
stem (str): Only separate audio into {STEM} and no_{STEM}.
int24 (bool): Save wav output as 24 bits wav.
float32 (bool): Save wav output as float32 (2x bigger).
clip_mode (str): Strategy for avoiding clipping: rescaling entire signal if necessary
(rescale) or hard clipping (clamp).
mp3 (bool): Convert the output wavs to mp3.
mp3_bitrate (int): Bitrate of converted mp3.
verbose (bool): Verbose
"""
if os.environ.get("LIMIT_CPU", False):
th.set_num_threads(1)
jobs = 1
else:
# Number of jobs. This can increase memory usage but will be much faster when
# multiple cores are available.
jobs = os.cpu_count()
if th.cuda.is_available():
device = "cuda"
else:
device = "cpu"
args = argparse.Namespace()
args.tracks = tracks
args.out = out
args.model = model
args.device = device
args.shifts = shifts
args.overlap = overlap
args.stem = stem
args.int24 = int24
args.float32 = float32
args.clip_mode = clip_mode
args.mp3 = mp3
args.mp3_bitrate = mp3_bitrate
args.jobs = jobs
args.verbose = verbose
args.filename = "{track}/{stem}.{ext}"
args.split = True
args.segment = None
args.name = model
args.repo = None
try:
model = get_model_from_args(args)
except ModelLoadingError as error:
fatal(error.args[0])
if args.segment is not None and args.segment < 8:
fatal("Segment must greater than 8. ")
if ".." in args.filename.replace("\\", "/").split("/"):
fatal('".." must not appear in filename. ')
if isinstance(model, BagOfModels):
print(
f"Selected model is a bag of {len(model.models)} models. "
"You will see that many progress bars per track."
)
if args.segment is not None:
for sub in model.models:
sub.segment = args.segment
else:
if args.segment is not None:
model.segment = args.segment
model.cpu()
model.eval()
if args.stem is not None and args.stem not in model.sources:
fatal(
'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format(
stem=args.stem, sources=", ".join(model.sources)
)
)
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
print(f"Separated tracks will be stored in {out.resolve()}")
for track in args.tracks:
if not track.exists():
print(
f"File {track} does not exist. If the path contains spaces, "
'please try again after surrounding the entire path with quotes "".',
file=sys.stderr,
)
continue
print(f"Separating track {track}")
wav = load_track(track, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(
model,
wav[None],
device=args.device,
shifts=args.shifts,
split=args.split,
overlap=args.overlap,
progress=True,
num_workers=args.jobs,
)[0]
sources = sources * ref.std() + ref.mean()
if args.mp3:
ext = "mp3"
else:
ext = "wav"
kwargs = {
"samplerate": model.samplerate,
"bitrate": args.mp3_bitrate,
"clip": args.clip_mode,
"as_float": args.float32,
"bits_per_sample": 24 if args.int24 else 16,
}
if args.stem is None:
for source, name in zip(sources, model.sources):
stem = out / args.filename.format(
track=track.name.rsplit(".", 1)[0],
trackext=track.name.rsplit(".", 1)[-1],
stem=name,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(source, str(stem), **kwargs)
else:
sources = list(sources)
stem = out / args.filename.format(
track=track.name.rsplit(".", 1)[0],
trackext=track.name.rsplit(".", 1)[-1],
stem=args.stem,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs)
# Warning : after poping the stem, selected stem is no longer in the list 'sources'
other_stem = th.zeros_like(sources[0])
for i in sources:
other_stem += i
stem = out / args.filename.format(
track=track.name.rsplit(".", 1)[0],
trackext=track.name.rsplit(".", 1)[-1],
stem="no_" + args.stem,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(other_stem, str(stem), **kwargs)
##############################################################################
import os
import logging
import librosa
import numpy as np
import soundfile as sf
import torch
from pydub import AudioSegment
if os.environ.get("LIMIT_CPU", False):
torch.set_num_threads(1)
def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32):
if min_range < fade_size * 2:
raise ValueError("min_range must be >= fade_size * 2")
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
artifact_idx = np.where(end_idx - start_idx > min_range)[0]
weight = np.zeros_like(y_mask)
if len(artifact_idx) > 0:
start_idx = start_idx[artifact_idx]
end_idx = end_idx[artifact_idx]
old_e = None
for s, e in zip(start_idx, end_idx):
if old_e is not None and s - old_e < fade_size:
s = old_e - fade_size * 2
if s != 0:
weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
else:
s -= fade_size
if e != y_mask.shape[2]:
weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
else:
e += fade_size
weight[:, :, s + fade_size : e - fade_size] = 1
old_e = e
v_mask = 1 - y_mask
y_mask += weight * v_mask
return y_mask
def make_padding(width, cropsize, offset):
left = offset
roi_size = cropsize - offset * 2
if roi_size == 0:
roi_size = cropsize
right = roi_size - (width % roi_size) + left
return left, right, roi_size
def wave_to_spectrogram(wave, hop_length, n_fft):
wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
spec = np.asfortranarray([spec_left, spec_right])
return spec
def spectrogram_to_wave(spec, hop_length=1024):
if spec.ndim == 2:
wave = librosa.istft(spec, hop_length=hop_length)
elif spec.ndim == 3:
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hop_length)
wave_right = librosa.istft(spec_right, hop_length=hop_length)
wave = np.asfortranarray([wave_left, wave_right])
return wave
class Separator(object):
def __init__(self, model, device, batchsize, cropsize, postprocess=False, progress_bar=None):
self.model = model
self.offset = model.offset
self.device = device
self.batchsize = batchsize
self.cropsize = cropsize
self.postprocess = postprocess
self.progress_bar = progress_bar
def _separate(self, X_mag_pad, roi_size):
X_dataset = []
patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size
for i in range(patches):
start = i * roi_size
X_mag_crop = X_mag_pad[:, :, start : start + self.cropsize]
X_dataset.append(X_mag_crop)
X_dataset = np.asarray(X_dataset)
self.model.eval()
with torch.no_grad():
mask = []
# To reduce the overhead, dataloader is not used.
for i in range(0, patches, self.batchsize):
X_batch = X_dataset[i : i + self.batchsize]
X_batch = torch.from_numpy(X_batch).to(self.device)
pred = self.model.predict_mask(X_batch)
pred = pred.detach().cpu().numpy()
pred = np.concatenate(pred, axis=2)
mask.append(pred)
mask = np.concatenate(mask, axis=2)
return mask
def _preprocess(self, X_spec):
X_mag = np.abs(X_spec)
X_phase = np.angle(X_spec)
return X_mag, X_phase
def _postprocess(self, mask, X_mag, X_phase):
if self.postprocess:
mask = merge_artifacts(mask)
y_spec = mask * X_mag * np.exp(1.0j * X_phase)
v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase)
return y_spec, v_spec
def separate(self, X_spec):
X_mag, X_phase = self._preprocess(X_spec)
n_frame = X_mag.shape[2]
pad_l, pad_r, roi_size = make_padding(n_frame, self.cropsize, self.offset)
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
X_mag_pad /= X_mag_pad.max()
mask = self._separate(X_mag_pad, roi_size)
mask = mask[:, :, :n_frame]
y_spec, v_spec = self._postprocess(mask, X_mag, X_phase)
return y_spec, v_spec
def load_model(pretrained_model, n_fft=2048):
model = CascadedNet(n_fft, 32, 128)
if torch.cuda.is_available():
device = torch.device("cuda:0")
model.to(device)
# elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
# device = torch.device("mps")
# model.to(device)
else:
device = torch.device("cpu")
model.load_state_dict(torch.load(pretrained_model, map_location=device))
return model, device
def separate(
input,
model,
device,
output_dir,
batchsize=4,
cropsize=256,
postprocess=False,
hop_length=1024,
n_fft=2048,
sr=44100,
progress_bar=None,
only_no_vocals=False,
):
X, sr = librosa.load(input, sr=sr, mono=False, dtype=np.float32, res_type="kaiser_fast")
basename = os.path.splitext(os.path.basename(input))[0]
if X.ndim == 1:
# mono to stereo
X = np.asarray([X, X])
X_spec = wave_to_spectrogram(X, hop_length, n_fft)
with torch.no_grad():
sp = Separator(model, device, batchsize, cropsize, postprocess, progress_bar=progress_bar)
y_spec, v_spec = sp.separate(X_spec)
base_dir = f"{output_dir}/vocal_remover/{basename}"
os.makedirs(base_dir, exist_ok=True)
wave = spectrogram_to_wave(y_spec, hop_length=hop_length)
try:
sf.write(f"{base_dir}/no_vocals.mp3", wave.T, sr)
except Exception:
logging.error("Failed to write no_vocals.mp3, trying pydub...")
pydub_write(wave, f"{base_dir}/no_vocals.mp3", sr)
if only_no_vocals:
return
wave = spectrogram_to_wave(v_spec, hop_length=hop_length)
try:
sf.write(f"{base_dir}/vocals.mp3", wave.T, sr)
except Exception:
logging.error("Failed to write vocals.mp3, trying pydub...")
pydub_write(wave, f"{base_dir}/vocals.mp3", sr)
def pydub_write(wave, output_path, frame_rate, audio_format="mp3"):
# Ensure the wave data is in the right format for pydub (mono and 16-bit depth)
wave_16bit = (wave * 32767).astype(np.int16)
audio_segment = AudioSegment(
wave_16bit.tobytes(),
frame_rate=frame_rate,
sample_width=wave_16bit.dtype.itemsize,
channels=1,
)
audio_segment.export(output_path, format=audio_format)
#####################################################################################
import torch
from torch import nn
import torch.nn.functional as F
class BaseNet(nn.Module):
def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))):
super(BaseNet, self).__init__()
self.enc1 = Conv2DBNActiv(nin, nout, 3, 1, 1)
self.enc2 = Encoder(nout, nout * 2, 3, 2, 1)
self.enc3 = Encoder(nout * 2, nout * 4, 3, 2, 1)
self.enc4 = Encoder(nout * 4, nout * 6, 3, 2, 1)
self.enc5 = Encoder(nout * 6, nout * 8, 3, 2, 1)
self.aspp = ASPPModule(nout * 8, nout * 8, dilations, dropout=True)
self.dec4 = Decoder(nout * (6 + 8), nout * 6, 3, 1, 1)
self.dec3 = Decoder(nout * (4 + 6), nout * 4, 3, 1, 1)
self.dec2 = Decoder(nout * (2 + 4), nout * 2, 3, 1, 1)
self.lstm_dec2 = LSTMModule(nout * 2, nin_lstm, nout_lstm)
self.dec1 = Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1)
def __call__(self, x):
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
h = self.aspp(e5)
h = self.dec4(h, e4)
h = self.dec3(h, e3)
h = self.dec2(h, e2)
h = torch.cat([h, self.lstm_dec2(h)], dim=1)
h = self.dec1(h, e1)
return h
class CascadedNet(nn.Module):
def __init__(self, n_fft, nout=32, nout_lstm=128):
super(CascadedNet, self).__init__()
self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.nin_lstm = self.max_bin // 2
self.offset = 64
self.stg1_low_band_net = nn.Sequential(
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0),
)
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
self.stg2_low_band_net = nn.Sequential(
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
Conv2DBNActiv(nout, nout // 2, 1, 1, 0),
)
self.stg2_high_band_net = BaseNet(
nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
)
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
self.out = nn.Conv2d(nout, 2, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)
def forward(self, x):
x = x[:, :, : self.max_bin]
bandw = x.size()[2] // 2
l1_in = x[:, :, :bandw]
h1_in = x[:, :, bandw:]
l1 = self.stg1_low_band_net(l1_in)
h1 = self.stg1_high_band_net(h1_in)
aux1 = torch.cat([l1, h1], dim=2)
l2_in = torch.cat([l1_in, l1], dim=1)
h2_in = torch.cat([h1_in, h1], dim=1)
l2 = self.stg2_low_band_net(l2_in)
h2 = self.stg2_high_band_net(h2_in)
aux2 = torch.cat([l2, h2], dim=2)
f3_in = torch.cat([x, aux1, aux2], dim=1)
f3 = self.stg3_full_band_net(f3_in)
mask = torch.sigmoid(self.out(f3))
mask = F.pad(
input=mask,
pad=(0, 0, 0, self.output_bin - mask.size()[2]),
mode="replicate",
)
if self.training:
aux = torch.cat([aux1, aux2], dim=1)
aux = torch.sigmoid(self.aux_out(aux))
aux = F.pad(
input=aux,
pad=(0, 0, 0, self.output_bin - aux.size()[2]),
mode="replicate",
)
return mask, aux
else:
return mask
def predict_mask(self, x):
mask = self.forward(x)
if self.offset > 0:
mask = mask[:, :, :, self.offset : -self.offset]
assert mask.size()[3] > 0
return mask
def predict(self, x):
mask = self.forward(x)
pred_mag = x * mask
if self.offset > 0:
pred_mag = pred_mag[:, :, :, self.offset : -self.offset]
assert pred_mag.size()[3] > 0
return pred_mag
##############################################################################
def crop_center(h1, h2):
h1_shape = h1.size()
h2_shape = h2.size()
if h1_shape[3] == h2_shape[3]:
return h1
elif h1_shape[3] < h2_shape[3]:
raise ValueError("h1_shape[3] must be greater than h2_shape[3]")
s_time = (h1_shape[3] - h2_shape[3]) // 2
e_time = s_time + h2_shape[3]
h1 = h1[:, :, :, s_time:e_time]
return h1
class Conv2DBNActiv(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
super(Conv2DBNActiv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
nin,
nout,
kernel_size=ksize,
stride=stride,
padding=pad,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(nout),
activ(),
)
def __call__(self, x):
return self.conv(x)
class Encoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
super(Encoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ)
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
def __call__(self, x):
h = self.conv1(x)
h = self.conv2(h)
return h
class Decoder(nn.Module):
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
super(Decoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def __call__(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
if skip is not None:
skip = crop_center(skip, x)
x = torch.cat([x, skip], dim=1)
h = self.conv1(x)
# h = self.conv2(h)
if self.dropout is not None:
h = self.dropout(h)
return h
class ASPPModule(nn.Module):
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False):
super(ASPPModule, self).__init__()
self.conv1 = nn.Sequential(
nn.AdaptiveAvgPool2d((1, None)),
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
)
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ)
self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ)
self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ)
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
feat2 = self.conv2(x)
feat3 = self.conv3(x)
feat4 = self.conv4(x)
feat5 = self.conv5(x)
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
out = self.bottleneck(out)
if self.dropout is not None:
out = self.dropout(out)
return out
class LSTMModule(nn.Module):
def __init__(self, nin_conv, nin_lstm, nout_lstm):
super(LSTMModule, self).__init__()
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True)
self.dense = nn.Sequential(
nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
)
def forward(self, x):
N, _, nbins, nframes = x.size()
h = self.conv(x)[:, 0] # N, nbins, nframes
h = h.permute(2, 0, 1) # nframes, N, nbins
h, _ = self.lstm(h)
h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins
h = h.reshape(nframes, N, 1, nbins)
h = h.permute(1, 2, 3, 0)
return h