uvr5 / download.py
lorneluo's picture
model downloader
83c8e0b
raw
history blame
2.99 kB
import json
import os
import sys
import urllib
from pprint import pprint
import wget
from tqdm import tqdm
from UVR import DEMUCS_NEWER_REPO_DIR, VR_MODELS_DIR, MDX_MODELS_DIR
from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO
online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS))
mdx_download_list = {
**online_model_data["mdx_download_list"],
**online_model_data["mdx23c_download_list"],
**online_model_data["mdx23_download_list"],
# **online_model_data["mdx_download_vip_list"],
# **online_model_data["mdx23c_download_vip_list"],
}
vr_download_list = online_model_data["vr_download_list"]
demucs_download_list = online_model_data["demucs_download_list"]
def get_mdx_model_file(model):
return get_mdx_model_filelist(model)[0][0]
def get_mdx_model_filelist(model):
filename = mdx_download_list[model]
if isinstance(filename, dict):
model_name = list(filename.keys())[0]
else:
model_name = str(filename)
model_path = os.path.join(MDX_MODELS_DIR, model_name)
url = f"{NORMAL_REPO}{model_name}"
return [(model_path, url)]
def get_vr_model_file(model):
return get_vr_model_filelist(model)[0][0]
def get_vr_model_filelist(model):
filename = vr_download_list[model]
url = f"{NORMAL_REPO}{filename}"
model_path = os.path.join(VR_MODELS_DIR, filename)
return [(model_path, url)]
def get_demucs_model_file(model):
for filename, url in get_demucs_model_filelist(model):
if filename.lower().endswith('.yaml'):
return filename
def get_demucs_model_filelist(model):
download_demucs_newer_models = []
for filename, url in demucs_download_list[model].items():
model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename)
download_demucs_newer_models.append((model_path, url))
return download_demucs_newer_models
def get_model_file(model_name):
if model_name in mdx_download_list:
model_path = get_mdx_model_file(model_name)
elif model_name in vr_download_list:
model_path = get_vr_model_file(model_name)
elif model_name in demucs_download_list:
model_path = get_demucs_model_file(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
return model_path
def download_model(model_name):
if model_name in mdx_download_list:
filelist = get_mdx_model_filelist(model_name)
elif model_name in vr_download_list:
filelist = get_vr_model_filelist(model_name)
elif model_name in demucs_download_list:
filelist = get_demucs_model_filelist(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
for model_path, url in filelist:
if os.path.isfile(model_path):
return
print(f'Downloading from {url} to {model_path}')
wget.download(url, model_path)
if __name__ == '__main__':
model_name = sys.argv[1]
download_model(model_name)