File size: 1,555 Bytes
1e3b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def load_torch_file(ckpt, safe_load=False, device=None):
    if device is None:
        device = torch.device("cpu")
    if ckpt.lower().endswith(".safetensors"):
        sd = safetensors.torch.load_file(ckpt, device=device.type)
    else:
        if safe_load:
            if not 'weights_only' in torch.load.__code__.co_varnames:
                logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
                safe_load = False
        if safe_load:
            pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
        else:
            pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
        if "global_step" in pl_sd:
            logging.debug(f"Global Step: {pl_sd['global_step']}")
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
    
    if sd.get('id_encoder', None) and (lora_weights:=sd.get('lora_weights', None)) and len(sd) == 2:
        def find_outer_instance(target:str, target_type):
            import inspect
            frame = inspect.currentframe()
            i = 0
            while frame and i < 5:
                if (found:=frame.f_locals.get(target, None)) is not None:
                    if isinstance(found, target_type):
                        return found
                frame = frame.f_back
                i += 1
            return None
        if find_outer_instance('lora_name', str) is not None:
            sd = lora_weights
    return sd