|
|
|
from torch import Tensor
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import io
|
|
from contextlib import contextmanager
|
|
import json
|
|
import folder_paths
|
|
|
|
|
|
my_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
sys.path.append(my_dir)
|
|
|
|
|
|
comfy_dir = os.path.abspath(os.path.join(my_dir, '..', '..'))
|
|
|
|
|
|
sys.path.append(comfy_dir)
|
|
|
|
|
|
import comfy.sd
|
|
import comfy.utils
|
|
import latent_preview
|
|
from comfy.cli_args import args
|
|
|
|
|
|
loaded_objects = {
|
|
"ckpt": [],
|
|
"refn": [],
|
|
"vae": [],
|
|
"lora": []
|
|
}
|
|
|
|
|
|
last_helds = {
|
|
"latent": [],
|
|
"image": [],
|
|
"cnet_img": []
|
|
}
|
|
|
|
def load_ksampler_results(key: str, my_unique_id, parameters_list=None):
|
|
global last_helds
|
|
for data in last_helds[key]:
|
|
id_ = data[-1]
|
|
if id_ == my_unique_id:
|
|
if parameters_list is not None:
|
|
|
|
if len(data) >= 3 and data[1] == parameters_list:
|
|
return data[0]
|
|
else:
|
|
return data[0]
|
|
return None
|
|
|
|
def store_ksampler_results(key: str, my_unique_id, value, parameters_list=None):
|
|
global last_helds
|
|
|
|
for i, data in enumerate(last_helds[key]):
|
|
id_ = data[-1]
|
|
if id_ == my_unique_id:
|
|
|
|
updated_data = (value, parameters_list, id_) if parameters_list is not None else (value, id_)
|
|
last_helds[key][i] = updated_data
|
|
return True
|
|
|
|
|
|
if parameters_list is not None:
|
|
last_helds[key].append((value, parameters_list, my_unique_id))
|
|
else:
|
|
last_helds[key].append((value, my_unique_id))
|
|
return True
|
|
|
|
|
|
def tensor2pil(image: torch.Tensor) -> Image.Image:
|
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
|
|
|
|
|
def pil2tensor(image: Image.Image) -> torch.Tensor:
|
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
|
|
def quick_resize(source_tensor: torch.Tensor, target_shape: tuple) -> torch.Tensor:
|
|
resized_images = []
|
|
for img in source_tensor:
|
|
resized_pil = tensor2pil(img.squeeze(0)).resize((target_shape[2], target_shape[1]), Image.ANTIALIAS)
|
|
resized_images.append(pil2tensor(resized_pil).squeeze(0))
|
|
return torch.stack(resized_images, dim=0)
|
|
|
|
|
|
import hashlib
|
|
def tensor_to_hash(tensor):
|
|
byte_repr = tensor.cpu().numpy().tobytes()
|
|
return hashlib.sha256(byte_repr).hexdigest()
|
|
|
|
|
|
MESSAGE_COLOR = "\033[36m"
|
|
XYPLOT_COLOR = "\033[35m"
|
|
SUCCESS_COLOR = "\033[92m"
|
|
WARNING_COLOR = "\033[93m"
|
|
ERROR_COLOR = "\033[91m"
|
|
INFO_COLOR = "\033[90m"
|
|
def format_message(text, color_code):
|
|
RESET_COLOR = "\033[0m"
|
|
return f"{color_code}{text}{RESET_COLOR}"
|
|
def message(text):
|
|
return format_message(text, MESSAGE_COLOR)
|
|
def warning(text):
|
|
return format_message(text, WARNING_COLOR)
|
|
def error(text):
|
|
return format_message(text, ERROR_COLOR)
|
|
def success(text):
|
|
return format_message(text, SUCCESS_COLOR)
|
|
def xyplot_message(text):
|
|
return format_message(text, XYPLOT_COLOR)
|
|
def info(text):
|
|
return format_message(text, INFO_COLOR)
|
|
|
|
def extract_node_info(prompt, id, indirect_key=None):
|
|
|
|
id = str(id)
|
|
node_id = None
|
|
|
|
|
|
if indirect_key:
|
|
|
|
if id in prompt and 'inputs' in prompt[id] and indirect_key in prompt[id]['inputs']:
|
|
|
|
indirect_id = prompt[id]['inputs'][indirect_key][0]
|
|
|
|
|
|
if indirect_id in prompt:
|
|
node_id = indirect_id
|
|
return prompt[indirect_id].get('class_type', None), node_id
|
|
|
|
|
|
return None, None
|
|
|
|
|
|
return prompt.get(id, {}).get('class_type', None), node_id
|
|
|
|
def extract_node_value(prompt, id, key):
|
|
|
|
return prompt.get(str(id), {}).get('inputs', {}).get(key, None)
|
|
|
|
def print_loaded_objects_entries(id=None, prompt=None, show_id=False):
|
|
print("-" * 40)
|
|
if id is not None:
|
|
id = str(id)
|
|
if prompt is not None and id is not None:
|
|
node_name, _ = extract_node_info(prompt, id)
|
|
if show_id:
|
|
print(f"\033[36m{node_name} Models Cache: (node_id:{int(id)})\033[0m")
|
|
else:
|
|
print(f"\033[36m{node_name} Models Cache:\033[0m")
|
|
elif id is None:
|
|
print(f"\033[36mGlobal Models Cache:\033[0m")
|
|
else:
|
|
print(f"\033[36mModels Cache: \nnode_id:{int(id)}\033[0m")
|
|
entries_found = False
|
|
for key in ["ckpt", "refn", "vae", "lora"]:
|
|
entries_with_id = loaded_objects[key] if id is None else [entry for entry in loaded_objects[key] if id in entry[-1]]
|
|
if not entries_with_id:
|
|
continue
|
|
entries_found = True
|
|
print(f"{key.capitalize()}:")
|
|
for i, entry in enumerate(entries_with_id, 1):
|
|
if key == "lora":
|
|
base_ckpt_name = os.path.splitext(os.path.basename(entry[1]))[0]
|
|
if id is None:
|
|
associated_ids = ', '.join(map(str, entry[-1]))
|
|
print(f" [{i}] base_ckpt: {base_ckpt_name} (ids: {associated_ids})")
|
|
else:
|
|
print(f" [{i}] base_ckpt: {base_ckpt_name}")
|
|
for name, strength_model, strength_clip in entry[0]:
|
|
lora_model_info = f"{os.path.splitext(os.path.basename(name))[0]}({round(strength_model, 2)},{round(strength_clip, 2)})"
|
|
print(f" lora(mod,clip): {lora_model_info}")
|
|
else:
|
|
name_without_ext = os.path.splitext(os.path.basename(entry[0]))[0]
|
|
if id is None:
|
|
associated_ids = ', '.join(map(str, entry[-1]))
|
|
print(f" [{i}] {name_without_ext} (ids: {associated_ids})")
|
|
else:
|
|
print(f" [{i}] {name_without_ext}")
|
|
if not entries_found:
|
|
print("-")
|
|
|
|
|
|
def globals_cleanup(prompt):
|
|
global loaded_objects
|
|
global last_helds
|
|
|
|
|
|
for key in list(last_helds.keys()):
|
|
original_length = len(last_helds[key])
|
|
last_helds[key] = [
|
|
(*values, id_)
|
|
for *values, id_ in last_helds[key]
|
|
if str(id_) in prompt.keys()
|
|
]
|
|
|
|
|
|
for key in list(loaded_objects.keys()):
|
|
for i, tup in enumerate(list(loaded_objects[key])):
|
|
|
|
id_array = [id for id in tup[-1] if str(id) in prompt.keys()]
|
|
if len(id_array) != len(tup[-1]):
|
|
if id_array:
|
|
loaded_objects[key][i] = tup[:-1] + (id_array,)
|
|
|
|
else:
|
|
|
|
loaded_objects[key].remove(tup)
|
|
|
|
|
|
def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=True, ckpt_type="ckpt"):
|
|
global loaded_objects
|
|
|
|
|
|
ckpt_name = ckpt_name.copy() if isinstance(ckpt_name, (list, dict, set)) else ckpt_name
|
|
|
|
|
|
if ckpt_type not in ["ckpt", "refn"]:
|
|
raise ValueError(f"Invalid checkpoint type: {ckpt_type}")
|
|
|
|
for entry in loaded_objects[ckpt_type]:
|
|
if entry[0] == ckpt_name:
|
|
_, model, clip, vae, ids = entry
|
|
cache_full = cache and len([entry for entry in loaded_objects[ckpt_type] if id in entry[-1]]) >= cache
|
|
|
|
if cache_full:
|
|
clear_cache(id, cache, ckpt_type)
|
|
elif id not in ids:
|
|
ids.append(id)
|
|
|
|
return model, clip, vae
|
|
|
|
if os.path.isabs(ckpt_name):
|
|
ckpt_path = ckpt_name
|
|
else:
|
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
|
with suppress_output():
|
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
|
|
|
model = out[0]
|
|
clip = out[1]
|
|
vae = out[2] if output_vae else None
|
|
|
|
if cache:
|
|
cache_list = [entry for entry in loaded_objects[ckpt_type] if id in entry[-1]]
|
|
if len(cache_list) < cache:
|
|
loaded_objects[ckpt_type].append((ckpt_name, model, clip, vae, [id]))
|
|
else:
|
|
clear_cache(id, cache, ckpt_type)
|
|
if cache_overwrite:
|
|
for e in loaded_objects[ckpt_type]:
|
|
if id in e[-1]:
|
|
e[-1].remove(id)
|
|
|
|
if not e[-1]:
|
|
loaded_objects[ckpt_type].remove(e)
|
|
break
|
|
loaded_objects[ckpt_type].append((ckpt_name, model, clip, vae, [id]))
|
|
|
|
return model, clip, vae
|
|
|
|
def get_bvae_by_ckpt_name(ckpt_name):
|
|
for ckpt in loaded_objects["ckpt"]:
|
|
if ckpt[0] == ckpt_name:
|
|
return ckpt[3]
|
|
return None
|
|
|
|
def load_vae(vae_name, id, cache=None, cache_overwrite=False):
|
|
global loaded_objects
|
|
|
|
|
|
vae_name = vae_name.copy() if isinstance(vae_name, (list, dict, set)) else vae_name
|
|
|
|
for i, entry in enumerate(loaded_objects["vae"]):
|
|
if entry[0] == vae_name:
|
|
vae, ids = entry[1], entry[2]
|
|
if id not in ids:
|
|
if cache and len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) >= cache:
|
|
return vae
|
|
ids.append(id)
|
|
if cache:
|
|
clear_cache(id, cache, "vae")
|
|
return vae
|
|
|
|
if os.path.isabs(vae_name):
|
|
vae_path = vae_name
|
|
else:
|
|
vae_path = folder_paths.get_full_path("vae", vae_name)
|
|
|
|
sd = comfy.utils.load_torch_file(vae_path)
|
|
vae = comfy.sd.VAE(sd=sd)
|
|
|
|
if cache:
|
|
if len([entry for entry in loaded_objects["vae"] if id in entry[-1]]) < cache:
|
|
loaded_objects["vae"].append((vae_name, vae, [id]))
|
|
else:
|
|
clear_cache(id, cache, "vae")
|
|
if cache_overwrite:
|
|
|
|
for e in loaded_objects["vae"]:
|
|
if id in e[-1]:
|
|
e[-1].remove(id)
|
|
|
|
if not e[-1]:
|
|
loaded_objects["vae"].remove(e)
|
|
break
|
|
loaded_objects["vae"].append((vae_name, vae, [id]))
|
|
|
|
return vae
|
|
|
|
def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_overwrite=False):
|
|
global loaded_objects
|
|
|
|
|
|
lora_params = lora_params.copy() if isinstance(lora_params, (list, dict, set)) else lora_params
|
|
ckpt_name = ckpt_name.copy() if isinstance(ckpt_name, (list, dict, set)) else ckpt_name
|
|
|
|
for entry in loaded_objects["lora"]:
|
|
|
|
|
|
if set(entry[0]) == set(lora_params) and entry[1] == ckpt_name:
|
|
|
|
_, _, lora_model, lora_clip, ids = entry
|
|
cache_full = cache and len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) >= cache
|
|
|
|
if cache_full:
|
|
clear_cache(id, cache, "lora")
|
|
elif id not in ids:
|
|
ids.append(id)
|
|
|
|
|
|
for ckpt_entry in loaded_objects["ckpt"]:
|
|
if ckpt_entry[0] == ckpt_name:
|
|
_, _, _, _, ckpt_ids = ckpt_entry
|
|
ckpt_cache_full = ckpt_cache and len(
|
|
[ckpt_entry for ckpt_entry in loaded_objects["ckpt"] if id in ckpt_entry[-1]]) >= ckpt_cache
|
|
|
|
if ckpt_cache_full:
|
|
clear_cache(id, ckpt_cache, "ckpt")
|
|
elif id not in ckpt_ids:
|
|
ckpt_ids.append(id)
|
|
|
|
return lora_model, lora_clip
|
|
|
|
def recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths):
|
|
if len(lora_params) == 0:
|
|
return ckpt, clip
|
|
|
|
lora_name, strength_model, strength_clip = lora_params[0]
|
|
if os.path.isabs(lora_name):
|
|
lora_path = lora_name
|
|
else:
|
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
|
|
|
lora_model, lora_clip = comfy.sd.load_lora_for_models(ckpt, clip, comfy.utils.load_torch_file(lora_path), strength_model, strength_clip)
|
|
|
|
|
|
return recursive_load_lora(lora_params[1:], lora_model, lora_clip, id, ckpt_cache, cache_overwrite, folder_paths)
|
|
|
|
|
|
lora_name, strength_model, strength_clip = lora_params[0]
|
|
ckpt, clip, _ = load_checkpoint(ckpt_name, id, cache=ckpt_cache)
|
|
|
|
lora_model, lora_clip = recursive_load_lora(lora_params, ckpt, clip, id, ckpt_cache, cache_overwrite, folder_paths)
|
|
|
|
if cache:
|
|
if len([entry for entry in loaded_objects["lora"] if id in entry[-1]]) < cache:
|
|
loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id]))
|
|
else:
|
|
clear_cache(id, cache, "lora")
|
|
if cache_overwrite:
|
|
|
|
for e in loaded_objects["lora"]:
|
|
if id in e[-1]:
|
|
e[-1].remove(id)
|
|
|
|
if not e[-1]:
|
|
loaded_objects["lora"].remove(e)
|
|
break
|
|
loaded_objects["lora"].append((lora_params, ckpt_name, lora_model, lora_clip, [id]))
|
|
|
|
return lora_model, lora_clip
|
|
|
|
def clear_cache(id, cache, dict_name):
|
|
"""
|
|
Clear the cache for a specific id in a specific dictionary.
|
|
If the cache limit is reached for a specific id, deletes the id from the oldest entry.
|
|
If the id array of the entry becomes empty, deletes the entry.
|
|
"""
|
|
|
|
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]]
|
|
while len(id_associated_entries) > cache:
|
|
|
|
older_entry = id_associated_entries[0]
|
|
|
|
older_entry[-1].remove(id)
|
|
|
|
if not older_entry[-1]:
|
|
loaded_objects[dict_name].remove(older_entry)
|
|
|
|
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]]
|
|
|
|
|
|
def clear_cache_by_exception(node_id, vae_dict=None, ckpt_dict=None, lora_dict=None, refn_dict=None):
|
|
global loaded_objects
|
|
|
|
dict_mapping = {
|
|
"vae_dict": "vae",
|
|
"ckpt_dict": "ckpt",
|
|
"lora_dict": "lora",
|
|
"refn_dict": "refn"
|
|
}
|
|
|
|
for arg_name, arg_val in {"vae_dict": vae_dict, "ckpt_dict": ckpt_dict, "lora_dict": lora_dict, "refn_dict": refn_dict}.items():
|
|
if arg_val is None:
|
|
continue
|
|
|
|
dict_name = dict_mapping[arg_name]
|
|
|
|
for tuple_idx, tuple_item in enumerate(loaded_objects[dict_name].copy()):
|
|
if arg_name == "lora_dict":
|
|
|
|
for lora_params, ckpt_name in arg_val:
|
|
|
|
if set(lora_params) == set(tuple_item[0]) and ckpt_name == tuple_item[1]:
|
|
break
|
|
else:
|
|
if node_id in tuple_item[-1]:
|
|
tuple_item[-1].remove(node_id)
|
|
if not tuple_item[-1]:
|
|
loaded_objects[dict_name].remove(tuple_item)
|
|
continue
|
|
elif tuple_item[0] not in arg_val:
|
|
if node_id in tuple_item[-1]:
|
|
tuple_item[-1].remove(node_id)
|
|
if not tuple_item[-1]:
|
|
loaded_objects[dict_name].remove(tuple_item)
|
|
|
|
|
|
def get_cache_numbers(node_name):
|
|
|
|
my_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
settings_file = os.path.join(my_dir, 'node_settings.json')
|
|
|
|
with open(settings_file, 'r') as file:
|
|
node_settings = json.load(file)
|
|
|
|
model_cache_settings = node_settings.get(node_name, {}).get('model_cache', {})
|
|
vae_cache = int(model_cache_settings.get('vae', 1))
|
|
ckpt_cache = int(model_cache_settings.get('ckpt', 1))
|
|
lora_cache = int(model_cache_settings.get('lora', 1))
|
|
refn_cache = int(model_cache_settings.get('ckpt', 1))
|
|
return vae_cache, ckpt_cache, lora_cache, refn_cache,
|
|
|
|
def print_last_helds(id=None):
|
|
print("\n" + "-" * 40)
|
|
if id is not None:
|
|
id = str(id)
|
|
print(f"Node-specific Last Helds (node_id:{int(id)})")
|
|
else:
|
|
print(f"Global Last Helds:")
|
|
for key in ["preview_images", "latent", "output_images", "vae_decode"]:
|
|
entries_with_id = last_helds[key] if id is None else [entry for entry in last_helds[key] if id == entry[-1]]
|
|
if not entries_with_id:
|
|
continue
|
|
print(f"{key.capitalize()}:")
|
|
for i, entry in enumerate(entries_with_id, 1):
|
|
if isinstance(entry[0], bool):
|
|
output = entry[0]
|
|
else:
|
|
output = len(entry[0])
|
|
if id is None:
|
|
print(f" [{i}] Output: {output} (id: {entry[-1]})")
|
|
else:
|
|
print(f" [{i}] Output: {output}")
|
|
print("-" * 40)
|
|
print("\n")
|
|
|
|
|
|
@contextmanager
|
|
def suppress_output():
|
|
original_stdout = sys.stdout
|
|
original_stderr = sys.stderr
|
|
|
|
sys.stdout = io.StringIO()
|
|
sys.stderr = io.StringIO()
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
sys.stdout = original_stdout
|
|
sys.stderr = original_stderr
|
|
|
|
|
|
def set_preview_method(method):
|
|
if method == 'auto' or method == 'LatentPreviewMethod.Auto':
|
|
args.preview_method = latent_preview.LatentPreviewMethod.Auto
|
|
elif method == 'latent2rgb' or method == 'LatentPreviewMethod.Latent2RGB':
|
|
args.preview_method = latent_preview.LatentPreviewMethod.Latent2RGB
|
|
elif method == 'taesd' or method == 'LatentPreviewMethod.TAESD':
|
|
args.preview_method = latent_preview.LatentPreviewMethod.TAESD
|
|
else:
|
|
args.preview_method = latent_preview.LatentPreviewMethod.NoPreviews
|
|
|
|
|
|
def global_preview_method():
|
|
return args.preview_method
|
|
|
|
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
destination_dir = os.path.join(comfy_dir, 'web', 'extensions', 'efficiency-nodes-comfyui')
|
|
|
|
|
|
if os.path.exists(destination_dir):
|
|
shutil.rmtree(destination_dir)
|
|
|
|
|
|
|
|
class XY_Capsule:
|
|
def pre_define_model(self, model, clip, vae):
|
|
return model, clip, vae
|
|
|
|
def set_result(self, image, latent):
|
|
pass
|
|
|
|
def get_result(self, model, clip, vae):
|
|
return None
|
|
|
|
def set_x_capsule(self, capsule):
|
|
return None
|
|
|
|
def getLabel(self):
|
|
return "Unknown"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|