|
import json |
|
import os |
|
|
|
import folder_paths |
|
import nodes |
|
from server import PromptServer |
|
|
|
from .libs.utils import TaggedCache, any_typ |
|
|
|
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
settings_file = os.path.join(root_dir, 'cache_settings.json') |
|
try: |
|
with open(settings_file) as f: |
|
cache_settings = json.load(f) |
|
except Exception as e: |
|
print(e) |
|
cache_settings = {} |
|
cache = TaggedCache(cache_settings) |
|
cache_count = {} |
|
|
|
|
|
def update_cache(k, tag, v): |
|
cache[k] = (tag, v) |
|
cnt = cache_count.get(k) |
|
if cnt is None: |
|
cnt = 0 |
|
cache_count[k] = cnt |
|
else: |
|
cache_count[k] += 1 |
|
|
|
|
|
def cache_weak_hash(k): |
|
cnt = cache_count.get(k) |
|
if cnt is None: |
|
cnt = 0 |
|
|
|
return k, cnt |
|
|
|
|
|
class CacheBackendData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("STRING", {"multiline": False, "placeholder": "Input data key (e.g. 'model a', 'chunli lora', 'girl latent 3', ...)"}), |
|
"tag": ("STRING", {"multiline": False, "placeholder": "Tag: short description"}), |
|
"data": (any_typ,), |
|
} |
|
} |
|
|
|
RETURN_TYPES = (any_typ,) |
|
RETURN_NAMES = ("data opt",) |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
OUTPUT_NODE = True |
|
|
|
def doit(self, key, tag, data): |
|
global cache |
|
|
|
if key == '*': |
|
print(f"[Inspire Pack] CacheBackendData: '*' is reserved key. Cannot use that key") |
|
|
|
update_cache(key, tag, (False, data)) |
|
return (data,) |
|
|
|
|
|
class CacheBackendDataNumberKey: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
"tag": ("STRING", {"multiline": False, "placeholder": "Tag: short description"}), |
|
"data": (any_typ,), |
|
} |
|
} |
|
|
|
RETURN_TYPES = (any_typ,) |
|
RETURN_NAMES = ("data opt",) |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
OUTPUT_NODE = True |
|
|
|
def doit(self, key, tag, data): |
|
global cache |
|
|
|
update_cache(key, tag, (False, data)) |
|
return (data,) |
|
|
|
|
|
class CacheBackendDataList: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("STRING", {"multiline": False, "placeholder": "Input data key (e.g. 'model a', 'chunli lora', 'girl latent 3', ...)"}), |
|
"tag": ("STRING", {"multiline": False, "placeholder": "Tag: short description"}), |
|
"data": (any_typ,), |
|
} |
|
} |
|
|
|
INPUT_IS_LIST = True |
|
|
|
RETURN_TYPES = (any_typ,) |
|
RETURN_NAMES = ("data opt",) |
|
OUTPUT_IS_LIST = (True,) |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
OUTPUT_NODE = True |
|
|
|
def doit(self, key, tag, data): |
|
global cache |
|
|
|
if key == '*': |
|
print(f"[Inspire Pack] CacheBackendDataList: '*' is reserved key. Cannot use that key") |
|
|
|
update_cache(key[0], tag[0], (True, data)) |
|
return (data,) |
|
|
|
|
|
class CacheBackendDataNumberKeyList: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
"tag": ("STRING", {"multiline": False, "placeholder": "Tag: short description"}), |
|
"data": (any_typ,), |
|
} |
|
} |
|
|
|
INPUT_IS_LIST = True |
|
|
|
RETURN_TYPES = (any_typ,) |
|
RETURN_NAMES = ("data opt",) |
|
OUTPUT_IS_LIST = (True,) |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
OUTPUT_NODE = True |
|
|
|
def doit(self, key, tag, data): |
|
global cache |
|
update_cache(key[0], tag[0], (True, data)) |
|
return (data,) |
|
|
|
|
|
class RetrieveBackendData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("STRING", {"multiline": False, "placeholder": "Input data key (e.g. 'model a', 'chunli lora', 'girl latent 3', ...)"}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = (any_typ,) |
|
RETURN_NAMES = ("data",) |
|
OUTPUT_IS_LIST = (True,) |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
@staticmethod |
|
def doit(key): |
|
global cache |
|
|
|
v = cache.get(key) |
|
|
|
if v is None: |
|
print(f"[RetrieveBackendData] '{key}' is unregistered key.") |
|
return (None,) |
|
|
|
is_list, data = v[1] |
|
|
|
if is_list: |
|
return (data,) |
|
else: |
|
return ([data],) |
|
|
|
@staticmethod |
|
def IS_CHANGED(key): |
|
return cache_weak_hash(key) |
|
|
|
|
|
class RetrieveBackendDataNumberKey(RetrieveBackendData): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
} |
|
} |
|
|
|
|
|
class RemoveBackendData: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("STRING", {"multiline": False, "placeholder": "Input data key ('*' = clear all)"}), |
|
}, |
|
"optional": { |
|
"signal_opt": (any_typ,), |
|
} |
|
} |
|
|
|
RETURN_TYPES = (any_typ,) |
|
RETURN_NAMES = ("signal",) |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
OUTPUT_NODE = True |
|
|
|
@staticmethod |
|
def doit(key, signal_opt=None): |
|
global cache |
|
|
|
if key == '*': |
|
cache = TaggedCache(cache_settings) |
|
elif key in cache: |
|
del cache[key] |
|
else: |
|
print(f"[Inspire Pack] RemoveBackendData: invalid data key {key}") |
|
|
|
return (signal_opt,) |
|
|
|
|
|
class RemoveBackendDataNumberKey(RemoveBackendData): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"key": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), |
|
}, |
|
"optional": { |
|
"signal_opt": (any_typ,), |
|
} |
|
} |
|
|
|
@staticmethod |
|
def doit(key, signal_opt=None): |
|
global cache |
|
|
|
if key in cache: |
|
del cache[key] |
|
else: |
|
print(f"[Inspire Pack] RemoveBackendDataNumberKey: invalid data key {key}") |
|
|
|
return (signal_opt,) |
|
|
|
|
|
class ShowCachedInfo: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"cache_info": ("STRING", {"multiline": True, "default": ""}), |
|
"key": ("STRING", {"multiline": False, "default": ""}), |
|
}, |
|
"hidden": {"unique_id": "UNIQUE_ID"}, |
|
} |
|
|
|
RETURN_TYPES = () |
|
|
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
OUTPUT_NODE = True |
|
|
|
@staticmethod |
|
def get_data(): |
|
global cache |
|
|
|
text1 = "---- [String Key Caches] ----\n" |
|
text2 = "---- [Number Key Caches] ----\n" |
|
for k, v in cache.items(): |
|
tag = 'N/A(tag)' if v[0] == '' else v[0] |
|
if isinstance(k, str): |
|
text1 += f'{k}: {tag}\n' |
|
else: |
|
text2 += f'{k}: {tag}\n' |
|
|
|
text3 = "---- [TagCache Settings] ----\n" |
|
for k, v in cache._tag_settings.items(): |
|
text3 += f'{k}: {v}\n' |
|
|
|
for k, v in cache._data.items(): |
|
if k not in cache._tag_settings: |
|
text3 += f'{k}: {v.maxsize}\n' |
|
|
|
return f'{text1}\n{text2}\n{text3}' |
|
|
|
@staticmethod |
|
def set_cache_settings(data: str): |
|
global cache |
|
settings = data.split("---- [TagCache Settings] ----\n")[-1].strip().split("\n") |
|
|
|
new_tag_settings = {} |
|
for s in settings: |
|
k, v = s.split(":") |
|
new_tag_settings[k] = int(v.strip()) |
|
if new_tag_settings == cache._tag_settings: |
|
|
|
return |
|
|
|
|
|
new_cache = TaggedCache(new_tag_settings) |
|
for k, v in cache.items(): |
|
new_cache[k] = v |
|
cache = new_cache |
|
|
|
def doit(self, cache_info, key, unique_id): |
|
text = ShowCachedInfo.get_data() |
|
PromptServer.instance.send_sync("inspire-node-feedback", {"node_id": unique_id, "widget_name": "cache_info", "type": "text", "data": text}) |
|
|
|
return {} |
|
|
|
@classmethod |
|
def IS_CHANGED(cls, **kwargs): |
|
return float("NaN") |
|
|
|
|
|
class CheckpointLoaderSimpleShared(nodes.CheckpointLoaderSimple): |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), |
|
"key_opt": ("STRING", {"multiline": False, "placeholder": "If empty, use 'ckpt_name' as the key."}), |
|
}, |
|
"optional": { |
|
"mode": (['Auto', 'Override Cache', 'Read Only'],), |
|
}} |
|
|
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING") |
|
RETURN_NAMES = ("model", "clip", "vae", "cache key") |
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
def doit(self, ckpt_name, key_opt, mode='Auto'): |
|
if mode == 'Read Only': |
|
if key_opt.strip() == '': |
|
raise Exception("[CheckpointLoaderSimpleShared] key_opt cannot be omit if mode is 'Read Only'") |
|
key = key_opt.strip() |
|
elif key_opt.strip() == '': |
|
key = ckpt_name |
|
else: |
|
key = key_opt.strip() |
|
|
|
if key not in cache or mode == 'Override Cache': |
|
res = self.load_checkpoint(ckpt_name) |
|
update_cache(key, "ckpt", (False, res)) |
|
cache_kind = 'ckpt' |
|
print(f"[Inspire Pack] CheckpointLoaderSimpleShared: Ckpt '{ckpt_name}' is cached to '{key}'.") |
|
else: |
|
cache_kind, (_, res) = cache[key] |
|
print(f"[Inspire Pack] CheckpointLoaderSimpleShared: Cached ckpt '{key}' is loaded. (Loading skip)") |
|
|
|
if cache_kind == 'ckpt': |
|
model, clip, vae = res |
|
elif cache_kind == 'unclip_ckpt': |
|
model, clip, vae, _ = res |
|
else: |
|
raise Exception(f"[CheckpointLoaderSimpleShared] Unexpected cache_kind '{cache_kind}'") |
|
|
|
return model, clip, vae, key |
|
|
|
@staticmethod |
|
def IS_CHANGED(ckpt_name, key_opt, mode='Auto'): |
|
if mode == 'Read Only': |
|
if key_opt.strip() == '': |
|
raise Exception("[CheckpointLoaderSimpleShared] key_opt cannot be omit if mode is 'Read Only'") |
|
key = key_opt.strip() |
|
elif key_opt.strip() == '': |
|
key = ckpt_name |
|
else: |
|
key = key_opt.strip() |
|
|
|
if mode == 'Read Only': |
|
return (None, cache_weak_hash(key)) |
|
elif mode == 'Override Cache': |
|
return (ckpt_name, key) |
|
|
|
return (None, cache_weak_hash(key)) |
|
|
|
|
|
class StableCascade_CheckpointLoader: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
ckpts = folder_paths.get_filename_list("checkpoints") |
|
default_stage_b = '' |
|
default_stage_c = '' |
|
|
|
sc_ckpts = [x for x in ckpts if 'cascade' in x.lower()] |
|
sc_b_ckpts = [x for x in sc_ckpts if 'stage_b' in x.lower()] |
|
sc_c_ckpts = [x for x in sc_ckpts if 'stage_c' in x.lower()] |
|
|
|
if len(sc_b_ckpts) == 0: |
|
sc_b_ckpts = [x for x in ckpts if 'stage_b' in x.lower()] |
|
if len(sc_c_ckpts) == 0: |
|
sc_c_ckpts = [x for x in ckpts if 'stage_c' in x.lower()] |
|
|
|
if len(sc_b_ckpts) == 0: |
|
sc_b_ckpts = ckpts |
|
if len(sc_c_ckpts) == 0: |
|
sc_c_ckpts = ckpts |
|
|
|
if len(sc_b_ckpts) > 0: |
|
default_stage_b = sc_b_ckpts[0] |
|
if len(sc_c_ckpts) > 0: |
|
default_stage_c = sc_c_ckpts[0] |
|
|
|
return {"required": { |
|
"stage_b": (ckpts, {'default': default_stage_b}), |
|
"key_opt_b": ("STRING", {"multiline": False, "placeholder": "If empty, use 'stage_b' as the key."}), |
|
"stage_c": (ckpts, {'default': default_stage_c}), |
|
"key_opt_c": ("STRING", {"multiline": False, "placeholder": "If empty, use 'stage_c' as the key."}), |
|
"cache_mode": (["none", "stage_b", "stage_c", "all"], {"default": "none"}), |
|
}} |
|
|
|
RETURN_TYPES = ("MODEL", "VAE", "MODEL", "VAE", "CLIP_VISION", "CLIP", "STRING", "STRING") |
|
RETURN_NAMES = ("b_model", "b_vae", "c_model", "c_vae", "c_clip_vision", "clip", "key_b", "key_c") |
|
FUNCTION = "doit" |
|
|
|
CATEGORY = "InspirePack/Backend" |
|
|
|
def doit(self, stage_b, key_opt_b, stage_c, key_opt_c, cache_mode): |
|
if key_opt_b.strip() == '': |
|
key_b = stage_b |
|
else: |
|
key_b = key_opt_b.strip() |
|
|
|
if key_opt_c.strip() == '': |
|
key_c = stage_c |
|
else: |
|
key_c = key_opt_c.strip() |
|
|
|
if cache_mode in ['stage_b', "all"]: |
|
if key_b not in cache: |
|
res_b = nodes.CheckpointLoaderSimple().load_checkpoint(ckpt_name=stage_b) |
|
update_cache(key_b, "ckpt", (False, res_b)) |
|
print(f"[Inspire Pack] StableCascade_CheckpointLoader: Ckpt '{stage_b}' is cached to '{key_b}'.") |
|
else: |
|
_, (_, res_b) = cache[key_b] |
|
print(f"[Inspire Pack] StableCascade_CheckpointLoader: Cached ckpt '{key_b}' is loaded. (Loading skip)") |
|
b_model, clip, b_vae = res_b |
|
else: |
|
b_model, clip, b_vae = nodes.CheckpointLoaderSimple().load_checkpoint(ckpt_name=stage_b) |
|
|
|
if cache_mode in ['stage_c', "all"]: |
|
if key_c not in cache: |
|
res_c = nodes.unCLIPCheckpointLoader().load_checkpoint(ckpt_name=stage_c) |
|
update_cache(key_c, "unclip_ckpt", (False, res_c)) |
|
print(f"[Inspire Pack] StableCascade_CheckpointLoader: Ckpt '{stage_c}' is cached to '{key_c}'.") |
|
else: |
|
_, (_, res_c) = cache[key_c] |
|
print(f"[Inspire Pack] StableCascade_CheckpointLoader: Cached ckpt '{key_c}' is loaded. (Loading skip)") |
|
c_model, _, c_vae, clip_vision = res_c |
|
else: |
|
c_model, _, c_vae, clip_vision = nodes.unCLIPCheckpointLoader().load_checkpoint(ckpt_name=stage_c) |
|
|
|
return b_model, b_vae, c_model, c_vae, clip_vision, clip, key_b, key_c |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"CacheBackendData //Inspire": CacheBackendData, |
|
"CacheBackendDataNumberKey //Inspire": CacheBackendDataNumberKey, |
|
"CacheBackendDataList //Inspire": CacheBackendDataList, |
|
"CacheBackendDataNumberKeyList //Inspire": CacheBackendDataNumberKeyList, |
|
"RetrieveBackendData //Inspire": RetrieveBackendData, |
|
"RetrieveBackendDataNumberKey //Inspire": RetrieveBackendDataNumberKey, |
|
"RemoveBackendData //Inspire": RemoveBackendData, |
|
"RemoveBackendDataNumberKey //Inspire": RemoveBackendDataNumberKey, |
|
"ShowCachedInfo //Inspire": ShowCachedInfo, |
|
"CheckpointLoaderSimpleShared //Inspire": CheckpointLoaderSimpleShared, |
|
"StableCascade_CheckpointLoader //Inspire": StableCascade_CheckpointLoader |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"CacheBackendData //Inspire": "Cache Backend Data (Inspire)", |
|
"CacheBackendDataNumberKey //Inspire": "Cache Backend Data [NumberKey] (Inspire)", |
|
"CacheBackendDataList //Inspire": "Cache Backend Data List (Inspire)", |
|
"CacheBackendDataNumberKeyList //Inspire": "Cache Backend Data List [NumberKey] (Inspire)", |
|
"RetrieveBackendData //Inspire": "Retrieve Backend Data (Inspire)", |
|
"RetrieveBackendDataNumberKey //Inspire": "Retrieve Backend Data [NumberKey] (Inspire)", |
|
"RemoveBackendData //Inspire": "Remove Backend Data (Inspire)", |
|
"RemoveBackendDataNumberKey //Inspire": "Remove Backend Data [NumberKey] (Inspire)", |
|
"ShowCachedInfo //Inspire": "Show Cached Info (Inspire)", |
|
"CheckpointLoaderSimpleShared //Inspire": "Shared Checkpoint Loader (Inspire)", |
|
"StableCascade_CheckpointLoader //Inspire": "Stable Cascade Checkpoint Loader (Inspire)" |
|
} |
|
|