Spaces:
Runtime error
Runtime error
File size: 4,377 Bytes
2a35740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch.nn.functional as nnf
import torch
import numpy as np
from tqdm import tqdm
from torch.optim.adam import Adam
from PIL import Image
from generation import load_512
from p2p import register_attention_control
def null_optimization(solver,
latents,
guidance_scale,
num_inner_steps,
epsilon):
uncond_embeddings, cond_embeddings = solver.context.chunk(2)
uncond_embeddings_list = []
latent_cur = latents[-1]
bar = tqdm(total=num_inner_steps * solver.n_steps)
for i in range(solver.n_steps):
uncond_embeddings = uncond_embeddings.clone().detach()
uncond_embeddings.requires_grad = True
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
latent_prev = latents[len(latents) - i - 2]
t = solver.model.scheduler.timesteps[i]
with torch.no_grad():
noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings)
for j in range(num_inner_steps):
noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur)
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_item = loss.item()
bar.update()
if loss_item < epsilon + i * 2e-5:
break
for j in range(j + 1, num_inner_steps):
bar.update()
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
with torch.no_grad():
context = torch.cat([uncond_embeddings, cond_embeddings])
noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context)
latent_cur = solver.prev_step(noise_pred, t, latent_cur)
bar.close()
return uncond_embeddings_list
def invert(solver,
stop_step,
is_cons_inversion=False,
inv_guidance_scale=1,
nti_guidance_scale=8,
dynamic_guidance=False,
tau1=0.4,
tau2=0.6,
w_embed_dim=0,
image_path=None,
prompt='',
offsets=(0, 0, 0, 0),
do_nti=False,
do_npi=False,
num_inner_steps=10,
early_stop_epsilon=1e-5,
seed=0,
):
solver.init_prompt(prompt)
uncond_embeddings, cond_embeddings = solver.context.chunk(2)
register_attention_control(solver.model, None)
if isinstance(image_path, list):
image_gt = [load_512(path, *offsets) for path in image_path]
elif isinstance(image_path, str):
image_gt = load_512(image_path, *offsets)
else:
image_gt = np.array(Image.fromarray(image_path).resize((512, 512)))
if is_cons_inversion:
image_rec, ddim_latents = solver.cons_inversion(image_gt,
w_embed_dim=w_embed_dim,
guidance_scale=inv_guidance_scale,
seed=seed,)
else:
image_rec, ddim_latents = solver.ddim_inversion(image_gt,
n_steps=stop_step,
guidance_scale=inv_guidance_scale,
dynamic_guidance=dynamic_guidance,
tau1=tau1, tau2=tau2,
w_embed_dim=w_embed_dim)
if do_nti:
print("Null-text optimization...")
uncond_embeddings = null_optimization(solver,
ddim_latents,
nti_guidance_scale,
num_inner_steps,
early_stop_epsilon)
elif do_npi:
uncond_embeddings = [cond_embeddings] * solver.n_steps
else:
uncond_embeddings = None
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
|