DreamVoice / dreamvoice /src /vc_wrapper.py
Higobeatz's picture
Initial commit
0a97d6c
import yaml
import torch
from diffusers import DDIMScheduler
from .model.model import DiffVC
from .model.model_cross import DiffVC_Cross
from .utils import scale_shift, scale_shift_re, rescale_noise_cfg
class ReDiffVC(object):
def __init__(self,
config_path='configs/diffvc_base.yaml',
ckpt_path='../ckpts/dreamvc_base.pt',
device='cpu'):
with open(config_path, 'r') as fp:
config = yaml.safe_load(fp)
self.device = device
self.model = DiffVC(config['model']).to(device)
self.model.load_state_dict(torch.load(ckpt_path)['model'])
self.model.eval()
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
beta_start=config['scheduler']['beta_start'],
beta_end=config['scheduler']['beta_end'],
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
clip_sample=False,
prediction_type='v_prediction')
self.noise_scheduler = noise_scheduler
self.scale = config['scheduler']['scale']
self.shift = config['scheduler']['shift']
self.melshape = config['model']['unet']['sample_size'][0]
@torch.no_grad()
def inference(self,
spk_embed, content_clip, f0_clip=None,
guidance_scale=3, guidance_rescale=0.7,
ddim_steps=50, eta=1, random_seed=2023):
self.model.eval()
if random_seed is not None:
generator = torch.Generator(device=self.device).manual_seed(random_seed)
else:
generator = torch.Generator(device=self.device)
generator.seed()
self.noise_scheduler.set_timesteps(ddim_steps)
# init noise
gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
noise = torch.randn(gen_shape, generator=generator, device=self.device)
latents = noise
for t in self.noise_scheduler.timesteps:
latents = self.noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
output_text = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
output_uncond = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=True,
speaker_cfg=1.0, pitch_cfg=0.0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
output_pred = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
return pred
class DreamVC(object):
def __init__(self,
config_path='configs/diffvc_cross.yaml',
ckpt_path='../ckpts/dreamvc_cross.pt',
device='cpu'):
with open(config_path, 'r') as fp:
config = yaml.safe_load(fp)
self.device = device
self.model = DiffVC_Cross(config['model']).to(device)
self.model.load_state_dict(torch.load(ckpt_path)['model'])
self.model.eval()
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
beta_start=config['scheduler']['beta_start'],
beta_end=config['scheduler']['beta_end'],
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
clip_sample=False,
prediction_type='v_prediction')
self.noise_scheduler = noise_scheduler
self.scale = config['scheduler']['scale']
self.shift = config['scheduler']['shift']
self.melshape = config['model']['unet']['sample_size'][0]
@torch.no_grad()
def inference(self,
text, content_clip, f0_clip=None,
guidance_scale=3, guidance_rescale=0.7,
ddim_steps=50, eta=1, random_seed=2023):
text, text_mask = text
self.model.eval()
if random_seed is not None:
generator = torch.Generator(device=self.device).manual_seed(random_seed)
else:
generator = torch.Generator(device=self.device)
generator.seed()
self.noise_scheduler.set_timesteps(ddim_steps)
# init noise
gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
noise = torch.randn(gen_shape, generator=generator, device=self.device)
latents = noise
for t in self.noise_scheduler.timesteps:
latents = self.noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
output_text = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
output_uncond = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=True,
speaker_cfg=1.0, pitch_cfg=0.0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
output_pred = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
return pred