File size: 4,377 Bytes
2a35740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch.nn.functional as nnf
import torch
import numpy as np

from tqdm import tqdm
from torch.optim.adam import Adam
from PIL import Image

from generation import load_512
from p2p import register_attention_control


def null_optimization(solver,
                      latents,
                      guidance_scale,
                      num_inner_steps,
                      epsilon):
    uncond_embeddings, cond_embeddings = solver.context.chunk(2)
    uncond_embeddings_list = []
    latent_cur = latents[-1]
    bar = tqdm(total=num_inner_steps * solver.n_steps)
    for i in range(solver.n_steps):
        uncond_embeddings = uncond_embeddings.clone().detach()
        uncond_embeddings.requires_grad = True
        optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
        latent_prev = latents[len(latents) - i - 2]
        t = solver.model.scheduler.timesteps[i]
        with torch.no_grad():
            noise_pred_cond = solver.get_noise_pred_single(latent_cur, t, cond_embeddings)
        for j in range(num_inner_steps):
            noise_pred_uncond = solver.get_noise_pred_single(latent_cur, t, uncond_embeddings)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            latents_prev_rec = solver.prev_step(noise_pred, t, latent_cur)
            loss = nnf.mse_loss(latents_prev_rec, latent_prev)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_item = loss.item()
            bar.update()
            if loss_item < epsilon + i * 2e-5:
                break
        for j in range(j + 1, num_inner_steps):
            bar.update()
        uncond_embeddings_list.append(uncond_embeddings[:1].detach())
        with torch.no_grad():
            context = torch.cat([uncond_embeddings, cond_embeddings])
            noise_pred = solver.get_noise_pred(solver.model, latent_cur, t, guidance_scale, context)
            latent_cur = solver.prev_step(noise_pred, t, latent_cur)
    bar.close()
    return uncond_embeddings_list


def invert(solver, 
           stop_step,
           is_cons_inversion=False,
           inv_guidance_scale=1, 
           nti_guidance_scale=8,
           dynamic_guidance=False,
           tau1=0.4,
           tau2=0.6,
           w_embed_dim=0,
           image_path=None, 
           prompt='',
           offsets=(0, 0, 0, 0),
           do_nti=False,
           do_npi=False,
           num_inner_steps=10, 
           early_stop_epsilon=1e-5,
           seed=0,
           ):
    solver.init_prompt(prompt)
    uncond_embeddings, cond_embeddings = solver.context.chunk(2)
    register_attention_control(solver.model, None)
    if isinstance(image_path, list):
        image_gt = [load_512(path, *offsets) for path in image_path]
    elif isinstance(image_path, str):
        image_gt = load_512(image_path, *offsets)
    else:
        image_gt = np.array(Image.fromarray(image_path).resize((512, 512)))
    
    if is_cons_inversion:
        image_rec, ddim_latents = solver.cons_inversion(image_gt,
                                                        w_embed_dim=w_embed_dim,
                                                        guidance_scale=inv_guidance_scale,
                                                        seed=seed,)
    else:  
        image_rec, ddim_latents = solver.ddim_inversion(image_gt, 
                                                        n_steps=stop_step,
                                                        guidance_scale=inv_guidance_scale,
                                                        dynamic_guidance=dynamic_guidance,
                                                        tau1=tau1, tau2=tau2, 
                                                        w_embed_dim=w_embed_dim)
    if do_nti:
        print("Null-text optimization...")
        uncond_embeddings = null_optimization(solver,
                                              ddim_latents,
                                              nti_guidance_scale,
                                              num_inner_steps,
                                              early_stop_epsilon)
    elif do_npi:
        uncond_embeddings = [cond_embeddings] * solver.n_steps
    else:
        uncond_embeddings = None
    return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings