3v324v23's picture
lfs
1e3b872
raw
history blame
1.64 kB
import torch
# Copied from https://github.com/comfyanonymous/ComfyUI/blob/719fb2c81d716ce8edd7f1bdc7804ae160a71d3a/comfy/model_patcher.py#L21 for backward compatibility
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"].copy()
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
else:
to["patches_replace"][name] = to["patches_replace"][name].copy()
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
model_options["transformer_options"] = to
return model_options
# Modified 'Algorithm 2 Classifier-Free Guidance with Rescale' from Common Diffusion Noise Schedules and Sample Steps are Flawed (Lin et al.).
def rescale_pag(pag: torch.Tensor, cond_pred: torch.Tensor, cfg_result: torch.Tensor, rescale=0.0, rescale_mode="full"):
if rescale == 0.0:
return pag
match rescale_mode:
case "full":
pag_result = cfg_result + pag
case _:
pag_result = cond_pred + pag
std_cond = torch.std(cond_pred, dim=(1, 2, 3), keepdim=True)
std_pag = torch.std(pag_result, dim=(1, 2, 3), keepdim=True)
factor = std_cond / std_pag
factor = rescale * factor + (1.0 - rescale)
return pag * factor