File size: 2,450 Bytes
9a85c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
from random import shuffle

import tqdm
from text.cleaner import clean_text
from collections import defaultdict
import shutil
stage = [1,2,3]

transcription_path = 'filelists/short_character_anno.list'
train_path = 'filelists/train.list'
val_path = 'filelists/val.list'
config_path = "configs/config.json"
val_per_spk = 4
max_val_total = 8

if 1 in stage:
    with open( transcription_path+'.cleaned', 'w', encoding='utf-8') as f:
        for line in tqdm.tqdm(open(transcription_path, encoding='utf-8').readlines()):
            try:
                utt, spk, language, text = line.strip().split('|')
                #language = "ZH"
                norm_text, phones, tones, word2ph = clean_text(text, language)
                f.write('{}|{}|{}|{}|{}|{}|{}\n'.format(utt, spk, language, norm_text, ' '.join(phones),
                                                     " ".join([str(i) for i in tones]),
                                                     " ".join([str(i) for i in word2ph])))
            except:
                print("err!", utt)

if 2 in stage:
    spk_utt_map = defaultdict(list)
    spk_id_map = {}
    current_sid = 0

    with open( transcription_path+'.cleaned', encoding='utf-8') as f:
        for line in f.readlines():
            utt, spk, language, text, phones, tones, word2ph = line.strip().split('|')
            spk_utt_map[spk].append(line)
            if spk not in spk_id_map.keys():
                spk_id_map[spk] = current_sid
                current_sid += 1
    train_list = []
    val_list = []
    for spk, utts in spk_utt_map.items():
          shuffle(utts)
          val_list+=utts[:val_per_spk]
          train_list+=utts[val_per_spk:]
    if len(val_list) > max_val_total:
          train_list+=val_list[max_val_total:]
          val_list = val_list[:max_val_total]
     
    with open( train_path,"w", encoding='utf-8') as f:
         for line in train_list:
            f.write(line)

    file_path = transcription_path+'.cleaned'
    shutil.copy(file_path,'./filelists/train.list')
    
    with open(val_path, "w", encoding='utf-8') as f:
        for line in val_list:
            f.write(line)

if 3 in stage:
    assert 2 in stage
    config = json.load(open(config_path))
    config['data']["n_speakers"] = current_sid #
    config["data"]['spk2id'] = spk_id_map
    with open(config_path, 'w', encoding='utf-8') as f:
        json.dump(config, f, indent=2, ensure_ascii=False)