|
def load_torch_file(ckpt, safe_load=False, device=None): |
|
if device is None: |
|
device = torch.device("cpu") |
|
if ckpt.lower().endswith(".safetensors"): |
|
sd = safetensors.torch.load_file(ckpt, device=device.type) |
|
else: |
|
if safe_load: |
|
if not 'weights_only' in torch.load.__code__.co_varnames: |
|
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") |
|
safe_load = False |
|
if safe_load: |
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) |
|
else: |
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) |
|
if "global_step" in pl_sd: |
|
logging.debug(f"Global Step: {pl_sd['global_step']}") |
|
if "state_dict" in pl_sd: |
|
sd = pl_sd["state_dict"] |
|
else: |
|
sd = pl_sd |
|
|
|
if sd.get('id_encoder', None) and (lora_weights:=sd.get('lora_weights', None)) and len(sd) == 2: |
|
def find_outer_instance(target:str, target_type): |
|
import inspect |
|
frame = inspect.currentframe() |
|
i = 0 |
|
while frame and i < 5: |
|
if (found:=frame.f_locals.get(target, None)) is not None: |
|
if isinstance(found, target_type): |
|
return found |
|
frame = frame.f_back |
|
i += 1 |
|
return None |
|
if find_outer_instance('lora_name', str) is not None: |
|
sd = lora_weights |
|
return sd |
|
|
|
|