Spaces:
Runtime error
Runtime error
from utils import * | |
import datetime | |
from pydub import AudioSegment, effects | |
def normalizeAudio(file, format): | |
#https://stackoverflow.com/questions/42492246/how-to-normalize-the-volume-of-an-audio-file-in-python | |
rawsound = AudioSegment.from_file(file, format) | |
normalizedsound = effects.normalize(rawsound) | |
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
output_file = f"normalized_{timestamp}.wav" | |
normalizedsound.export(output_file, format="wav") | |
return output_file | |
def mp3_to_wav(mp3_path, tag): | |
# Load the MP3 file | |
audio = AudioSegment.from_mp3(mp3_path) | |
outfile = mp3_path.split(".")[0] + tag +".wav" | |
# Export the audio in WAV format | |
audio.export(outfile, format="wav") | |
return outfile | |
def stereo_to_mono(wav_path): | |
# Load the stereo WAV file | |
audio = AudioSegment.from_wav(wav_path) | |
# Convert to mono | |
audio_mono = audio.set_channels(1) | |
# Export the audio in WAV format | |
audio_mono.export(wav_path, format="wav") | |
return wav_path | |
def cutaudio(audiopath, start_time, end_time): | |
audio = AudioSegment.from_wav(audiopath)[start_time:end_time] | |
exportname = str(start_time)+"_"+str(end_time)+".wav" | |
audio.export(exportname, format="wav") | |
return exportname | |
def compose_audio(audio_files, timestamps, output_file): | |
# Example usage: | |
# audio_files = ["clip1.wav", "clip2.wav", "clip3.wav"] | |
# timestamps = [0, 5000, 10000, 15] # clip1 starts at 0s, clip2 at 5s, and clip3 at 10s; audio ends at 15s | |
# output_file = "composed_audio.wav" | |
# compose_audio(audio_files, timestamps, output_file) | |
# Check if lengths are consistent | |
if len(audio_files) != len(timestamps) - 1: | |
raise ValueError("Number of timestamps should be one more than number of audio files") | |
# Load the first audio file | |
final_audio = AudioSegment.silent(duration=timestamps[0]) | |
for i, audio_file in enumerate(audio_files): | |
# Load the audio clip | |
clip = AudioSegment.from_wav(audio_file) # Change this if you're using a different format | |
# Calculate the amount of silence needed before the clip | |
silence_duration = (timestamps[i + 1] - timestamps[i] - len(clip) ) # in milliseconds | |
if silence_duration < 0: | |
print(f"Warning: Clip {audio_file} is longer than the gap between timestamps {i} and {i + 1}. Trimming the audio.") | |
clip = clip[:timestamps[i + 1] - timestamps[i]] # Trim the clip | |
silence_duration = 0 | |
final_audio += clip + AudioSegment.silent(duration=silence_duration) | |
# Export final audio | |
#timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
#output_file_time = f"{output_file}_{timestamp}.wav" | |
final_audio.export(output_file, format="wav") | |
return output_file | |
def append_wav_files(filenames, output_filename): | |
# Load the first WAV file | |
combined = AudioSegment.from_wav(filenames[0]) | |
# Load each subsequent WAV file and append to the combined segment | |
for filename in filenames[1:]: | |
audio = AudioSegment.from_wav(filename) | |
combined += audio | |
# Export the combined audio | |
combined.export(output_filename, format="wav") | |
return output_filename | |
# def generateAudio(respuesta, elabs_key): | |
# user = ElevenLabsUser(elabs_key) | |
# premadeVoice = user.get_voices_by_name("Rachel")[0] | |
# playbackOptions = PlaybackOptions(runInBackground=False) | |
# generationOptions = GenerationOptions(model_id="eleven_multilingual_v1", stability=0.3, similarity_boost=0.7, style=0.6, #eleven_english_v2 | |
# use_speaker_boost=True) | |
# audioData, historyID = premadeVoice.generate_audio_v2(respuesta, generationOptions) | |
# #generationData = premadeVoice.generate_play_audio_v2(text, PlaybackOptions(runInBackground=False), GenerationOptions(stability=0.4)) | |
# filename = "output.wav" | |
# #Save them to disk, in ogg format (can be any format supported by SoundFile) | |
# save_audio_bytes(audioData, filename, outputFormat="wav") | |
# return filename | |
def overlay_audios(audio_paths, output_file): | |
# Load all the audios | |
audios = [AudioSegment.from_wav(path) for path in audio_paths] # assuming WAV format | |
# Find the length of the longest audio | |
max_length = max(audio.duration_seconds for audio in audios) | |
# Pad all audios to the length of the longest one | |
padded_audios = [audio + AudioSegment.silent(duration=(max_length - audio.duration_seconds) * 1000) for audio in audios] | |
# Start with the first padded audio | |
overlay_audio = padded_audios[0] | |
# Overlay the rest of the audios on top | |
for audio in padded_audios[1:]: | |
overlay_audio = overlay_audio.overlay(audio) | |
overlay_audio.export(output_file, format="wav") | |
return output_file | |
def total_duration(audiofile): | |
audiofile = Path(audiofile) | |
format = audiofile.suffix.replace(".","") | |
song = AudioSegment.from_file(audiofile, format=format) | |
#song = load_audio_segment(audiofile, audiofile.split(".")[-1]) | |
n_msecs = len(song) | |
return n_msecs | |
########################################################################### | |
def separateVoiceInstrumental(audiofile): | |
audiofile = Path(audiofile) | |
filename = audiofile.stem | |
format = audiofile.suffix.replace(".","") | |
song = AudioSegment.from_file(audiofile, format=format) | |
#song = load_audio_segment(audiofile, audiofile.split(".")[-1]) | |
n_secs = round(len(song) / 1000) | |
start_time = 0 | |
end_time = n_secs | |
model_name, file_sources = ("htdemucs", ["vocals.mp3", "no_vocals.mp3"]) | |
out_path = Path("output") | |
stem = "vocals" | |
separator( | |
tracks=[audiofile], | |
out=out_path, | |
model=model_name, | |
shifts=1, | |
overlap=0.5, | |
stem=stem, | |
int24=False, | |
float32=False, | |
clip_mode="rescale", | |
mp3=True, | |
mp3_bitrate=320, | |
verbose=True, | |
start_time=start_time, | |
end_time=end_time, | |
) | |
instrumentalFile = f"output/htdemucs/{filename}/no_vocals.mp3" | |
voiceFile = f"output/htdemucs/{filename}/vocals.mp3" | |
return instrumentalFile, voiceFile | |
################################################################################ | |
import argparse | |
import sys | |
from pathlib import Path | |
from typing import List | |
import os | |
from dora.log import fatal | |
import torch as th | |
from demucs.apply import apply_model, BagOfModels | |
from demucs.audio import save_audio | |
from demucs.pretrained import get_model_from_args, ModelLoadingError | |
from demucs.separate import load_track | |
def separator( | |
tracks: List[Path], | |
out: Path, | |
model: str, | |
shifts: int, | |
overlap: float, | |
stem: str, | |
int24: bool, | |
float32: bool, | |
clip_mode: str, | |
mp3: bool, | |
mp3_bitrate: int, | |
verbose: bool, | |
*args, | |
**kwargs, | |
): | |
"""Separate the sources for the given tracks | |
Args: | |
tracks (Path): Path to tracks | |
out (Path): Folder where to put extracted tracks. A subfolder with the model name will be | |
created. | |
model (str): Model name | |
shifts (int): Number of random shifts for equivariant stabilization. | |
Increase separation time but improves quality for Demucs. | |
10 was used in the original paper. | |
overlap (float): Overlap | |
stem (str): Only separate audio into {STEM} and no_{STEM}. | |
int24 (bool): Save wav output as 24 bits wav. | |
float32 (bool): Save wav output as float32 (2x bigger). | |
clip_mode (str): Strategy for avoiding clipping: rescaling entire signal if necessary | |
(rescale) or hard clipping (clamp). | |
mp3 (bool): Convert the output wavs to mp3. | |
mp3_bitrate (int): Bitrate of converted mp3. | |
verbose (bool): Verbose | |
""" | |
if os.environ.get("LIMIT_CPU", False): | |
th.set_num_threads(1) | |
jobs = 1 | |
else: | |
# Number of jobs. This can increase memory usage but will be much faster when | |
# multiple cores are available. | |
jobs = os.cpu_count() | |
if th.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
args = argparse.Namespace() | |
args.tracks = tracks | |
args.out = out | |
args.model = model | |
args.device = device | |
args.shifts = shifts | |
args.overlap = overlap | |
args.stem = stem | |
args.int24 = int24 | |
args.float32 = float32 | |
args.clip_mode = clip_mode | |
args.mp3 = mp3 | |
args.mp3_bitrate = mp3_bitrate | |
args.jobs = jobs | |
args.verbose = verbose | |
args.filename = "{track}/{stem}.{ext}" | |
args.split = True | |
args.segment = None | |
args.name = model | |
args.repo = None | |
try: | |
model = get_model_from_args(args) | |
except ModelLoadingError as error: | |
fatal(error.args[0]) | |
if args.segment is not None and args.segment < 8: | |
fatal("Segment must greater than 8. ") | |
if ".." in args.filename.replace("\\", "/").split("/"): | |
fatal('".." must not appear in filename. ') | |
if isinstance(model, BagOfModels): | |
print( | |
f"Selected model is a bag of {len(model.models)} models. " | |
"You will see that many progress bars per track." | |
) | |
if args.segment is not None: | |
for sub in model.models: | |
sub.segment = args.segment | |
else: | |
if args.segment is not None: | |
model.segment = args.segment | |
model.cpu() | |
model.eval() | |
if args.stem is not None and args.stem not in model.sources: | |
fatal( | |
'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( | |
stem=args.stem, sources=", ".join(model.sources) | |
) | |
) | |
out = args.out / args.name | |
out.mkdir(parents=True, exist_ok=True) | |
print(f"Separated tracks will be stored in {out.resolve()}") | |
for track in args.tracks: | |
if not track.exists(): | |
print( | |
f"File {track} does not exist. If the path contains spaces, " | |
'please try again after surrounding the entire path with quotes "".', | |
file=sys.stderr, | |
) | |
continue | |
print(f"Separating track {track}") | |
wav = load_track(track, model.audio_channels, model.samplerate) | |
ref = wav.mean(0) | |
wav = (wav - ref.mean()) / ref.std() | |
sources = apply_model( | |
model, | |
wav[None], | |
device=args.device, | |
shifts=args.shifts, | |
split=args.split, | |
overlap=args.overlap, | |
progress=True, | |
num_workers=args.jobs, | |
)[0] | |
sources = sources * ref.std() + ref.mean() | |
if args.mp3: | |
ext = "mp3" | |
else: | |
ext = "wav" | |
kwargs = { | |
"samplerate": model.samplerate, | |
"bitrate": args.mp3_bitrate, | |
"clip": args.clip_mode, | |
"as_float": args.float32, | |
"bits_per_sample": 24 if args.int24 else 16, | |
} | |
if args.stem is None: | |
for source, name in zip(sources, model.sources): | |
stem = out / args.filename.format( | |
track=track.name.rsplit(".", 1)[0], | |
trackext=track.name.rsplit(".", 1)[-1], | |
stem=name, | |
ext=ext, | |
) | |
stem.parent.mkdir(parents=True, exist_ok=True) | |
save_audio(source, str(stem), **kwargs) | |
else: | |
sources = list(sources) | |
stem = out / args.filename.format( | |
track=track.name.rsplit(".", 1)[0], | |
trackext=track.name.rsplit(".", 1)[-1], | |
stem=args.stem, | |
ext=ext, | |
) | |
stem.parent.mkdir(parents=True, exist_ok=True) | |
save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs) | |
# Warning : after poping the stem, selected stem is no longer in the list 'sources' | |
other_stem = th.zeros_like(sources[0]) | |
for i in sources: | |
other_stem += i | |
stem = out / args.filename.format( | |
track=track.name.rsplit(".", 1)[0], | |
trackext=track.name.rsplit(".", 1)[-1], | |
stem="no_" + args.stem, | |
ext=ext, | |
) | |
stem.parent.mkdir(parents=True, exist_ok=True) | |
save_audio(other_stem, str(stem), **kwargs) | |
############################################################################## | |
import os | |
import logging | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from pydub import AudioSegment | |
if os.environ.get("LIMIT_CPU", False): | |
torch.set_num_threads(1) | |
def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32): | |
if min_range < fade_size * 2: | |
raise ValueError("min_range must be >= fade_size * 2") | |
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0] | |
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0]) | |
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1]) | |
artifact_idx = np.where(end_idx - start_idx > min_range)[0] | |
weight = np.zeros_like(y_mask) | |
if len(artifact_idx) > 0: | |
start_idx = start_idx[artifact_idx] | |
end_idx = end_idx[artifact_idx] | |
old_e = None | |
for s, e in zip(start_idx, end_idx): | |
if old_e is not None and s - old_e < fade_size: | |
s = old_e - fade_size * 2 | |
if s != 0: | |
weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size) | |
else: | |
s -= fade_size | |
if e != y_mask.shape[2]: | |
weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size) | |
else: | |
e += fade_size | |
weight[:, :, s + fade_size : e - fade_size] = 1 | |
old_e = e | |
v_mask = 1 - y_mask | |
y_mask += weight * v_mask | |
return y_mask | |
def make_padding(width, cropsize, offset): | |
left = offset | |
roi_size = cropsize - offset * 2 | |
if roi_size == 0: | |
roi_size = cropsize | |
right = roi_size - (width % roi_size) + left | |
return left, right, roi_size | |
def wave_to_spectrogram(wave, hop_length, n_fft): | |
wave_left = np.asfortranarray(wave[0]) | |
wave_right = np.asfortranarray(wave[1]) | |
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length) | |
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) | |
spec = np.asfortranarray([spec_left, spec_right]) | |
return spec | |
def spectrogram_to_wave(spec, hop_length=1024): | |
if spec.ndim == 2: | |
wave = librosa.istft(spec, hop_length=hop_length) | |
elif spec.ndim == 3: | |
spec_left = np.asfortranarray(spec[0]) | |
spec_right = np.asfortranarray(spec[1]) | |
wave_left = librosa.istft(spec_left, hop_length=hop_length) | |
wave_right = librosa.istft(spec_right, hop_length=hop_length) | |
wave = np.asfortranarray([wave_left, wave_right]) | |
return wave | |
class Separator(object): | |
def __init__(self, model, device, batchsize, cropsize, postprocess=False, progress_bar=None): | |
self.model = model | |
self.offset = model.offset | |
self.device = device | |
self.batchsize = batchsize | |
self.cropsize = cropsize | |
self.postprocess = postprocess | |
self.progress_bar = progress_bar | |
def _separate(self, X_mag_pad, roi_size): | |
X_dataset = [] | |
patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size | |
for i in range(patches): | |
start = i * roi_size | |
X_mag_crop = X_mag_pad[:, :, start : start + self.cropsize] | |
X_dataset.append(X_mag_crop) | |
X_dataset = np.asarray(X_dataset) | |
self.model.eval() | |
with torch.no_grad(): | |
mask = [] | |
# To reduce the overhead, dataloader is not used. | |
for i in range(0, patches, self.batchsize): | |
X_batch = X_dataset[i : i + self.batchsize] | |
X_batch = torch.from_numpy(X_batch).to(self.device) | |
pred = self.model.predict_mask(X_batch) | |
pred = pred.detach().cpu().numpy() | |
pred = np.concatenate(pred, axis=2) | |
mask.append(pred) | |
mask = np.concatenate(mask, axis=2) | |
return mask | |
def _preprocess(self, X_spec): | |
X_mag = np.abs(X_spec) | |
X_phase = np.angle(X_spec) | |
return X_mag, X_phase | |
def _postprocess(self, mask, X_mag, X_phase): | |
if self.postprocess: | |
mask = merge_artifacts(mask) | |
y_spec = mask * X_mag * np.exp(1.0j * X_phase) | |
v_spec = (1 - mask) * X_mag * np.exp(1.0j * X_phase) | |
return y_spec, v_spec | |
def separate(self, X_spec): | |
X_mag, X_phase = self._preprocess(X_spec) | |
n_frame = X_mag.shape[2] | |
pad_l, pad_r, roi_size = make_padding(n_frame, self.cropsize, self.offset) | |
X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") | |
X_mag_pad /= X_mag_pad.max() | |
mask = self._separate(X_mag_pad, roi_size) | |
mask = mask[:, :, :n_frame] | |
y_spec, v_spec = self._postprocess(mask, X_mag, X_phase) | |
return y_spec, v_spec | |
def load_model(pretrained_model, n_fft=2048): | |
model = CascadedNet(n_fft, 32, 128) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
model.to(device) | |
# elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
# device = torch.device("mps") | |
# model.to(device) | |
else: | |
device = torch.device("cpu") | |
model.load_state_dict(torch.load(pretrained_model, map_location=device)) | |
return model, device | |
def separate( | |
input, | |
model, | |
device, | |
output_dir, | |
batchsize=4, | |
cropsize=256, | |
postprocess=False, | |
hop_length=1024, | |
n_fft=2048, | |
sr=44100, | |
progress_bar=None, | |
only_no_vocals=False, | |
): | |
X, sr = librosa.load(input, sr=sr, mono=False, dtype=np.float32, res_type="kaiser_fast") | |
basename = os.path.splitext(os.path.basename(input))[0] | |
if X.ndim == 1: | |
# mono to stereo | |
X = np.asarray([X, X]) | |
X_spec = wave_to_spectrogram(X, hop_length, n_fft) | |
with torch.no_grad(): | |
sp = Separator(model, device, batchsize, cropsize, postprocess, progress_bar=progress_bar) | |
y_spec, v_spec = sp.separate(X_spec) | |
base_dir = f"{output_dir}/vocal_remover/{basename}" | |
os.makedirs(base_dir, exist_ok=True) | |
wave = spectrogram_to_wave(y_spec, hop_length=hop_length) | |
try: | |
sf.write(f"{base_dir}/no_vocals.mp3", wave.T, sr) | |
except Exception: | |
logging.error("Failed to write no_vocals.mp3, trying pydub...") | |
pydub_write(wave, f"{base_dir}/no_vocals.mp3", sr) | |
if only_no_vocals: | |
return | |
wave = spectrogram_to_wave(v_spec, hop_length=hop_length) | |
try: | |
sf.write(f"{base_dir}/vocals.mp3", wave.T, sr) | |
except Exception: | |
logging.error("Failed to write vocals.mp3, trying pydub...") | |
pydub_write(wave, f"{base_dir}/vocals.mp3", sr) | |
def pydub_write(wave, output_path, frame_rate, audio_format="mp3"): | |
# Ensure the wave data is in the right format for pydub (mono and 16-bit depth) | |
wave_16bit = (wave * 32767).astype(np.int16) | |
audio_segment = AudioSegment( | |
wave_16bit.tobytes(), | |
frame_rate=frame_rate, | |
sample_width=wave_16bit.dtype.itemsize, | |
channels=1, | |
) | |
audio_segment.export(output_path, format=audio_format) | |
##################################################################################### | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class BaseNet(nn.Module): | |
def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))): | |
super(BaseNet, self).__init__() | |
self.enc1 = Conv2DBNActiv(nin, nout, 3, 1, 1) | |
self.enc2 = Encoder(nout, nout * 2, 3, 2, 1) | |
self.enc3 = Encoder(nout * 2, nout * 4, 3, 2, 1) | |
self.enc4 = Encoder(nout * 4, nout * 6, 3, 2, 1) | |
self.enc5 = Encoder(nout * 6, nout * 8, 3, 2, 1) | |
self.aspp = ASPPModule(nout * 8, nout * 8, dilations, dropout=True) | |
self.dec4 = Decoder(nout * (6 + 8), nout * 6, 3, 1, 1) | |
self.dec3 = Decoder(nout * (4 + 6), nout * 4, 3, 1, 1) | |
self.dec2 = Decoder(nout * (2 + 4), nout * 2, 3, 1, 1) | |
self.lstm_dec2 = LSTMModule(nout * 2, nin_lstm, nout_lstm) | |
self.dec1 = Decoder(nout * (1 + 2) + 1, nout * 1, 3, 1, 1) | |
def __call__(self, x): | |
e1 = self.enc1(x) | |
e2 = self.enc2(e1) | |
e3 = self.enc3(e2) | |
e4 = self.enc4(e3) | |
e5 = self.enc5(e4) | |
h = self.aspp(e5) | |
h = self.dec4(h, e4) | |
h = self.dec3(h, e3) | |
h = self.dec2(h, e2) | |
h = torch.cat([h, self.lstm_dec2(h)], dim=1) | |
h = self.dec1(h, e1) | |
return h | |
class CascadedNet(nn.Module): | |
def __init__(self, n_fft, nout=32, nout_lstm=128): | |
super(CascadedNet, self).__init__() | |
self.max_bin = n_fft // 2 | |
self.output_bin = n_fft // 2 + 1 | |
self.nin_lstm = self.max_bin // 2 | |
self.offset = 64 | |
self.stg1_low_band_net = nn.Sequential( | |
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm), | |
Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0), | |
) | |
self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2) | |
self.stg2_low_band_net = nn.Sequential( | |
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), | |
Conv2DBNActiv(nout, nout // 2, 1, 1, 0), | |
) | |
self.stg2_high_band_net = BaseNet( | |
nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2 | |
) | |
self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm) | |
self.out = nn.Conv2d(nout, 2, 1, bias=False) | |
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) | |
def forward(self, x): | |
x = x[:, :, : self.max_bin] | |
bandw = x.size()[2] // 2 | |
l1_in = x[:, :, :bandw] | |
h1_in = x[:, :, bandw:] | |
l1 = self.stg1_low_band_net(l1_in) | |
h1 = self.stg1_high_band_net(h1_in) | |
aux1 = torch.cat([l1, h1], dim=2) | |
l2_in = torch.cat([l1_in, l1], dim=1) | |
h2_in = torch.cat([h1_in, h1], dim=1) | |
l2 = self.stg2_low_band_net(l2_in) | |
h2 = self.stg2_high_band_net(h2_in) | |
aux2 = torch.cat([l2, h2], dim=2) | |
f3_in = torch.cat([x, aux1, aux2], dim=1) | |
f3 = self.stg3_full_band_net(f3_in) | |
mask = torch.sigmoid(self.out(f3)) | |
mask = F.pad( | |
input=mask, | |
pad=(0, 0, 0, self.output_bin - mask.size()[2]), | |
mode="replicate", | |
) | |
if self.training: | |
aux = torch.cat([aux1, aux2], dim=1) | |
aux = torch.sigmoid(self.aux_out(aux)) | |
aux = F.pad( | |
input=aux, | |
pad=(0, 0, 0, self.output_bin - aux.size()[2]), | |
mode="replicate", | |
) | |
return mask, aux | |
else: | |
return mask | |
def predict_mask(self, x): | |
mask = self.forward(x) | |
if self.offset > 0: | |
mask = mask[:, :, :, self.offset : -self.offset] | |
assert mask.size()[3] > 0 | |
return mask | |
def predict(self, x): | |
mask = self.forward(x) | |
pred_mag = x * mask | |
if self.offset > 0: | |
pred_mag = pred_mag[:, :, :, self.offset : -self.offset] | |
assert pred_mag.size()[3] > 0 | |
return pred_mag | |
############################################################################## | |
def crop_center(h1, h2): | |
h1_shape = h1.size() | |
h2_shape = h2.size() | |
if h1_shape[3] == h2_shape[3]: | |
return h1 | |
elif h1_shape[3] < h2_shape[3]: | |
raise ValueError("h1_shape[3] must be greater than h2_shape[3]") | |
s_time = (h1_shape[3] - h2_shape[3]) // 2 | |
e_time = s_time + h2_shape[3] | |
h1 = h1[:, :, :, s_time:e_time] | |
return h1 | |
class Conv2DBNActiv(nn.Module): | |
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): | |
super(Conv2DBNActiv, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d( | |
nin, | |
nout, | |
kernel_size=ksize, | |
stride=stride, | |
padding=pad, | |
dilation=dilation, | |
bias=False, | |
), | |
nn.BatchNorm2d(nout), | |
activ(), | |
) | |
def __call__(self, x): | |
return self.conv(x) | |
class Encoder(nn.Module): | |
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): | |
super(Encoder, self).__init__() | |
self.conv1 = Conv2DBNActiv(nin, nout, ksize, stride, pad, activ=activ) | |
self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) | |
def __call__(self, x): | |
h = self.conv1(x) | |
h = self.conv2(h) | |
return h | |
class Decoder(nn.Module): | |
def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False): | |
super(Decoder, self).__init__() | |
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) | |
self.dropout = nn.Dropout2d(0.1) if dropout else None | |
def __call__(self, x, skip=None): | |
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) | |
if skip is not None: | |
skip = crop_center(skip, x) | |
x = torch.cat([x, skip], dim=1) | |
h = self.conv1(x) | |
# h = self.conv2(h) | |
if self.dropout is not None: | |
h = self.dropout(h) | |
return h | |
class ASPPModule(nn.Module): | |
def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False): | |
super(ASPPModule, self).__init__() | |
self.conv1 = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, None)), | |
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ), | |
) | |
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) | |
self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ) | |
self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ) | |
self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ) | |
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) | |
self.dropout = nn.Dropout2d(0.1) if dropout else None | |
def forward(self, x): | |
_, _, h, w = x.size() | |
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True) | |
feat2 = self.conv2(x) | |
feat3 = self.conv3(x) | |
feat4 = self.conv4(x) | |
feat5 = self.conv5(x) | |
out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) | |
out = self.bottleneck(out) | |
if self.dropout is not None: | |
out = self.dropout(out) | |
return out | |
class LSTMModule(nn.Module): | |
def __init__(self, nin_conv, nin_lstm, nout_lstm): | |
super(LSTMModule, self).__init__() | |
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) | |
self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True) | |
self.dense = nn.Sequential( | |
nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU() | |
) | |
def forward(self, x): | |
N, _, nbins, nframes = x.size() | |
h = self.conv(x)[:, 0] # N, nbins, nframes | |
h = h.permute(2, 0, 1) # nframes, N, nbins | |
h, _ = self.lstm(h) | |
h = self.dense(h.reshape(-1, h.size()[-1])) # nframes * N, nbins | |
h = h.reshape(nframes, N, 1, nbins) | |
h = h.permute(1, 2, 3, 0) | |
return h |