jayparmr commited on
Commit
a05f11d
·
1 Parent(s): 64e2b0f

Update infer/lib/train/process_ckpt.py

Browse files
Files changed (1) hide show
  1. infer/lib/train/process_ckpt.py +77 -6
infer/lib/train/process_ckpt.py CHANGED
@@ -10,10 +10,53 @@ from i18n.i18n import I18nAuto
10
  i18n = I18nAuto()
11
 
12
 
13
- def savee(ckpt, sr, if_f0, name, epoch, version, hps):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
- print("ckpt, sr, if_f0, name, epoch, version, hps", ckpt, sr, if_f0, name, epoch, version, hps)
16
- opt = OrderedDict()
17
  opt["weight"] = {}
18
  for key in ckpt.keys():
19
  if "enc_q" in key:
@@ -43,11 +86,39 @@ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
43
  opt["sr"] = sr
44
  opt["f0"] = if_f0
45
  opt["version"] = version
46
- torch.save(opt, "assets/weights/%s.pth" % name)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return "Success."
49
- except:
50
- return traceback.format_exc()
 
51
 
52
 
53
  def show_info(path):
 
10
  i18n = I18nAuto()
11
 
12
 
13
+ # def savee(ckpt, sr, if_f0, name, epoch, version, hps):
14
+ # try:
15
+ # opt = OrderedDict()
16
+ # opt["weight"] = {}
17
+ # for key in ckpt.keys():
18
+ # if "enc_q" in key:
19
+ # continue
20
+ # opt["weight"][key] = ckpt[key].half()
21
+ # opt["config"] = [
22
+ # hps.data.filter_length // 2 + 1,
23
+ # 32,
24
+ # hps.model.inter_channels,
25
+ # hps.model.hidden_channels,
26
+ # hps.model.filter_channels,
27
+ # hps.model.n_heads,
28
+ # hps.model.n_layers,
29
+ # hps.model.kernel_size,
30
+ # hps.model.p_dropout,
31
+ # hps.model.resblock,
32
+ # hps.model.resblock_kernel_sizes,
33
+ # hps.model.resblock_dilation_sizes,
34
+ # hps.model.upsample_rates,
35
+ # hps.model.upsample_initial_channel,
36
+ # hps.model.upsample_kernel_sizes,
37
+ # hps.model.spk_embed_dim,
38
+ # hps.model.gin_channels,
39
+ # hps.data.sampling_rate,
40
+ # ]
41
+ # opt["info"] = "%sepoch" % epoch
42
+ # opt["sr"] = sr
43
+ # opt["f0"] = if_f0
44
+ # opt["version"] = version
45
+ # torch.save(opt, "assets/weights/%s.pth" % name)
46
+
47
+ # return "Success."
48
+ # except:
49
+ # return traceback.format_exc()
50
+
51
+ import os
52
+ import torch
53
+ from google.oauth2 import service_account
54
+ from googleapiclient.discovery import build
55
+ from googleapiclient.http import MediaFileUpload
56
+
57
+ def savee(ckpt, sr, if_f0, name, epoch, version, hps, credentials_path):
58
  try:
59
+ opt = {}
 
60
  opt["weight"] = {}
61
  for key in ckpt.keys():
62
  if "enc_q" in key:
 
86
  opt["sr"] = sr
87
  opt["f0"] = if_f0
88
  opt["version"] = version
 
89
 
90
+ # Save the checkpoint to a local file
91
+ checkpoint_file = "assets/weights/%s.pth" % name
92
+ torch.save(opt, checkpoint_file)
93
+
94
+ # Set up Google Drive API and upload the file
95
+ SCOPES = ['https://www.googleapis.com/auth/drive.file']
96
+ creds = service_account.Credentials.from_service_account_file(credentials_path, scopes=SCOPES)
97
+
98
+ drive_service = build('drive', 'v3', credentials=creds)
99
+
100
+ # Specify the file name for Google Drive
101
+ drive_file_name = 'your_drive_file_name_here'
102
+
103
+ # Create a MediaFileUpload object to upload the file
104
+ media = MediaFileUpload(checkpoint_file, resumable=True)
105
+
106
+ # Create a file on Google Drive
107
+ file_metadata = {'name': drive_file_name}
108
+ file = drive_service.files().create(body=file_metadata, media_body=media, fields='id').execute()
109
+
110
+ # Make the uploaded file publicly accessible (optional)
111
+ drive_service.permissions().create(
112
+ fileId=file['id'],
113
+ body={'type': 'anyone', 'role': 'reader'}
114
+ ).execute()
115
+
116
+ print(f'File ID: {file.get("id")}')
117
+ print(file)
118
  return "Success."
119
+ except Exception as e:
120
+ print("Error:", str(e))
121
+
122
 
123
 
124
  def show_info(path):