|
import torch
|
|
|
|
|
|
|
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
|
to = model_options["transformer_options"].copy()
|
|
|
|
if "patches_replace" not in to:
|
|
to["patches_replace"] = {}
|
|
else:
|
|
to["patches_replace"] = to["patches_replace"].copy()
|
|
|
|
if name not in to["patches_replace"]:
|
|
to["patches_replace"][name] = {}
|
|
else:
|
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
|
|
|
if transformer_index is not None:
|
|
block = (block_name, number, transformer_index)
|
|
else:
|
|
block = (block_name, number)
|
|
to["patches_replace"][name][block] = patch
|
|
model_options["transformer_options"] = to
|
|
return model_options
|
|
|
|
|
|
|
|
def rescale_pag(pag: torch.Tensor, cond_pred: torch.Tensor, cfg_result: torch.Tensor, rescale=0.0, rescale_mode="full"):
|
|
if rescale == 0.0:
|
|
return pag
|
|
|
|
match rescale_mode:
|
|
case "full":
|
|
pag_result = cfg_result + pag
|
|
case _:
|
|
pag_result = cond_pred + pag
|
|
|
|
std_cond = torch.std(cond_pred, dim=(1, 2, 3), keepdim=True)
|
|
std_pag = torch.std(pag_result, dim=(1, 2, 3), keepdim=True)
|
|
|
|
factor = std_cond / std_pag
|
|
factor = rescale * factor + (1.0 - rescale)
|
|
|
|
return pag * factor
|
|
|