File size: 2,935 Bytes
ed6fc19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, AutoConfig
from encodec import EncodecModel
from encodec.utils import convert_audio
import torch
import torchaudio
import re

class GPTTTS(PreTrainedModel):
    def __init__(self, *model_args, **model_kwargs):
        super().__init__(AutoConfig.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice"), *model_args, **model_kwargs)
        self.model = AutoModelForCausalLM.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice")
        self.encodec_model = EncodecModel.encodec_model_24khz()
        self.encodec_model.set_target_bandwidth(1.5)
        self.sample_rate = self.encodec_model.sample_rate
    
    def forward(self, input_ids):
        #decoded = tokenizer.decode(tokens[0], skip_special_tokens=True)
        #decoded = input_text
        # Get all audio_token_
        #pattern = r'audio_token_(\d+)'
        #audio_tokens = re.findall(pattern, decoded)
        #audio_tokens = [int(token) for token in audio_tokens]

        tokens = self.model.generate(input_ids, do_sample=True, max_length=1024, temperature=1, top_k=50, top_p=0.95)[0]
        # Get all tokens which are larger than 50257, and subtract 50257 from them
        audio_tokens = [token - 50257 for token in tokens if token > 50257]

        number_of_codebooks = 2
        number_of_samples = len(audio_tokens) // number_of_codebooks
        frame = torch.zeros(1, number_of_codebooks, number_of_samples, dtype=torch.long)
        for sample in range(number_of_samples):
            for codebook in range(number_of_codebooks):
              frame[0, codebook, sample] = audio_tokens[sample * number_of_codebooks + codebook]
    
        frames = [(frame, None)]

        with torch.no_grad():
            wav = self.encodec_model.decode(frames)
    
        return wav[0, :, :]


class GPTTTSTokenizer(PreTrainedTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = AutoTokenizer.from_pretrained("anforsm/distilgpt2-finetuned-common-voice")
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
    
    def tokenize(self, text, *args, **kwargs):
        prompt = f"text: {text}\nsound:"
        return self.tokenizer(prompt, return_tensors="pt")
    
    def _tokenize(self, *args, **kwargs):
        return self.tokenize(*args, **kwargs)
    
    def convert_tokens_to_ids(self, tokens):
        return tokens["input_ids"]
    
    def convert_ids_to_tokens(self, ids):
        return self.tokenizer.decode(ids[0], skip_special_tokens=True)
    
    def _batch_encode_plus(self, *args, **kwargs):
        return self.tokenize(*args, **kwargs)
    
    def _encode_plus(self, *args, **kwargs):
        return self.tokenize(*args, **kwargs)
    
    
    def save_vocabulary(self, *args, **kwargs):
        return self.tokenizer.save_vocabulary(*args, **kwargs)