import json |
import os |
import platform |
import time |
from pathlib import Path |
import folder_paths |
import nodes |
class Loader: |
def __init__(self): |
pass |
__ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) |
__TEMPLATE_PATH = os.path.join(__ROOT_PATH, "resources/template.json") |
__TIMESTAMP_PATH = os.path.join(__ROOT_PATH, "resources/timestamp.json") |
__CONFIG_PATH = os.path.join(__ROOT_PATH, "config.json") |
__GIT_PATH = Path(os.path.join(__ROOT_PATH, ".git")) |
__DAY_SECONDS = 24 * 60 * 60 |
def __log(self, text): |
print("\033[92m[Allor]\033[0m: " + text) |
def __error(self, text): |
print("\033[91m[Allor]\033[0m: " + text) |
def __notification(self, text): |
print("\033[94m[Allor]\033[0m: " + text) |
def __new_line(self): |
print() |
def __create_config(self): |
with open(self.__CONFIG_PATH, "w", encoding="utf-8") as f: |
json.dump(self.__template(), f, ensure_ascii=False, indent=4) |
def __create_timestamp(self): |
with open(self.__TIMESTAMP_PATH, "w", encoding="utf-8") as f: |
json.dump({"timestamp": 0}, f, ensure_ascii=False, indent=4) |
def __get_template(self): |
with open(self.__TEMPLATE_PATH, "r") as f: |
template = json.load(f) |
if "__comment" in template: |
del template["__comment"] |
return template |
def __get_config(self): |
with open(self.__CONFIG_PATH, "r") as f: |
return json.load(f) |
def __get_timestamp(self): |
with open(self.__TIMESTAMP_PATH, "r") as f: |
return json.load(f) |
def __update_config(self, template, source): |
def update_source(__template, __source): |
for k, v in __template.items(): |
if k not in __source: |
if isinstance(v, dict): |
__source[k] = {} |
else: |
__source[k] = v |
if isinstance(v, dict): |
__source[k] = update_source(v, __source[k]) |
return __source |
def delete_keys(__template, __source): |
keys_to_delete = [k for k in __source if k not in __template] |
for k in keys_to_delete: |
del __source[k] |
return __source |
def sync_order(__template, __source): |
new_source = {} |
for key in __template: |
if key in __source: |
if isinstance(__template[key], dict): |
new_source[key] = sync_order(__template[key], __source[key]) |
else: |
new_source[key] = __source[key] |
return new_source |
source = update_source(template, source) |
source = delete_keys(template, source) |
source = sync_order(template, source) |
with open(self.__CONFIG_PATH, "w", encoding="utf-8") as f: |
json.dump(source, f, ensure_ascii=False, indent=4) |
def __update_timestamp(self): |
with open(self.__TIMESTAMP_PATH, "w", encoding="utf-8") as f: |
json.dump({"timestamp": time.time()}, f, ensure_ascii=False, indent=4) |
__template = __get_template |
__config = __get_config |
__timestamp = __get_timestamp |
def __get_fonts_folder_path(self): |
system = platform.system() |
user_home = os.path.expanduser('~') |
config_font_path = os.path.join(folder_paths.base_path, *self.__config()["fonts"]["folder_path"].replace("\\", "/").split("/")) |
if not os.path.exists(config_font_path): |
os.makedirs(config_font_path, exist_ok=True) |
paths = [config_font_path] |
if self.__config()["fonts"]["system_fonts"]: |
if system == "Windows": |
paths.append(os.path.join(os.environ["WINDIR"], "Fonts")) |
elif system == "Darwin": |
paths.append(os.path.join("/Library", "Fonts")) |
elif system == "Linux": |
paths.append(os.path.join("/usr", "share", "fonts")) |
paths.append(os.path.join("/usr", "local", "share", "fonts")) |
if self.__config()["fonts"]["user_fonts"]: |
if system == "Darwin": |
paths.append(os.path.join(user_home, "Library", "Fonts")) |
elif system == "Linux": |
paths.append(os.path.join(user_home, ".fonts")) |
return [path for path in paths if os.path.exists(path)] |
def __get_keys(self, json_obj, prefix=''): |
keys = [] |
for k, v in json_obj.items(): |
if isinstance(v, dict): |
keys.extend(self.__get_keys(v, prefix + k + '.')) |
else: |
keys.append(prefix + k) |
return set(keys) |
def __check_json_keys(self, json1, json2): |
keys1 = self.__get_keys(json1) |
keys2 = self.__get_keys(json2) |
return keys1 == keys2 |
def setup_config(self): |
if not os.path.exists(self.__CONFIG_PATH): |
self.__log("Creating config.json") |
self.__create_config() |
else: |
if not self.__check_json_keys(self.__template(), self.__config()): |
self.__log("Updating config.json") |
self.__update_config(self.__template(), self.__config()) |
def setup_timestamp(self): |
if not os.path.exists(self.__TIMESTAMP_PATH): |
self.__log("Creating timestamp.json") |
self.__create_timestamp() |
def check_updates(self): |
branch_name = self.__config()["updates"]["branch_name"] |
update_frequency = self.__config()["updates"]["update_frequency"].lower() |
valid_frequencies = ["always", "day", "week", "month", "never"] |
time_difference = time.time() - self.__timestamp()["timestamp"] |
if update_frequency == valid_frequencies[0]: |
it_is_time_for_update = True |
elif update_frequency == valid_frequencies[1]: |
it_is_time_for_update = time_difference >= self.__DAY_SECONDS |
elif update_frequency == valid_frequencies[2]: |
it_is_time_for_update = time_difference >= self.__WEEK_SECONDS |
elif update_frequency == valid_frequencies[3]: |
it_is_time_for_update = time_difference >= self.__MONTH_SECONDS |
elif update_frequency == valid_frequencies[4]: |
it_is_time_for_update = False |
else: |
self.__error(f"Unknown update frequency - {update_frequency}, available: {valid_frequencies}") |
return |
if it_is_time_for_update: |
if not (self.__GIT_PATH.exists() or self.__GIT_PATH.is_dir()): |
self.__error("Root directory of Allor is not a git repository. Update canceled.") |
return |
try: |
import git |
from git import Repo |
from git import GitCommandError |
repo = Repo(self.__ROOT_PATH, odbt=git.db.GitDB) |
current_commit = repo.head.commit.hexsha |
repo.remotes.origin.fetch() |
latest_commit = getattr(repo.remotes.origin.refs, branch_name).commit.hexsha |
if current_commit == latest_commit: |
if self.__config()["updates"]["notify_if_no_new_updates"]: |
self.__notification("No new updates.") |
else: |
if self.__config()["updates"]["notify_if_has_new_updates"]: |
self.__notification("New updates are available.") |
if self.__config()["updates"]["auto_update"]: |
update_mode = self.__config()["updates"]["update_mode"].lower() |
valid_modes = ["soft", "hard"] |
if repo.active_branch.name != branch_name: |
try: |
repo.git.checkout(branch_name) |
except GitCommandError: |
self.__error(f"An error occurred while switching to the branch {branch_name}.") |
return |
if update_mode == "soft": |
try: |
repo.git.pull() |
except GitCommandError: |
self.__error("An error occurred during the update. " |
"It is recommended to use \"hard\" update mode. " |
"But be careful, it erases all personal changes from Allor repository.") |
elif update_mode == "hard": |
repo.git.reset('--hard', 'origin/' + branch_name) |
else: |
self.__error(f"Unknown update mode - {update_mode}, available: {valid_modes}") |
return |
self.__notification("Update complete.") |
self.__update_timestamp() |
except ImportError: |
self.__error("GitPython is not installed.") |
def setup_rembg(self): |
os.environ["U2NET_HOME"] = folder_paths.models_dir + "/onnx" |
def setup_paths(self): |
fonts_folder_path = self.__get_fonts_folder_path() |
folder_paths.folder_names_and_paths["onnx"] = ([os.path.join(folder_paths.models_dir, "onnx")], {".onnx"}) |
folder_paths.folder_names_and_paths["fonts"] = (fonts_folder_path, {".otf", ".ttf"}) |
def setup_override(self): |
override_nodes_len = 0 |
def override(function): |
start_len = nodes.NODE_CLASS_MAPPINGS.__len__() |
nodes.NODE_CLASS_MAPPINGS = dict( |
filter(function, nodes.NODE_CLASS_MAPPINGS.items()) |
) |
return start_len - nodes.NODE_CLASS_MAPPINGS.__len__() |
if self.__config()["override"]["postprocessing"]: |
override_nodes_len += override(lambda item: not item[1].CATEGORY.startswith("image/postprocessing")) |
if self.__config()["override"]["transform"]: |
override_nodes_len += override(lambda item: not item[0] == "ImageScale" and not item[0] == "ImageScaleBy" and not item[0] == "ImageInvert") |
if self.__config()["override"]["debug"]: |
nodes.VAEDecodeTiled.CATEGORY = "latent" |
nodes.VAEEncodeTiled.CATEGORY = "latent" |
override_nodes_len += override(lambda item: not item[1].CATEGORY.startswith("_for_testing")) |
self.__log(str(override_nodes_len) + " nodes were overridden.") |
def get_modules(self): |
modules = dict() |
if self.__config()["modules"]["AlphaChanel"]: |
from .modules import AlphaChanel |
modules.update(AlphaChanel.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["Clamp"]: |
from .modules import Clamp |
modules.update(Clamp.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageBatch"]: |
from .modules import ImageBatch |
modules.update(ImageBatch.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageComposite"]: |
from .modules import ImageComposite |
modules.update(ImageComposite.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageContainer"]: |
from .modules import ImageContainer |
modules.update(ImageContainer.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageDraw"]: |
from .modules import ImageDraw |
modules.update(ImageDraw.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageEffects"]: |
from .modules import ImageEffects |
modules.update(ImageEffects.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageFilter"]: |
from .modules import ImageFilter |
modules.update(ImageFilter.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageNoise"]: |
from .modules import ImageNoise |
modules.update(ImageNoise.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageSegmentation"]: |
from .modules import ImageSegmentation |
modules.update(ImageSegmentation.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageText"]: |
from .modules import ImageText |
modules.update(ImageText.NODE_CLASS_MAPPINGS) |
if self.__config()["modules"]["ImageTransform"]: |
from .modules import ImageTransform |
modules.update(ImageTransform.NODE_CLASS_MAPPINGS) |
modules_len = dict( |
filter( |
lambda item: item[1], |
self.__config()["modules"].items() |
) |
).__len__() |
nodes_len = modules.__len__() |
self.__log(str(modules_len) + " modules were enabled.") |
self.__log(str(nodes_len) + " nodes were loaded.") |
return modules |