File size: 6,161 Bytes
b8b70ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54fb9ea
 
b8b70ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54fb9ea
b8b70ac
 
 
 
 
 
 
028b554
b8b70ac
 
 
 
54fb9ea
b8b70ac
 
 
 
 
 
 
 
54fb9ea
 
b8b70ac
 
 
d171db8
b8b70ac
 
 
 
725a102
b8b70ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import argparse
import torch
import commons
import utils
from models import (
    SynthesizerTrn, )

from text.symbols import symbol_len, lang_to_dict

# we use Kyubyong/g2p for demo instead of our internal g2p
# https://github.com/Kyubyong/g2p
from g2p_en import G2p
import re

_symbol_to_id = lang_to_dict("en_US")

class GradioApp:

    def __init__(self, args):
        self.hps = utils.get_hparams_from_file(args.config)
        self.device = "cpu"
        self.net_g = SynthesizerTrn(symbol_len(self.hps.data.languages),
                                    self.hps.data.filter_length // 2 + 1,
                                    self.hps.train.segment_size //
                                    self.hps.data.hop_length,
                                    midi_start=-5,
                                    midi_end=75,
                                    octave_range=24,
                                    n_speakers=len(self.hps.data.speakers),
                                    **self.hps.model).to(self.device)
        _ = self.net_g.eval()
        _ = utils.load_checkpoint(args.checkpoint_path, model_g=self.net_g)
        self.g2p = G2p()
        self.interface = self._gradio_interface()

    def get_phoneme(self, text):
        phones = [re.sub("[0-9]", "", p) for p in self.g2p(text)]
        tone = [0 for p in phones]
        if self.hps.data.add_blank:
            text_norm = [_symbol_to_id[symbol] for symbol in phones]
            text_norm = commons.intersperse(text_norm, 0)
            tone = commons.intersperse(tone, 0)
        else:
            text_norm = phones
        text_norm = torch.LongTensor(text_norm)
        tone = torch.LongTensor(tone)
        return text_norm, tone, phones
    
    def inference(self, text, speaker_id_val, seed, scope_shift, duration):
        seed = int(seed)
        scope_shift = int(scope_shift)
        torch.manual_seed(seed)
        text_norm, tone, phones = self.get_phoneme(text)
        x_tst = text_norm.to(self.device).unsqueeze(0)
        t_tst = tone.to(self.device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([text_norm.size(0)]).to(self.device)
        speaker_id = torch.LongTensor([speaker_id_val]).to(self.device)
        decoder_inputs,*_ = self.net_g.infer_pre_decoder(
                                           x_tst,
                                           t_tst,
                                           x_tst_lengths,
                                           sid=speaker_id,
                                           noise_scale=0.667,
                                           noise_scale_w=0.8,
                                           length_scale=duration,
                                           scope_shift=scope_shift)
        audio = self.net_g.infer_decode_chunk(
            decoder_inputs, sid=speaker_id)[0, 0].data.cpu().float().numpy()
        del decoder_inputs,  
        return phones, (self.hps.data.sampling_rate, audio)


    def _gradio_interface(self):
        title = "PITS Demo"
        self.inputs = [
            gr.Textbox(label="Text (150 words limitation)",
                       value="This is demo page.",
                       elem_id="tts-input"),
            gr.Dropdown(list(self.hps.data.speakers),
                        value="p225",
                        label="Speaker Identity",
                        type="index"),
            gr.Slider(0, 65536, value=0, step=1, label="random seed"),
            gr.Slider(-15, 15, value=0, step=1, label="scope-shift"),
            gr.Slider(0.5, 2., value=1., step=0.1,
                      label="duration multiplier"),
        ]
        self.outputs = [
            gr.Textbox(label="Phonemes"),
            gr.Audio(type="numpy", label="Output audio")
        ]
        description = "Welcome to the Gradio demo for PITS: Variational Pitch Inference without Fundamental Frequency for End-to-End Pitch-controllable TTS.\n In this demo, we utilize an open-source G2P library (g2p_en) with stress removing, instead of our internal G2P.\n You can fix the latent z by controlling random seed.\n You can shift the pitch scope, but please note that this is opposite to pitch-shift. In addition, it is cropped from fixed z so please check pitch-controllability by comparing with normal synthesis.\n Thank you for trying out our PITS demo!"
        article = "Github:https://github.com/anonymous-pits/pits \n Our current preprint contains several errors. Please wait for next update."
        examples = [["This is a demo page of the PITS."],["I love hugging face."]]
        return gr.Interface(
            fn=self.inference,
            inputs=self.inputs,
            outputs=self.outputs,
            title=title,
            description=description,
            article=article,
            cache_examples=False,
            examples=examples,
        )

    def launch(self):
        return self.interface.launch(share=False)


def parsearg():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--config',
                        type=str,
                        default="./configs/config_en.yaml",
                        help='Path to configuration file')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='PITS',
                        help='Model name')
    parser.add_argument('-r',
                        '--checkpoint_path',
                        type=str,
                        default='./logs/pits_vctk_AD_3000.pth',
                        help='Path to checkpoint for resume')
    parser.add_argument('-f',
                        '--force_resume',
                        type=str,
                        help='Path to checkpoint for force resume')
    parser.add_argument('-d',
                        '--dir',
                        type=str,
                        default='/DATA/audio/pits_samples',
                        help='root dir')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parsearg()
    app = GradioApp(args)
    app.launch()