Spaces:
Runtime error
Runtime error
Update infer/lib/train/process_ckpt.py
Browse files
infer/lib/train/process_ckpt.py
CHANGED
@@ -10,10 +10,53 @@ from i18n.i18n import I18nAuto
|
|
10 |
i18n = I18nAuto()
|
11 |
|
12 |
|
13 |
-
def savee(ckpt, sr, if_f0, name, epoch, version, hps):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
try:
|
15 |
-
|
16 |
-
opt = OrderedDict()
|
17 |
opt["weight"] = {}
|
18 |
for key in ckpt.keys():
|
19 |
if "enc_q" in key:
|
@@ -43,11 +86,39 @@ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
|
|
43 |
opt["sr"] = sr
|
44 |
opt["f0"] = if_f0
|
45 |
opt["version"] = version
|
46 |
-
torch.save(opt, "assets/weights/%s.pth" % name)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
return "Success."
|
49 |
-
except:
|
50 |
-
|
|
|
51 |
|
52 |
|
53 |
def show_info(path):
|
|
|
10 |
i18n = I18nAuto()
|
11 |
|
12 |
|
13 |
+
# def savee(ckpt, sr, if_f0, name, epoch, version, hps):
|
14 |
+
# try:
|
15 |
+
# opt = OrderedDict()
|
16 |
+
# opt["weight"] = {}
|
17 |
+
# for key in ckpt.keys():
|
18 |
+
# if "enc_q" in key:
|
19 |
+
# continue
|
20 |
+
# opt["weight"][key] = ckpt[key].half()
|
21 |
+
# opt["config"] = [
|
22 |
+
# hps.data.filter_length // 2 + 1,
|
23 |
+
# 32,
|
24 |
+
# hps.model.inter_channels,
|
25 |
+
# hps.model.hidden_channels,
|
26 |
+
# hps.model.filter_channels,
|
27 |
+
# hps.model.n_heads,
|
28 |
+
# hps.model.n_layers,
|
29 |
+
# hps.model.kernel_size,
|
30 |
+
# hps.model.p_dropout,
|
31 |
+
# hps.model.resblock,
|
32 |
+
# hps.model.resblock_kernel_sizes,
|
33 |
+
# hps.model.resblock_dilation_sizes,
|
34 |
+
# hps.model.upsample_rates,
|
35 |
+
# hps.model.upsample_initial_channel,
|
36 |
+
# hps.model.upsample_kernel_sizes,
|
37 |
+
# hps.model.spk_embed_dim,
|
38 |
+
# hps.model.gin_channels,
|
39 |
+
# hps.data.sampling_rate,
|
40 |
+
# ]
|
41 |
+
# opt["info"] = "%sepoch" % epoch
|
42 |
+
# opt["sr"] = sr
|
43 |
+
# opt["f0"] = if_f0
|
44 |
+
# opt["version"] = version
|
45 |
+
# torch.save(opt, "assets/weights/%s.pth" % name)
|
46 |
+
|
47 |
+
# return "Success."
|
48 |
+
# except:
|
49 |
+
# return traceback.format_exc()
|
50 |
+
|
51 |
+
import os
|
52 |
+
import torch
|
53 |
+
from google.oauth2 import service_account
|
54 |
+
from googleapiclient.discovery import build
|
55 |
+
from googleapiclient.http import MediaFileUpload
|
56 |
+
|
57 |
+
def savee(ckpt, sr, if_f0, name, epoch, version, hps, credentials_path):
|
58 |
try:
|
59 |
+
opt = {}
|
|
|
60 |
opt["weight"] = {}
|
61 |
for key in ckpt.keys():
|
62 |
if "enc_q" in key:
|
|
|
86 |
opt["sr"] = sr
|
87 |
opt["f0"] = if_f0
|
88 |
opt["version"] = version
|
|
|
89 |
|
90 |
+
# Save the checkpoint to a local file
|
91 |
+
checkpoint_file = "assets/weights/%s.pth" % name
|
92 |
+
torch.save(opt, checkpoint_file)
|
93 |
+
|
94 |
+
# Set up Google Drive API and upload the file
|
95 |
+
SCOPES = ['https://www.googleapis.com/auth/drive.file']
|
96 |
+
creds = service_account.Credentials.from_service_account_file(credentials_path, scopes=SCOPES)
|
97 |
+
|
98 |
+
drive_service = build('drive', 'v3', credentials=creds)
|
99 |
+
|
100 |
+
# Specify the file name for Google Drive
|
101 |
+
drive_file_name = 'your_drive_file_name_here'
|
102 |
+
|
103 |
+
# Create a MediaFileUpload object to upload the file
|
104 |
+
media = MediaFileUpload(checkpoint_file, resumable=True)
|
105 |
+
|
106 |
+
# Create a file on Google Drive
|
107 |
+
file_metadata = {'name': drive_file_name}
|
108 |
+
file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()
|
109 |
+
|
110 |
+
# Make the uploaded file publicly accessible (optional)
|
111 |
+
drive_service.permissions().create(
|
112 |
+
fileId=file['id'],
|
113 |
+
body={'type': 'anyone', 'role': 'reader'}
|
114 |
+
).execute()
|
115 |
+
|
116 |
+
print(f'File ID: {file.get("id")}')
|
117 |
+
print(file)
|
118 |
return "Success."
|
119 |
+
except Exception as e:
|
120 |
+
print("Error:", str(e))
|
121 |
+
|
122 |
|
123 |
|
124 |
def show_info(path):
|