ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
import sys
import launch
import platform
import os
import shutil
import site
import glob
import re
dirname = os.path.dirname(__file__)
repo_dir = os.path.join(dirname, "kohya_ss")
def prepare_environment():
torch_command = os.environ.get(
"TORCH_COMMAND",
"pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118",
)
sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git")
sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main")
requirements_file = os.environ.get("REQS_FILE", "requirements.txt")
sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install")
sys.argv, disable_strict_version = launch.extract_arg(
sys.argv, "--disable-strict-version"
)
sys.argv, skip_torch_cuda_test = launch.extract_arg(
sys.argv, "--skip-torch-cuda-test"
)
sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo")
sys.argv, update = launch.extract_arg(sys.argv, "--update")
sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers")
sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch")
xformers = "--xformers" in sys.argv
ngrok = "--ngrok" in sys.argv
if skip_install:
return
if (
reinstall_torch
or not launch.is_installed("torch")
or not launch.is_installed("torchvision")
) and not disable_strict_version:
launch.run(
f'"{launch.python}" -m {torch_command}',
"Installing torch and torchvision",
"Couldn't install torch",
)
if not skip_torch_cuda_test:
launch.run_python(
"import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'"
)
if (not launch.is_installed("xformers") or reinstall_xformers) and xformers:
launch.run_pip("install xformers --pre", "xformers")
if update and os.path.exists(repo_dir):
launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune')
launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main')
elif not os.path.exists(repo_dir):
launch.run(
f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"'
)
if not skip_checkout_repo:
launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}')
if not launch.is_installed("gradio"):
launch.run_pip("install gradio==3.16.2", "gradio")
if not launch.is_installed("pyngrok") and ngrok:
launch.run_pip("install pyngrok", "ngrok")
if platform.system() == "Linux":
if not launch.is_installed("triton"):
launch.run_pip("install triton", "triton")
if disable_strict_version:
with open(os.path.join(repo_dir, requirements_file), "r") as f:
txt = f.read()
requirements = [
re.split("==|<|>", a)[0]
for a in txt.split("\n")
if (not a.startswith("#") and a != ".")
]
requirements = " ".join(requirements)
launch.run_pip(
f'install "{requirements}" "{repo_dir}"',
"requirements for kohya sd-scripts",
)
else:
launch.run(
f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt',
desc=f"Installing requirements for kohya sd-scripts",
errdesc=f"Couldn't install requirements for kohya sd-scripts",
)
if platform.system() == "Windows":
for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")):
filename = os.path.basename(file)
for dir in site.getsitepackages():
outfile = (
os.path.join(dir, "bitsandbytes", "cuda_setup", filename)
if filename == "main.py"
else os.path.join(dir, "bitsandbytes", filename)
)
if not os.path.exists(os.path.dirname(outfile)):
continue
shutil.copy(file, outfile)
if __name__ == "__main__":
prepare_environment()