def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: logging.debug(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: sd = pl_sd if sd.get('id_encoder', None) and (lora_weights:=sd.get('lora_weights', None)) and len(sd) == 2: def find_outer_instance(target:str, target_type): import inspect frame = inspect.currentframe() i = 0 while frame and i < 5: if (found:=frame.f_locals.get(target, None)) is not None: if isinstance(found, target_type): return found frame = frame.f_back i += 1 return None if find_outer_instance('lora_name', str) is not None: sd = lora_weights return sd