import sys import launch import platform import os import shutil import site import glob import re dirname = os.path.dirname(__file__) repo_dir = os.path.join(dirname, "kohya_ss") def prepare_environment(): torch_command = os.environ.get( "TORCH_COMMAND", "pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118", ) sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git") sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main") requirements_file = os.environ.get("REQS_FILE", "requirements.txt") sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install") sys.argv, disable_strict_version = launch.extract_arg( sys.argv, "--disable-strict-version" ) sys.argv, skip_torch_cuda_test = launch.extract_arg( sys.argv, "--skip-torch-cuda-test" ) sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo") sys.argv, update = launch.extract_arg(sys.argv, "--update") sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers") sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch") xformers = "--xformers" in sys.argv ngrok = "--ngrok" in sys.argv if skip_install: return if ( reinstall_torch or not launch.is_installed("torch") or not launch.is_installed("torchvision") ) and not disable_strict_version: launch.run( f'"{launch.python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", ) if not skip_torch_cuda_test: launch.run_python( "import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'" ) if (not launch.is_installed("xformers") or reinstall_xformers) and xformers: launch.run_pip("install xformers --pre", "xformers") if update and os.path.exists(repo_dir): launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune') launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main') elif not os.path.exists(repo_dir): launch.run( f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"' ) if not skip_checkout_repo: launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}') if not launch.is_installed("gradio"): launch.run_pip("install gradio==3.16.2", "gradio") if not launch.is_installed("pyngrok") and ngrok: launch.run_pip("install pyngrok", "ngrok") if platform.system() == "Linux": if not launch.is_installed("triton"): launch.run_pip("install triton", "triton") if disable_strict_version: with open(os.path.join(repo_dir, requirements_file), "r") as f: txt = f.read() requirements = [ re.split("==|<|>", a)[0] for a in txt.split("\n") if (not a.startswith("#") and a != ".") ] requirements = " ".join(requirements) launch.run_pip( f'install "{requirements}" "{repo_dir}"', "requirements for kohya sd-scripts", ) else: launch.run( f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt', desc=f"Installing requirements for kohya sd-scripts", errdesc=f"Couldn't install requirements for kohya sd-scripts", ) if platform.system() == "Windows": for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")): filename = os.path.basename(file) for dir in site.getsitepackages(): outfile = ( os.path.join(dir, "bitsandbytes", "cuda_setup", filename) if filename == "main.py" else os.path.join(dir, "bitsandbytes", filename) ) if not os.path.exists(os.path.dirname(outfile)): continue shutil.copy(file, outfile) if __name__ == "__main__": prepare_environment()