|
import yaml |
|
import torch |
|
from diffusers import DDIMScheduler |
|
from .model.model import DiffVC |
|
from .model.model_cross import DiffVC_Cross |
|
from .utils import scale_shift, scale_shift_re, rescale_noise_cfg |
|
|
|
|
|
class ReDiffVC(object): |
|
def __init__(self, |
|
config_path='configs/diffvc_base.yaml', |
|
ckpt_path='../ckpts/dreamvc_base.pt', |
|
device='cpu'): |
|
|
|
with open(config_path, 'r') as fp: |
|
config = yaml.safe_load(fp) |
|
|
|
self.device = device |
|
self.model = DiffVC(config['model']).to(device) |
|
self.model.load_state_dict(torch.load(ckpt_path)['model']) |
|
self.model.eval() |
|
|
|
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'], |
|
beta_start=config['scheduler']['beta_start'], |
|
beta_end=config['scheduler']['beta_end'], |
|
rescale_betas_zero_snr=True, |
|
timestep_spacing="trailing", |
|
clip_sample=False, |
|
prediction_type='v_prediction') |
|
self.noise_scheduler = noise_scheduler |
|
self.scale = config['scheduler']['scale'] |
|
self.shift = config['scheduler']['shift'] |
|
self.melshape = config['model']['unet']['sample_size'][0] |
|
|
|
@torch.no_grad() |
|
def inference(self, |
|
spk_embed, content_clip, f0_clip=None, |
|
guidance_scale=3, guidance_rescale=0.7, |
|
ddim_steps=50, eta=1, random_seed=2023): |
|
|
|
self.model.eval() |
|
if random_seed is not None: |
|
generator = torch.Generator(device=self.device).manual_seed(random_seed) |
|
else: |
|
generator = torch.Generator(device=self.device) |
|
generator.seed() |
|
|
|
self.noise_scheduler.set_timesteps(ddim_steps) |
|
|
|
|
|
gen_shape = (1, 1, self.melshape, content_clip.shape[-2]) |
|
noise = torch.randn(gen_shape, generator=generator, device=self.device) |
|
latents = noise |
|
|
|
for t in self.noise_scheduler.timesteps: |
|
latents = self.noise_scheduler.scale_model_input(latents, t) |
|
|
|
if guidance_scale: |
|
output_text = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False) |
|
output_uncond = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=True, |
|
speaker_cfg=1.0, pitch_cfg=0.0) |
|
|
|
output_pred = output_uncond + guidance_scale * (output_text - output_uncond) |
|
if guidance_rescale > 0.0: |
|
output_pred = rescale_noise_cfg(output_pred, output_text, |
|
guidance_rescale=guidance_rescale) |
|
else: |
|
output_pred = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False) |
|
|
|
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents, |
|
eta=eta, generator=generator).prev_sample |
|
|
|
pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift) |
|
return pred |
|
|
|
|
|
class DreamVC(object): |
|
def __init__(self, |
|
config_path='configs/diffvc_cross.yaml', |
|
ckpt_path='../ckpts/dreamvc_cross.pt', |
|
device='cpu'): |
|
|
|
with open(config_path, 'r') as fp: |
|
config = yaml.safe_load(fp) |
|
|
|
self.device = device |
|
self.model = DiffVC_Cross(config['model']).to(device) |
|
self.model.load_state_dict(torch.load(ckpt_path)['model']) |
|
self.model.eval() |
|
|
|
noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'], |
|
beta_start=config['scheduler']['beta_start'], |
|
beta_end=config['scheduler']['beta_end'], |
|
rescale_betas_zero_snr=True, |
|
timestep_spacing="trailing", |
|
clip_sample=False, |
|
prediction_type='v_prediction') |
|
self.noise_scheduler = noise_scheduler |
|
self.scale = config['scheduler']['scale'] |
|
self.shift = config['scheduler']['shift'] |
|
self.melshape = config['model']['unet']['sample_size'][0] |
|
|
|
@torch.no_grad() |
|
def inference(self, |
|
text, content_clip, f0_clip=None, |
|
guidance_scale=3, guidance_rescale=0.7, |
|
ddim_steps=50, eta=1, random_seed=2023): |
|
|
|
text, text_mask = text |
|
self.model.eval() |
|
if random_seed is not None: |
|
generator = torch.Generator(device=self.device).manual_seed(random_seed) |
|
else: |
|
generator = torch.Generator(device=self.device) |
|
generator.seed() |
|
|
|
self.noise_scheduler.set_timesteps(ddim_steps) |
|
|
|
|
|
gen_shape = (1, 1, self.melshape, content_clip.shape[-2]) |
|
noise = torch.randn(gen_shape, generator=generator, device=self.device) |
|
latents = noise |
|
|
|
for t in self.noise_scheduler.timesteps: |
|
latents = self.noise_scheduler.scale_model_input(latents, t) |
|
|
|
if guidance_scale: |
|
output_text = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False) |
|
output_uncond = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=True, |
|
speaker_cfg=1.0, pitch_cfg=0.0) |
|
|
|
output_pred = output_uncond + guidance_scale * (output_text - output_uncond) |
|
if guidance_rescale > 0.0: |
|
output_pred = rescale_noise_cfg(output_pred, output_text, |
|
guidance_rescale=guidance_rescale) |
|
else: |
|
output_pred = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False) |
|
|
|
latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents, |
|
eta=eta, generator=generator).prev_sample |
|
|
|
pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift) |
|
return pred |
|
|
|
|