|
try:
|
|
import pag_nodes
|
|
|
|
if pag_nodes.BACKEND == "Forge":
|
|
import gradio as gr
|
|
|
|
from modules import scripts
|
|
from modules.ui_components import InputAccordion
|
|
|
|
opPerturbedAttention = pag_nodes.PerturbedAttention()
|
|
|
|
class PerturbedAttentionScript(scripts.Script):
|
|
def title(self):
|
|
return "Perturbed-Attention Guidance"
|
|
|
|
def show(self, is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def ui(self, *args, **kwargs):
|
|
with gr.Accordion(open=False, label=self.title()):
|
|
enabled = gr.Checkbox(label="Enabled", value=False)
|
|
scale = gr.Slider(label="PAG Scale", minimum=0.0, maximum=30.0, step=0.01, value=3.0)
|
|
with gr.Row():
|
|
rescale_pag = gr.Slider(label="Rescale PAG", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
|
|
rescale_mode = gr.Dropdown(choices=["full", "partial"], value="full", label="Rescale Mode")
|
|
adaptive_scale = gr.Slider(label="Adaptive Scale", minimum=0.0, maximum=1.0, step=0.001, value=0.0)
|
|
with InputAccordion(False, label="Override for Hires. fix") as hr_override:
|
|
hr_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label="CFG Scale", value=7.0)
|
|
hr_scale = gr.Slider(label="PAG Scale", minimum=0.0, maximum=30.0, step=0.01, value=3.0)
|
|
with gr.Row():
|
|
hr_rescale_pag = gr.Slider(label="Rescale PAG", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
|
|
hr_rescale_mode = gr.Dropdown(choices=["full", "partial"], value="full", label="Rescale Mode")
|
|
hr_adaptive_scale = gr.Slider(label="Adaptive Scale", minimum=0.0, maximum=1.0, step=0.001, value=0.0)
|
|
with gr.Row():
|
|
block = gr.Dropdown(choices=["input", "middle", "output"], value="middle", label="U-Net Block")
|
|
block_id = gr.Number(label="U-Net Block Id", value=0, precision=0, minimum=0)
|
|
block_list = gr.Text(label="U-Net Block List")
|
|
with gr.Row():
|
|
sigma_start = gr.Number(minimum=-1.0, label="Sigma Start", value=-1.0)
|
|
sigma_end = gr.Number(minimum=-1.0, label="Sigma End", value=-1.0)
|
|
|
|
self.infotext_fields = (
|
|
(enabled, lambda p: gr.Checkbox.update(value="pag_enabled" in p)),
|
|
(scale, "pag_scale"),
|
|
(rescale_pag, "pag_rescale"),
|
|
(rescale_mode, lambda p: gr.Dropdown.update(value=p.get("pag_rescale_mode", "full"))),
|
|
(adaptive_scale, "pag_adaptive_scale"),
|
|
(hr_override, lambda p: gr.Checkbox.update(value="hr_override" in p)),
|
|
(hr_cfg, "pag_hr_cfg"),
|
|
(hr_scale, "pag_hr_scale"),
|
|
(hr_rescale_pag, "pag_hr_rescale"),
|
|
(hr_rescale_mode, lambda p: gr.Dropdown.update(value=p.get("pag_hr_rescale_mode", "full"))),
|
|
(hr_adaptive_scale, "pag_hr_adaptive_scale"),
|
|
(block, lambda p: gr.Dropdown.update(value=p.get("pag_block", "middle"))),
|
|
(block_id, "pag_block_id"),
|
|
(block_list, lambda p: gr.Text.update(value=p.get("pag_block_list", ""))),
|
|
(sigma_start, "pag_sigma_start"),
|
|
(sigma_end, "pag_sigma_end"),
|
|
)
|
|
|
|
return enabled, scale, rescale_pag, rescale_mode, adaptive_scale, block, block_id, block_list, hr_override, hr_cfg, hr_scale, hr_rescale_pag, hr_rescale_mode, hr_adaptive_scale, sigma_start, sigma_end
|
|
|
|
def process_before_every_sampling(self, p, *script_args, **kwargs):
|
|
(
|
|
enabled,
|
|
scale,
|
|
rescale_pag,
|
|
rescale_mode,
|
|
adaptive_scale,
|
|
block,
|
|
block_id,
|
|
block_list,
|
|
hr_override,
|
|
hr_cfg,
|
|
hr_scale,
|
|
hr_rescale_pag,
|
|
hr_rescale_mode,
|
|
hr_adaptive_scale,
|
|
sigma_start,
|
|
sigma_end,
|
|
) = script_args
|
|
|
|
if not enabled:
|
|
return
|
|
|
|
unet = p.sd_model.forge_objects.unet
|
|
|
|
hr_enabled = getattr(p, "enable_hr", False)
|
|
|
|
if hr_enabled and p.is_hr_pass and hr_override:
|
|
p.cfg_scale_before_hr = p.cfg_scale
|
|
p.cfg_scale = hr_cfg
|
|
unet = opPerturbedAttention.patch(unet, hr_scale, hr_adaptive_scale, block, block_id, sigma_start, sigma_end, hr_rescale_pag, hr_rescale_mode, block_list)[0]
|
|
else:
|
|
unet = opPerturbedAttention.patch(unet, scale, adaptive_scale, block, block_id, sigma_start, sigma_end, rescale_pag, rescale_mode, block_list)[0]
|
|
|
|
p.sd_model.forge_objects.unet = unet
|
|
|
|
p.extra_generation_params.update(
|
|
dict(
|
|
pag_enabled=enabled,
|
|
pag_scale=scale,
|
|
pag_rescale=rescale_pag,
|
|
pag_rescale_mode=rescale_mode,
|
|
pag_adaptive_scale=adaptive_scale,
|
|
pag_block=block,
|
|
pag_block_id=block_id,
|
|
pag_block_list=block_list,
|
|
)
|
|
)
|
|
if hr_enabled:
|
|
p.extra_generation_params["pag_hr_override"] = hr_override
|
|
if hr_override:
|
|
p.extra_generation_params.update(
|
|
dict(
|
|
pag_hr_cfg=hr_cfg,
|
|
pag_hr_scale=hr_scale,
|
|
pag_hr_rescale=hr_rescale_pag,
|
|
pag_hr_rescale_mode=hr_rescale_mode,
|
|
pag_hr_adaptive_scale=hr_adaptive_scale,
|
|
)
|
|
)
|
|
if sigma_start >= 0 or sigma_end >= 0:
|
|
p.extra_generation_params.update(
|
|
dict(
|
|
pag_sigma_start=sigma_start,
|
|
pag_sigma_end=sigma_end,
|
|
)
|
|
)
|
|
|
|
return
|
|
|
|
def post_sample(self, p, ps, *script_args):
|
|
(
|
|
enabled,
|
|
scale,
|
|
rescale_pag,
|
|
rescale_mode,
|
|
adaptive_scale,
|
|
block,
|
|
block_id,
|
|
block_list,
|
|
hr_override,
|
|
hr_cfg,
|
|
hr_scale,
|
|
hr_rescale_pag,
|
|
hr_rescale_mode,
|
|
hr_adaptive_scale,
|
|
sigma_start,
|
|
sigma_end,
|
|
) = script_args
|
|
|
|
if not enabled:
|
|
return
|
|
|
|
hr_enabled = getattr(p, "enable_hr", False)
|
|
|
|
if hr_enabled and hr_override:
|
|
p.cfg_scale = p.cfg_scale_before_hr
|
|
|
|
return
|
|
|
|
except ImportError:
|
|
pass
|
|
|