Spaces:
Runtime error
Runtime error
Add InstructPix2Pix
Browse filesFormer-commit-id: 3626d699482f2419961432bff2e1763ccf55f6e7
- LICENSE +9 -0
- README.md +158 -0
- configs/generate.yaml +99 -0
- configs/train.yaml +113 -0
- dataset_creation/generate_img_dataset.py +297 -0
- dataset_creation/generate_txt_dataset.py +113 -0
- dataset_creation/prepare_dataset.py +29 -0
- dataset_creation/prepare_for_gpt.py +25 -0
- edit_app.py +269 -0
- edit_cli.py +128 -0
- edit_dataset.py +72 -0
- environment.yaml +37 -0
- imgs/example.jpg +0 -0
- main.py +797 -0
- metrics/clip_similarity.py +47 -0
- prompt_app.py +55 -0
- scripts/download_checkpoints.sh +7 -0
- scripts/download_data.sh +11 -0
- stable_diffusion/ldm/models/diffusion/ddpm_edit.py +1459 -0
- stable_diffusion/ldm/modules/attention.py +16 -2
- stable_diffusion/main.py +1 -1
LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
8 |
+
|
9 |
+
Portions of code and models (such as pretrained checkpoints, which are fine-tuned starting from released Stable Diffusion checkpoints) are derived from the Stable Diffusion codebase (https://github.com/CompVis/stable-diffusion). Further restrictions may apply. Please consult the Stable Diffusion license `stable_diffusion/LICENSE`. Modified code is denoted as such in comments at the start of each file.
|
README.md
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# InstructPix2Pix: Learning to Follow Image Editing Instructions
|
2 |
+
### [Project Page](https://www.timothybrooks.com/instruct-pix2pix/) | [Paper](https://arxiv.org/abs/2211.09800) | [Data](http://instruct-pix2pix.eecs.berkeley.edu/)
|
3 |
+
PyTorch implementation of InstructPix2Pix, an instruction-based image editing model, based on the original [CompVis/stable_diffusion](https://github.com/CompVis/stable-diffusion) repo. <br>
|
4 |
+
|
5 |
+
[InstructPix2Pix: Learning to Follow Image Editing Instructions](https://www.timothybrooks.com/instruct-pix2pix/)
|
6 |
+
[Tim Brooks](https://www.timothybrooks.com/)\*,
|
7 |
+
[Aleksander Holynski](https://holynski.org/)\*,
|
8 |
+
[Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) <br>
|
9 |
+
UC Berkeley <br>
|
10 |
+
\*denotes equal contribution
|
11 |
+
|
12 |
+
<img src='https://instruct-pix2pix.timothybrooks.com/teaser.jpg'/>
|
13 |
+
|
14 |
+
## TL;DR: quickstart
|
15 |
+
|
16 |
+
To setup a conda environment, download a pretrained model, and edit an image:
|
17 |
+
```
|
18 |
+
conda env create -f environment.yaml
|
19 |
+
conda activate ip2p
|
20 |
+
bash scripts/download_checkpoints.sh
|
21 |
+
python edit_cli.py --input imgs/example.jpg --output imgs/output.jpg --edit "turn him into a cyborg"
|
22 |
+
|
23 |
+
# Optionally, you can specify parameters:
|
24 |
+
# python edit_cli.py --steps 100 --resolution 512 --seed 0 --cfg-text 7.5 --cfg-image 1.2 --input imgs/example.jpg --output imgs/output.jpg --edit "turn him into a cyborg"
|
25 |
+
```
|
26 |
+
|
27 |
+
## Setup
|
28 |
+
|
29 |
+
Install all dependencies with:
|
30 |
+
```
|
31 |
+
conda env create -f environment.yaml
|
32 |
+
```
|
33 |
+
|
34 |
+
Download the pretrained models by running:
|
35 |
+
```
|
36 |
+
bash scripts/download_checkpoints.sh
|
37 |
+
```
|
38 |
+
|
39 |
+
## Generated Dataset
|
40 |
+
|
41 |
+
Our image editing model is trained on a generated dataset consisting of 454,445 examples. Each example contains (1) an input image, (2) an editing instruction, and (3) an output edited image. We provide two versions of the dataset, one in which each pair of edited images is generated 100 times, and the best examples are chosen based on CLIP metrics (Section 3.1.2 in the paper) (`clip-filtered-dataset`), and one in which examples are randomly chosen (`random-sample-dataset`).
|
42 |
+
|
43 |
+
For the released version of this dataset, we've additionally filtered prompts and images for NSFW content. After NSFW filtering, the GPT-3 generated dataset contains 451,990 examples. The final image-pair datasets contain:
|
44 |
+
|
45 |
+
| | # of image editing examples | Dataset size |
|
46 |
+
|--|-----------------------|----------------------- |
|
47 |
+
| `random-sample-dataset` |451990|727GB|
|
48 |
+
| `clip-filtered-dataset` |313010|436GB|
|
49 |
+
|
50 |
+
To download one of these datasets, along with the entire NSFW-filtered text data, run the following command with the appropriate dataset name:
|
51 |
+
|
52 |
+
```
|
53 |
+
bash scripts/download_data.sh clip-filtered-dataset
|
54 |
+
```
|
55 |
+
|
56 |
+
|
57 |
+
## Training InstructPix2Pix
|
58 |
+
|
59 |
+
Need to modify configs/instruct-pix2pix/default.yaml to point to the dataset in the right location. Need to also download the Stable Diffusion checkpoint from which to finetune.
|
60 |
+
|
61 |
+
```
|
62 |
+
python stable_diffusion/main.py --name default --base configs/train.yaml --train --gpus 0,1,2,3,4,5,6,7
|
63 |
+
```
|
64 |
+
|
65 |
+
|
66 |
+
## Creating your own dataset
|
67 |
+
|
68 |
+
Our generated dataset of paired images and editing instructions is made in two phases: First, we use GPT-3 to generate text triplets: (a) a caption describing an image, (b) an edit instruction, (c) a caption describing the image after the edit. Then, we turn pairs of captions (before/after the edit) into pairs of images using Stable Diffusion and Prompt-to-Prompt.
|
69 |
+
|
70 |
+
### (1) Generate a dataset of captions and instructions
|
71 |
+
|
72 |
+
We provide our generated dataset of captions and edit instructions [here](https://instruct-pix2pix.eecs.berkeley.edu/gpt-generated-prompts.jsonl). If you plan to use our captions+instructions, skip to step (2). Otherwise, if you would like to create your own text dataset, please follow steps (1.1-1.3) below. Note that generating very large datasets using GPT-3 can be expensive.
|
73 |
+
|
74 |
+
#### (1.1) Manually write a dataset of instructions and captions
|
75 |
+
|
76 |
+
The first step of the process is fine-tuning GPT-3. To do this, we made a dataset of 700 examples broadly covering of edits that we might want our model to be able to perform. Our examples are available here [here](https://instruct-pix2pix.eecs.berkeley.edu/human_written_examples.jsonl). These should be diverse and cover a wide range of possible captions and types of edits. Ideally, they should avoid duplication or significant overlap of captions and instructions. It is also important to be mindful of limitations of Stable Diffusion and Prompt-to-Prompt in writing these examples, such as inability to perform large spatial transformations (e.g., moving the camera, zooming in, swapping object locations).
|
77 |
+
|
78 |
+
Input prompts should closely match the distribution of input prompts used to generate the larger dataset. We sampled the 700 input prompts from LAION Improves Aesthetics 6.5+ dataset and also use this dataset for generating examples. We found this dataset is quite noisy (many of the captions are overly long and contain irrelevant text). For this reason, we also considered MSCOCO and LAION-COCO datasets, but ultimately chose LAION Improves Aesthetics 6.5+ due to its diversity of content, proper nouns, and artistic mediums. If you choose to use another dataset or combination of datasets as input to GPT-3 when generating examples, we recomend you sample the input prompts from the same distribution when manually writing training examples.
|
79 |
+
|
80 |
+
#### (1.2) Finetune GPT-3
|
81 |
+
|
82 |
+
The next step is to finetune a large language model to generate an edit instruction and edited caption from a new input caption. We use GPT-3 Davinci via the OpenAI API, although other language models could be used.
|
83 |
+
|
84 |
+
To prepare training data for GPT-3, one must setup an OpenAI developer account to access the needed APIs. Run the `prompts/prepare_for_gpt.py` script, which forms the prompts into the correct format by concatenating instructions and captions and adding delimiters and stop sequences.
|
85 |
+
|
86 |
+
```bash
|
87 |
+
python dataset_creation/prepare_for_gpt.py prompts/human_written_examples.jsonl prompts/human_written_examples_for_gpt.jsonl
|
88 |
+
```
|
89 |
+
|
90 |
+
Next, finetune GPT-3 via the OpenAI CLI. We provide an example below, although please refer to the official documentation here as best practices may change. We trained the Davinci model for a single epoch. You could experiment with smaller less expensive GPT-3 variants or with open source language models, although this may negatively hurt performance.
|
91 |
+
|
92 |
+
```bash
|
93 |
+
openai api fine_tunes.create -t prompts/human_written_examples_for_gpt.jsonl -m davinci --n_epochs 1 --suffix "instruct-pix2pix"
|
94 |
+
```
|
95 |
+
|
96 |
+
You can test out the finetuned GPT-3 model by launching the provided Gradio app:
|
97 |
+
|
98 |
+
```bash
|
99 |
+
python prompt_app.py OPENAI_MODEL_NAME
|
100 |
+
```
|
101 |
+
|
102 |
+
#### (1.3) Generate a large dataset of captions and instructions
|
103 |
+
|
104 |
+
We now use the finetuned GPT-3 model to generate a large dataset. Our dataset cost thousands of dollars to create. See `prompts/gen_instructions_and_captions.py` for the script which generates these examples. We recommend first generating a small number of examples and gradually increasing the scale to ensure the results are working as desired before increasing scale.
|
105 |
+
|
106 |
+
```bash
|
107 |
+
python dataset_creation/generate_txt_dataset.py OPENAI_MODEL_NAME
|
108 |
+
```
|
109 |
+
|
110 |
+
If you are generating at a very large scale (e.g., 100K+), it will be noteably faster to generate the dataset with multiple processes running in parallel. This can be accomplished by setting `--partitions=N` to a higher number and running multiple processes, setting each `--partition` to the corresponding value.
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python dataset_creation/generate_txt_dataset.py OPENAI_MODEL_NAME --partitions=10 --partition=0
|
114 |
+
```
|
115 |
+
|
116 |
+
### (2) Turn paired captions into paired images
|
117 |
+
|
118 |
+
The next step is to turn pairs of text captions into pairs of images. For this, we need to copy a pre-trained Stable Diffusion model checkpoint to `stable_diffusion/models/ldm/stable-diffusion-v1/`. For our model, we used [checkpoint v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt), but other versions may also work. It is also necessary to download a checkpoint for the Stable Diffusion autoencoder. We used the [new autoencoder](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt), which should be put in the same directory. Once all checkpoints have been downloaded, we can generate the dataset with the following command:
|
119 |
+
|
120 |
+
```
|
121 |
+
python dataset_creation/generate_img_dataset.py data/instruct-pix2pix-dataset-000 data/gpt_generated_prompts.jsonl
|
122 |
+
```
|
123 |
+
|
124 |
+
This command operates on a single GPU (typically a V100 or A100). To parallelize over many GPUs/machines, set `--n-partitions` to the total number of parallel jobs and `--partition` to the index of each job.
|
125 |
+
|
126 |
+
```
|
127 |
+
python dataset_creation/generate_img_dataset.py data/instruct-pix2pix-dataset-000 data/gpt_generated_prompts.jsonl --n-partitions 100 --partition 0
|
128 |
+
```
|
129 |
+
|
130 |
+
The default parameters match that of our dataset, although in practice you can use a smaller number of steps (e.g., `--steps=25`) to generate high quality data faster. By default, we generate 100 samples per prompt and use CLIP filtering to keep a max of 4 per prompt. You can experiment with fewer samples by setting `--n-samples`. The command below turns off CLIP filtering entirely and is therefore faster:
|
131 |
+
|
132 |
+
```
|
133 |
+
python dataset_creation/generate_img_dataset.py data/instruct-pix2pix-dataset-000 data/gpt_generated_prompts.jsonl --n-samples 4 --clip-threshold 0 --clip-dir-threshold 0 --clip-img-threshold 0 --n-partitions 100 --partition 0
|
134 |
+
```
|
135 |
+
|
136 |
+
After generating all of the dataset examples, run the following command below to create a list of the examples. This is needed for the dataset onject to efficiently be able to sample examples without needing to iterate over the entire dataset directory at the start of each training run.
|
137 |
+
|
138 |
+
```
|
139 |
+
python dataset_creation/prepare_dataset.py data/instruct-pix2pix-dataset-000
|
140 |
+
```
|
141 |
+
|
142 |
+
## Comments
|
143 |
+
|
144 |
+
- Our codebase is based on the [Stable Diffusion codebase](https://github.com/CompVis/stable-diffusion).
|
145 |
+
|
146 |
+
## BibTeX
|
147 |
+
|
148 |
+
```
|
149 |
+
@article{brooks2022instructpix2pix,
|
150 |
+
title={InstructPix2Pix: Learning to Follow Image Editing Instructions},
|
151 |
+
author={Brooks, Tim and Holynski, Aleksander and Efros, Alexei A},
|
152 |
+
journal={arXiv preprint arXiv:2211.09800},
|
153 |
+
year={2022}
|
154 |
+
}
|
155 |
+
```
|
156 |
+
|
157 |
+
|
158 |
+
|
configs/generate.yaml
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
model:
|
5 |
+
base_learning_rate: 1.0e-04
|
6 |
+
target: stable_diffusion.ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
7 |
+
params:
|
8 |
+
linear_start: 0.00085
|
9 |
+
linear_end: 0.0120
|
10 |
+
num_timesteps_cond: 1
|
11 |
+
log_every_t: 200
|
12 |
+
timesteps: 1000
|
13 |
+
first_stage_key: edited
|
14 |
+
cond_stage_key: edit
|
15 |
+
# image_size: 64
|
16 |
+
# image_size: 32
|
17 |
+
image_size: 16
|
18 |
+
channels: 4
|
19 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
20 |
+
conditioning_key: hybrid
|
21 |
+
monitor: val/loss_simple_ema
|
22 |
+
scale_factor: 0.18215
|
23 |
+
use_ema: true
|
24 |
+
load_ema: true
|
25 |
+
|
26 |
+
scheduler_config: # 10000 warmup steps
|
27 |
+
target: stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler
|
28 |
+
params:
|
29 |
+
warm_up_steps: [ 0 ]
|
30 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
31 |
+
f_start: [ 1.e-6 ]
|
32 |
+
f_max: [ 1. ]
|
33 |
+
f_min: [ 1. ]
|
34 |
+
|
35 |
+
unet_config:
|
36 |
+
target: stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel
|
37 |
+
params:
|
38 |
+
image_size: 32 # unused
|
39 |
+
in_channels: 8
|
40 |
+
out_channels: 4
|
41 |
+
model_channels: 320
|
42 |
+
attention_resolutions: [ 4, 2, 1 ]
|
43 |
+
num_res_blocks: 2
|
44 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
45 |
+
num_heads: 8
|
46 |
+
use_spatial_transformer: True
|
47 |
+
transformer_depth: 1
|
48 |
+
context_dim: 768
|
49 |
+
use_checkpoint: True
|
50 |
+
legacy: False
|
51 |
+
|
52 |
+
first_stage_config:
|
53 |
+
target: stable_diffusion.ldm.models.autoencoder.AutoencoderKL
|
54 |
+
params:
|
55 |
+
embed_dim: 4
|
56 |
+
monitor: val/rec_loss
|
57 |
+
ddconfig:
|
58 |
+
double_z: true
|
59 |
+
z_channels: 4
|
60 |
+
resolution: 256
|
61 |
+
in_channels: 3
|
62 |
+
out_ch: 3
|
63 |
+
ch: 128
|
64 |
+
ch_mult:
|
65 |
+
- 1
|
66 |
+
- 2
|
67 |
+
- 4
|
68 |
+
- 4
|
69 |
+
num_res_blocks: 2
|
70 |
+
attn_resolutions: []
|
71 |
+
dropout: 0.0
|
72 |
+
lossconfig:
|
73 |
+
target: torch.nn.Identity
|
74 |
+
|
75 |
+
cond_stage_config:
|
76 |
+
target: stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
77 |
+
|
78 |
+
data:
|
79 |
+
target: main.DataModuleFromConfig
|
80 |
+
params:
|
81 |
+
batch_size: 128
|
82 |
+
num_workers: 1
|
83 |
+
wrap: false
|
84 |
+
validation:
|
85 |
+
target: edit_dataset.EditDataset
|
86 |
+
params:
|
87 |
+
path: /shared/holynski/laion-aesthetics-6.5_edit-model=davinci-laion700-1epoch_samples=10000/laion-aesthetics-6.5_edit-model=davinci-laion700-1epoch_samples=10000
|
88 |
+
cache_dir: /shared/timbrooks/image-edit-data/caches
|
89 |
+
cache_name: davinci10k
|
90 |
+
split: val
|
91 |
+
min_text_sim: 0.2
|
92 |
+
min_image_sim: 0.75
|
93 |
+
min_direction_sim: 0.2
|
94 |
+
max_samples_per_prompt: 1
|
95 |
+
min_resize_res: 512
|
96 |
+
max_resize_res: 512
|
97 |
+
crop_res: 512
|
98 |
+
output_as_edit: False
|
99 |
+
real_input: True
|
configs/train.yaml
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
model:
|
5 |
+
base_learning_rate: 1.0e-04
|
6 |
+
target: stable_diffusion.ldm.models.diffusion.ddpm_edit.LatentDiffusion
|
7 |
+
params:
|
8 |
+
ckpt_path: stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
9 |
+
linear_start: 0.00085
|
10 |
+
linear_end: 0.0120
|
11 |
+
num_timesteps_cond: 1
|
12 |
+
log_every_t: 200
|
13 |
+
timesteps: 1000
|
14 |
+
first_stage_key: edited
|
15 |
+
cond_stage_key: edit
|
16 |
+
image_size: 32
|
17 |
+
channels: 4
|
18 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
19 |
+
conditioning_key: hybrid
|
20 |
+
monitor: val/loss_simple_ema
|
21 |
+
scale_factor: 0.18215
|
22 |
+
use_ema: true
|
23 |
+
load_ema: false
|
24 |
+
|
25 |
+
scheduler_config: # 10000 warmup steps
|
26 |
+
target: stable_diffusion.ldm.lr_scheduler.LambdaLinearScheduler
|
27 |
+
params:
|
28 |
+
warm_up_steps: [ 0 ]
|
29 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
30 |
+
f_start: [ 1.e-6 ]
|
31 |
+
f_max: [ 1. ]
|
32 |
+
f_min: [ 1. ]
|
33 |
+
|
34 |
+
unet_config:
|
35 |
+
target: stable_diffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel
|
36 |
+
params:
|
37 |
+
image_size: 32 # unused
|
38 |
+
in_channels: 8
|
39 |
+
out_channels: 4
|
40 |
+
model_channels: 320
|
41 |
+
attention_resolutions: [ 4, 2, 1 ]
|
42 |
+
num_res_blocks: 2
|
43 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
44 |
+
num_heads: 8
|
45 |
+
use_spatial_transformer: True
|
46 |
+
transformer_depth: 1
|
47 |
+
context_dim: 768
|
48 |
+
use_checkpoint: True
|
49 |
+
legacy: False
|
50 |
+
|
51 |
+
first_stage_config:
|
52 |
+
target: stable_diffusion.ldm.models.autoencoder.AutoencoderKL
|
53 |
+
params:
|
54 |
+
embed_dim: 4
|
55 |
+
monitor: val/rec_loss
|
56 |
+
ddconfig:
|
57 |
+
double_z: true
|
58 |
+
z_channels: 4
|
59 |
+
resolution: 256
|
60 |
+
in_channels: 3
|
61 |
+
out_ch: 3
|
62 |
+
ch: 128
|
63 |
+
ch_mult:
|
64 |
+
- 1
|
65 |
+
- 2
|
66 |
+
- 4
|
67 |
+
- 4
|
68 |
+
num_res_blocks: 2
|
69 |
+
attn_resolutions: []
|
70 |
+
dropout: 0.0
|
71 |
+
lossconfig:
|
72 |
+
target: torch.nn.Identity
|
73 |
+
|
74 |
+
cond_stage_config:
|
75 |
+
target: stable_diffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
76 |
+
|
77 |
+
data:
|
78 |
+
target: main.DataModuleFromConfig
|
79 |
+
params:
|
80 |
+
batch_size: 32
|
81 |
+
num_workers: 2
|
82 |
+
train:
|
83 |
+
target: edit_dataset.EditDataset
|
84 |
+
params:
|
85 |
+
path: /home/timbrooks/instruct-pix2pix-datasets/20-20-75
|
86 |
+
split: train
|
87 |
+
min_resize_res: 256
|
88 |
+
max_resize_res: 256
|
89 |
+
crop_res: 256
|
90 |
+
flip_prob: 0.5
|
91 |
+
validation:
|
92 |
+
target: edit_dataset.EditDataset
|
93 |
+
params:
|
94 |
+
path: /home/timbrooks/instruct-pix2pix-datasets/20-20-75
|
95 |
+
split: val
|
96 |
+
min_resize_res: 256
|
97 |
+
max_resize_res: 256
|
98 |
+
crop_res: 256
|
99 |
+
|
100 |
+
lightning:
|
101 |
+
callbacks:
|
102 |
+
image_logger:
|
103 |
+
target: main.ImageLogger
|
104 |
+
params:
|
105 |
+
batch_frequency: 2000
|
106 |
+
max_images: 2
|
107 |
+
increase_log_steps: False
|
108 |
+
|
109 |
+
trainer:
|
110 |
+
max_epochs: 2000
|
111 |
+
benchmark: True
|
112 |
+
accumulate_grad_batches: 4
|
113 |
+
check_val_every_n_epoch: 4
|
dataset_creation/generate_img_dataset.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import k_diffusion
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from PIL import Image
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from stable_diffusion.ldm.modules.attention import CrossAttention
|
16 |
+
from stable_diffusion.ldm.util import instantiate_from_config
|
17 |
+
from metrics.clip_similarity import ClipSimilarity
|
18 |
+
|
19 |
+
|
20 |
+
################################################################################
|
21 |
+
# Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
|
22 |
+
# https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
23 |
+
|
24 |
+
|
25 |
+
def append_dims(x, target_dims):
|
26 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
27 |
+
dims_to_append = target_dims - x.ndim
|
28 |
+
if dims_to_append < 0:
|
29 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
30 |
+
return x[(...,) + (None,) * dims_to_append]
|
31 |
+
|
32 |
+
|
33 |
+
def to_d(x, sigma, denoised):
|
34 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
35 |
+
return (x - denoised) / append_dims(sigma, x.ndim)
|
36 |
+
|
37 |
+
|
38 |
+
def get_ancestral_step(sigma_from, sigma_to):
|
39 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
40 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
41 |
+
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
|
42 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
43 |
+
return sigma_down, sigma_up
|
44 |
+
|
45 |
+
|
46 |
+
def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
|
47 |
+
"""Ancestral sampling with Euler method steps."""
|
48 |
+
s_in = x.new_ones([x.shape[0]])
|
49 |
+
for i in range(len(sigmas) - 1):
|
50 |
+
prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
|
51 |
+
for m in model.modules():
|
52 |
+
if isinstance(m, CrossAttention):
|
53 |
+
m.prompt_to_prompt = prompt_to_prompt
|
54 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
55 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
56 |
+
d = to_d(x, sigmas[i], denoised)
|
57 |
+
# Euler method
|
58 |
+
dt = sigma_down - sigmas[i]
|
59 |
+
x = x + d * dt
|
60 |
+
if sigmas[i + 1] > 0:
|
61 |
+
# Make noise the same across all samples in batch.
|
62 |
+
x = x + torch.randn_like(x[:1]) * sigma_up
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
################################################################################
|
67 |
+
|
68 |
+
|
69 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
|
70 |
+
print(f"Loading model from {ckpt}")
|
71 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
72 |
+
if "global_step" in pl_sd:
|
73 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
74 |
+
sd = pl_sd["state_dict"]
|
75 |
+
if vae_ckpt is not None:
|
76 |
+
print(f"Loading VAE from {vae_ckpt}")
|
77 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
78 |
+
sd = {
|
79 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
80 |
+
for k, v in sd.items()
|
81 |
+
}
|
82 |
+
model = instantiate_from_config(config.model)
|
83 |
+
m, u = model.load_state_dict(sd, strict=False)
|
84 |
+
if len(m) > 0 and verbose:
|
85 |
+
print("missing keys:")
|
86 |
+
print(m)
|
87 |
+
if len(u) > 0 and verbose:
|
88 |
+
print("unexpected keys:")
|
89 |
+
print(u)
|
90 |
+
return model
|
91 |
+
|
92 |
+
|
93 |
+
class CFGDenoiser(nn.Module):
|
94 |
+
def __init__(self, model):
|
95 |
+
super().__init__()
|
96 |
+
self.inner_model = model
|
97 |
+
|
98 |
+
def forward(self, x, sigma, uncond, cond, cfg_scale):
|
99 |
+
x_in = torch.cat([x] * 2)
|
100 |
+
sigma_in = torch.cat([sigma] * 2)
|
101 |
+
cond_in = torch.cat([uncond, cond])
|
102 |
+
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
103 |
+
return uncond + (cond - uncond) * cfg_scale
|
104 |
+
|
105 |
+
|
106 |
+
def to_pil(image: torch.Tensor) -> Image.Image:
|
107 |
+
image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
|
108 |
+
image = Image.fromarray(image.astype(np.uint8))
|
109 |
+
return image
|
110 |
+
|
111 |
+
|
112 |
+
def main():
|
113 |
+
parser = argparse.ArgumentParser()
|
114 |
+
parser.add_argument(
|
115 |
+
"out_dir",
|
116 |
+
type=str,
|
117 |
+
help="Path to output dataset directory.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"prompts_file",
|
121 |
+
type=str,
|
122 |
+
help="Path to prompts .jsonl file.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--steps",
|
126 |
+
type=int,
|
127 |
+
default=100,
|
128 |
+
help="Number of sampling steps.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--n-samples",
|
132 |
+
type=int,
|
133 |
+
default=100,
|
134 |
+
help="Number of samples to generate per prompt (before CLIP filtering).",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--max-out-samples",
|
138 |
+
type=int,
|
139 |
+
default=4,
|
140 |
+
help="Max number of output samples to save per prompt (after CLIP filtering).",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--n-partitions",
|
144 |
+
type=int,
|
145 |
+
default=1,
|
146 |
+
help="Number of total partitions.",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--partition",
|
150 |
+
type=int,
|
151 |
+
default=0,
|
152 |
+
help="Partition index.",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--min-p2p",
|
156 |
+
type=float,
|
157 |
+
default=0.1,
|
158 |
+
help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--max-p2p",
|
162 |
+
type=float,
|
163 |
+
default=0.9,
|
164 |
+
help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--min-cfg",
|
168 |
+
type=float,
|
169 |
+
default=7.5,
|
170 |
+
help="Min classifier free guidance scale.",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--max-cfg",
|
174 |
+
type=float,
|
175 |
+
default=15,
|
176 |
+
help="Max classifier free guidance scale.",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--clip-threshold",
|
180 |
+
type=float,
|
181 |
+
default=0.2,
|
182 |
+
help="CLIP threshold for text-image similarity of each image.",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--clip-dir-threshold",
|
186 |
+
type=float,
|
187 |
+
default=0.2,
|
188 |
+
help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--clip-img-threshold",
|
192 |
+
type=float,
|
193 |
+
default=0.7,
|
194 |
+
help="CLIP threshold for image-image similarity.",
|
195 |
+
)
|
196 |
+
opt = parser.parse_args()
|
197 |
+
|
198 |
+
global_seed = torch.randint(1 << 32, ()).item()
|
199 |
+
print(f"Global seed: {global_seed}")
|
200 |
+
seed_everything(global_seed)
|
201 |
+
|
202 |
+
model = load_model_from_config(
|
203 |
+
OmegaConf.load("configs/stable-diffusion/v1-inference.yaml"),
|
204 |
+
ckpt="models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
|
205 |
+
vae_ckpt="models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
206 |
+
)
|
207 |
+
model.cuda().eval()
|
208 |
+
model_wrap = k_diffusion.external.CompVisDenoiser(model)
|
209 |
+
|
210 |
+
clip_similarity = ClipSimilarity().cuda()
|
211 |
+
|
212 |
+
out_dir = Path(opt.out_dir)
|
213 |
+
out_dir.mkdir(exist_ok=True, parents=True)
|
214 |
+
|
215 |
+
with open(opt.prompts_file) as fp:
|
216 |
+
prompts = [json.loads(line) for line in fp]
|
217 |
+
|
218 |
+
print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
|
219 |
+
prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]
|
220 |
+
|
221 |
+
with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
|
222 |
+
uncond = model.get_learned_conditioning(2 * [""])
|
223 |
+
sigmas = model_wrap.get_sigmas(opt.steps)
|
224 |
+
|
225 |
+
for i, prompt in tqdm(prompts, desc="Prompts"):
|
226 |
+
prompt_dir = out_dir.joinpath(f"{i:07d}")
|
227 |
+
prompt_dir.mkdir(exist_ok=True)
|
228 |
+
|
229 |
+
with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
|
230 |
+
json.dump(prompt, fp)
|
231 |
+
|
232 |
+
cond = model.get_learned_conditioning([prompt["input"], prompt["output"]])
|
233 |
+
results = {}
|
234 |
+
|
235 |
+
with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:
|
236 |
+
|
237 |
+
while len(results) < opt.n_samples:
|
238 |
+
seed = torch.randint(1 << 32, ()).item()
|
239 |
+
if seed in results:
|
240 |
+
continue
|
241 |
+
torch.manual_seed(seed)
|
242 |
+
|
243 |
+
x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
|
244 |
+
x = repeat(x, "1 ... -> n ...", n=2)
|
245 |
+
|
246 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
247 |
+
p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
|
248 |
+
cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
|
249 |
+
extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
|
250 |
+
samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
|
251 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
252 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
253 |
+
|
254 |
+
x0 = x_samples_ddim[0]
|
255 |
+
x1 = x_samples_ddim[1]
|
256 |
+
|
257 |
+
clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
|
258 |
+
x0[None], x1[None], [prompt["input"]], [prompt["output"]]
|
259 |
+
)
|
260 |
+
|
261 |
+
results[seed] = dict(
|
262 |
+
image_0=to_pil(x0),
|
263 |
+
image_1=to_pil(x1),
|
264 |
+
p2p_threshold=p2p_threshold,
|
265 |
+
cfg_scale=cfg_scale,
|
266 |
+
clip_sim_0=clip_sim_0[0].item(),
|
267 |
+
clip_sim_1=clip_sim_1[0].item(),
|
268 |
+
clip_sim_dir=clip_sim_dir[0].item(),
|
269 |
+
clip_sim_image=clip_sim_image[0].item(),
|
270 |
+
)
|
271 |
+
|
272 |
+
progress_bar.update()
|
273 |
+
|
274 |
+
# CLIP filter to get best samples for each prompt.
|
275 |
+
metadata = [
|
276 |
+
(result["clip_sim_dir"], seed)
|
277 |
+
for seed, result in results.items()
|
278 |
+
if result["clip_sim_image"] >= opt.clip_img_threshold
|
279 |
+
and result["clip_sim_dir"] >= opt.clip_dir_threshold
|
280 |
+
and result["clip_sim_0"] >= opt.clip_threshold
|
281 |
+
and result["clip_sim_1"] >= opt.clip_threshold
|
282 |
+
]
|
283 |
+
metadata.sort(reverse=True)
|
284 |
+
for _, seed in metadata[: opt.max_out_samples]:
|
285 |
+
result = results[seed]
|
286 |
+
image_0 = result.pop("image_0")
|
287 |
+
image_1 = result.pop("image_1")
|
288 |
+
image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
|
289 |
+
image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
|
290 |
+
with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
|
291 |
+
fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")
|
292 |
+
|
293 |
+
print("Done.")
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
main()
|
dataset_creation/generate_txt_dataset.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import datasets
|
10 |
+
import numpy as np
|
11 |
+
import openai
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
DELIMITER_0 = "\n##\n"
|
16 |
+
DELIMITER_1 = "\n%%\n"
|
17 |
+
STOP = "\nEND"
|
18 |
+
|
19 |
+
|
20 |
+
def generate(
|
21 |
+
openai_model: str,
|
22 |
+
caption: str,
|
23 |
+
num_retries: int = 3,
|
24 |
+
max_tokens: int = 256,
|
25 |
+
temperature: float = 0.7,
|
26 |
+
top_p: float = 1.0,
|
27 |
+
frequency_penalty: float = 0.1,
|
28 |
+
presence_penalty: float = 0.0,
|
29 |
+
sleep_on_error: float = 1.0,
|
30 |
+
) -> Optional[tuple[str, str]]:
|
31 |
+
for _ in range(1 + num_retries):
|
32 |
+
try:
|
33 |
+
response = openai.Completion.create(
|
34 |
+
model=openai_model,
|
35 |
+
prompt=caption + DELIMITER_0,
|
36 |
+
temperature=temperature,
|
37 |
+
max_tokens=max_tokens,
|
38 |
+
top_p=top_p,
|
39 |
+
frequency_penalty=frequency_penalty,
|
40 |
+
presence_penalty=presence_penalty,
|
41 |
+
stop=[STOP],
|
42 |
+
)
|
43 |
+
except Exception as e:
|
44 |
+
print(e)
|
45 |
+
time.sleep(sleep_on_error)
|
46 |
+
continue
|
47 |
+
output = response["choices"][0]["text"].split(DELIMITER_1)
|
48 |
+
if len(output) == 2:
|
49 |
+
instruction, edited_caption = output
|
50 |
+
results = openai.Moderation.create([instruction, edited_caption])["results"]
|
51 |
+
if results[0]["flagged"] or results[1]["flagged"]:
|
52 |
+
continue
|
53 |
+
if caption.strip().strip(".!?").lower() != edited_caption.strip().strip(".!?").lower():
|
54 |
+
return instruction, edited_caption
|
55 |
+
|
56 |
+
|
57 |
+
def main(openai_model: str, num_samples: int, num_partitions: int, partition: int, seed: int):
|
58 |
+
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
|
59 |
+
# Other datasets we considered that may be worth trying:
|
60 |
+
# dataset = datasets.load_dataset("ChristophSchuhmann/MS_COCO_2017_URL_TEXT", split="train")
|
61 |
+
# dataset = datasets.load_dataset("laion/laion-coco", split="train")
|
62 |
+
|
63 |
+
np.random.seed(seed)
|
64 |
+
permutation = np.array_split(np.random.permutation(len(dataset)), num_partitions)[partition]
|
65 |
+
dataset = dataset[permutation]
|
66 |
+
captions = dataset["TEXT"]
|
67 |
+
urls = dataset["URL"]
|
68 |
+
output_path = f"prompts/dataset=laion-aesthetics-6.5_model={openai_model}_samples={num_samples}_partition={partition}.jsonl" # fmt: skip
|
69 |
+
print(f"Prompt file path: {output_path}")
|
70 |
+
|
71 |
+
count = 0
|
72 |
+
caption_set = set()
|
73 |
+
url_set = set()
|
74 |
+
|
75 |
+
if Path(output_path).exists():
|
76 |
+
with open(output_path, "r") as f:
|
77 |
+
for line in tqdm(f, desc="Resuming from existing prompts"):
|
78 |
+
prompt = json.loads(line)
|
79 |
+
if prompt["caption"] not in caption_set and prompt["url"] not in url_set:
|
80 |
+
caption_set.add(prompt["caption"])
|
81 |
+
url_set.add(prompt["url"])
|
82 |
+
count += 1
|
83 |
+
|
84 |
+
with open(output_path, "a") as fp:
|
85 |
+
with tqdm(total=num_samples - count, desc="Generating instructions and edited captions") as progress_bar:
|
86 |
+
for caption, url in zip(captions, urls):
|
87 |
+
if caption in caption_set or url in url_set:
|
88 |
+
continue
|
89 |
+
if openai.Moderation.create(caption)["results"][0]["flagged"]:
|
90 |
+
continue
|
91 |
+
edit_output = generate(caption)
|
92 |
+
if edit_output is not None:
|
93 |
+
edit, output = edit_output
|
94 |
+
fp.write(f"{json.dumps(dict(caption=caption, edit=edit, output=output, url=url))}\n")
|
95 |
+
count += 1
|
96 |
+
progress_bar.update()
|
97 |
+
caption_set.add(caption)
|
98 |
+
url_set.add(url)
|
99 |
+
if count == num_samples:
|
100 |
+
break
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
parser = ArgumentParser()
|
105 |
+
parser.add_argument("openai-api-key", type=str)
|
106 |
+
parser.add_argument("openai-model", type=str)
|
107 |
+
parser.add_argument("--num-samples", default=10000, type=int)
|
108 |
+
parser.add_argument("--num-partitions", default=1, type=int)
|
109 |
+
parser.add_argument("--partition", default=0, type=int)
|
110 |
+
parser.add_argument("--seed", default=0, type=int)
|
111 |
+
args = parser.parse_args()
|
112 |
+
openai.api_key = args.openai_api_key
|
113 |
+
main(args.openai_model, args.num_samples, args.num_partitions, args.partition, args.seed)
|
dataset_creation/prepare_dataset.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = ArgumentParser()
|
10 |
+
parser.add_argument("dataset_dir")
|
11 |
+
args = parser.parse_args()
|
12 |
+
dataset_dir = Path(args.dataset_dir)
|
13 |
+
|
14 |
+
seeds = []
|
15 |
+
with tqdm(desc="Listing dataset image seeds") as progress_bar:
|
16 |
+
for prompt_dir in dataset_dir.iterdir():
|
17 |
+
if prompt_dir.is_dir():
|
18 |
+
prompt_seeds = [image_path.name.split("_")[0] for image_path in sorted(prompt_dir.glob("*_0.jpg"))]
|
19 |
+
if len(prompt_seeds) > 0:
|
20 |
+
seeds.append((prompt_dir.name, prompt_seeds))
|
21 |
+
progress_bar.update()
|
22 |
+
seeds.sort()
|
23 |
+
|
24 |
+
with open(dataset_dir.joinpath("seeds.json"), "w") as f:
|
25 |
+
json.dump(seeds, f)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
main()
|
dataset_creation/prepare_for_gpt.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
from .generate_txt_dataset import DELIMITER_0, DELIMITER_1, STOP
|
5 |
+
|
6 |
+
|
7 |
+
def main(input_path: str, output_path: str):
|
8 |
+
with open(input_path) as f:
|
9 |
+
prompts = [json.loads(l) for l in f]
|
10 |
+
|
11 |
+
with open(output_path, "w") as f:
|
12 |
+
for prompt in prompts:
|
13 |
+
prompt_for_gpt = {
|
14 |
+
"prompt": f"{prompt['input']}{DELIMITER_0}",
|
15 |
+
"completion": f"{prompt['edit']}{DELIMITER_1}{prompt['output']}{STOP}",
|
16 |
+
}
|
17 |
+
f.write(f"{json.dumps(prompt_for_gpt)}\n")
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = ArgumentParser()
|
22 |
+
parser.add_argument("input-path", type=str)
|
23 |
+
parser.add_argument("output-path", type=str)
|
24 |
+
args = parser.parse_args()
|
25 |
+
main(args.input_path, args.output_path)
|
edit_app.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
import einops
|
9 |
+
import gradio as gr
|
10 |
+
import k_diffusion as K
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from einops import rearrange
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from PIL import Image, ImageOps
|
17 |
+
from torch import autocast
|
18 |
+
|
19 |
+
sys.path.append("./stable_diffusion")
|
20 |
+
|
21 |
+
from stable_diffusion.ldm.util import instantiate_from_config
|
22 |
+
|
23 |
+
|
24 |
+
help_text = """
|
25 |
+
If you're not getting what you want, there may be a few reasons:
|
26 |
+
1. Is the image not changing enough? Your Image CFG weight may be too high. This value dictates how similar the output should be to the input. It's possible your edit requires larger changes from the original image, and your Image CFG weight isn't allowing that. Alternatively, your Text CFG weight may be too low. This value dictates how much to listen to the text instruction. The default Image CFG of 1.5 and Text CFG of 7.5 are a good starting point, but aren't necessarily optimal for each edit. Try:
|
27 |
+
* Decreasing the Image CFG weight, or
|
28 |
+
* Incerasing the Text CFG weight, or
|
29 |
+
2. Conversely, is the image changing too much, such that the details in the original image aren't preserved? Try:
|
30 |
+
* Increasing the Image CFG weight, or
|
31 |
+
* Decreasing the Text CFG weight
|
32 |
+
3. Try generating results with different random seeds by setting "Randomize Seed" and running generation multiple times. You can also try setting "Randomize CFG" to sample new Text CFG and Image CFG values each time.
|
33 |
+
4. Rephrasing the instruction sometimes improves results (e.g., "turn him into a dog" vs. "make him a dog" vs. "as a dog").
|
34 |
+
5. Increasing the number of steps sometimes improves results.
|
35 |
+
6. Do faces look weird? The Stable Diffusion autoencoder has a hard time with faces that are small in the image. Try:
|
36 |
+
* Cropping the image so the face takes up a larger portion of the frame.
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
example_instructions = [
|
41 |
+
"Make it a picasso painting",
|
42 |
+
"as if it were by modigliani",
|
43 |
+
"convert to a bronze statue",
|
44 |
+
"Turn it into an anime.",
|
45 |
+
"have it look like a graphic novel",
|
46 |
+
"make him gain weight",
|
47 |
+
"what would he look like bald?",
|
48 |
+
"Have him smile",
|
49 |
+
"Put him in a cocktail party.",
|
50 |
+
"move him at the beach.",
|
51 |
+
"add dramatic lighting",
|
52 |
+
"Convert to black and white",
|
53 |
+
"What if it were snowing?",
|
54 |
+
"Give him a leather jacket",
|
55 |
+
"Turn him into a cyborg!",
|
56 |
+
"make him wear a beanie",
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
class CFGDenoiser(nn.Module):
|
61 |
+
def __init__(self, model):
|
62 |
+
super().__init__()
|
63 |
+
self.inner_model = model
|
64 |
+
|
65 |
+
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
|
66 |
+
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
|
67 |
+
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
|
68 |
+
cfg_cond = {
|
69 |
+
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
|
70 |
+
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
|
71 |
+
}
|
72 |
+
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
|
73 |
+
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
|
74 |
+
|
75 |
+
|
76 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False, cached=False):
|
77 |
+
print(f"Cache: {cached}")
|
78 |
+
print(f"Loading model from {ckpt}")
|
79 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
80 |
+
if "global_step" in pl_sd:
|
81 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
82 |
+
sd = pl_sd["state_dict"]
|
83 |
+
if vae_ckpt is not None:
|
84 |
+
print(f"Loading VAE from {vae_ckpt}")
|
85 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
86 |
+
sd = {
|
87 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
88 |
+
for k, v in sd.items()
|
89 |
+
}
|
90 |
+
model = instantiate_from_config(config.model, cached=cached)
|
91 |
+
m, u = model.load_state_dict(sd, strict=False)
|
92 |
+
if len(m) > 0 and verbose:
|
93 |
+
print("missing keys:")
|
94 |
+
print(m)
|
95 |
+
if len(u) > 0 and verbose:
|
96 |
+
print("unexpected keys:")
|
97 |
+
print(u)
|
98 |
+
return model
|
99 |
+
|
100 |
+
|
101 |
+
def main():
|
102 |
+
parser = ArgumentParser()
|
103 |
+
parser.add_argument("--resolution", default=512, type=int)
|
104 |
+
parser.add_argument("--config", default="configs/instruct-pix2pix/generate.yaml", type=str)
|
105 |
+
parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-20000.ckpt", type=str)
|
106 |
+
parser.add_argument("--vae-ckpt", default=None, type=str)
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
config = OmegaConf.load(args.config)
|
110 |
+
model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
|
111 |
+
model.eval().cuda()
|
112 |
+
model_wrap = K.external.CompVisDenoiser(model)
|
113 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
114 |
+
null_token = model.get_learned_conditioning([""])
|
115 |
+
example_image = Image.open("imgs/example.jpg").convert("RGB")
|
116 |
+
|
117 |
+
def load_example(
|
118 |
+
steps: int,
|
119 |
+
randomize_seed: bool,
|
120 |
+
seed: int,
|
121 |
+
randomize_cfg: bool,
|
122 |
+
text_cfg_scale: float,
|
123 |
+
image_cfg_scale: float,
|
124 |
+
):
|
125 |
+
example_instruction = random.choice(example_instructions)
|
126 |
+
return [example_image, example_instruction] + generate(
|
127 |
+
example_image,
|
128 |
+
example_instruction,
|
129 |
+
steps,
|
130 |
+
randomize_seed,
|
131 |
+
seed,
|
132 |
+
randomize_cfg,
|
133 |
+
text_cfg_scale,
|
134 |
+
image_cfg_scale,
|
135 |
+
)
|
136 |
+
|
137 |
+
def generate(
|
138 |
+
input_image: Image.Image,
|
139 |
+
instruction: str,
|
140 |
+
steps: int,
|
141 |
+
randomize_seed: bool,
|
142 |
+
seed: int,
|
143 |
+
randomize_cfg: bool,
|
144 |
+
text_cfg_scale: float,
|
145 |
+
image_cfg_scale: float,
|
146 |
+
):
|
147 |
+
seed = random.randint(0, 100000) if randomize_seed else seed
|
148 |
+
text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
|
149 |
+
image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
|
150 |
+
|
151 |
+
width, height = input_image.size
|
152 |
+
factor = args.resolution / max(width, height)
|
153 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
154 |
+
width = int((width * factor) // 64) * 64
|
155 |
+
height = int((height * factor) // 64) * 64
|
156 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
157 |
+
|
158 |
+
if instruction == "":
|
159 |
+
return [input_image, seed]
|
160 |
+
|
161 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
162 |
+
cond = {}
|
163 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([instruction])]
|
164 |
+
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
|
165 |
+
input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
|
166 |
+
cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
|
167 |
+
|
168 |
+
uncond = {}
|
169 |
+
uncond["c_crossattn"] = [null_token]
|
170 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
171 |
+
|
172 |
+
sigmas = model_wrap.get_sigmas(steps)
|
173 |
+
|
174 |
+
extra_args = {
|
175 |
+
"cond": cond,
|
176 |
+
"uncond": uncond,
|
177 |
+
"text_cfg_scale": text_cfg_scale,
|
178 |
+
"image_cfg_scale": image_cfg_scale,
|
179 |
+
}
|
180 |
+
torch.manual_seed(seed)
|
181 |
+
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
|
182 |
+
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
|
183 |
+
x = model.decode_first_stage(z)
|
184 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
185 |
+
x = 255.0 * rearrange(x, "1 c h w -> h w c")
|
186 |
+
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
|
187 |
+
|
188 |
+
return [seed, text_cfg_scale, image_cfg_scale, edited_image]
|
189 |
+
|
190 |
+
def reset():
|
191 |
+
return [50, "Randomize Seed", random.randint(0, 100000), "Fix CFG", 7.5, 1.5, None]
|
192 |
+
|
193 |
+
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
194 |
+
with gr.Row():
|
195 |
+
with gr.Column(scale=1, min_width=100):
|
196 |
+
generate_button = gr.Button("Generate")
|
197 |
+
with gr.Column(scale=1, min_width=100):
|
198 |
+
load_button = gr.Button("Load Example")
|
199 |
+
with gr.Column(scale=1, min_width=100):
|
200 |
+
reset_button = gr.Button("Reset")
|
201 |
+
with gr.Column(scale=3):
|
202 |
+
instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
|
203 |
+
|
204 |
+
with gr.Row():
|
205 |
+
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
206 |
+
edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False)
|
207 |
+
input_image.style(height=512, width=512)
|
208 |
+
edited_image.style(height=512, width=512)
|
209 |
+
|
210 |
+
with gr.Row():
|
211 |
+
steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
|
212 |
+
randomize_seed = gr.Radio(
|
213 |
+
["Fix Seed", "Randomize Seed"],
|
214 |
+
value="Randomize Seed",
|
215 |
+
type="index",
|
216 |
+
show_label=False,
|
217 |
+
interactive=True,
|
218 |
+
)
|
219 |
+
seed = gr.Number(value=random.randint(0, 100000), precision=0, label="Seed", interactive=True)
|
220 |
+
randomize_cfg = gr.Radio(
|
221 |
+
["Fix CFG", "Randomize CFG"],
|
222 |
+
value="Fix CFG",
|
223 |
+
type="index",
|
224 |
+
show_label=False,
|
225 |
+
interactive=True,
|
226 |
+
)
|
227 |
+
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
228 |
+
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
229 |
+
|
230 |
+
gr.Markdown(help_text)
|
231 |
+
|
232 |
+
load_button.click(
|
233 |
+
fn=load_example,
|
234 |
+
inputs=[
|
235 |
+
steps,
|
236 |
+
randomize_seed,
|
237 |
+
seed,
|
238 |
+
randomize_cfg,
|
239 |
+
text_cfg_scale,
|
240 |
+
image_cfg_scale,
|
241 |
+
],
|
242 |
+
outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
|
243 |
+
)
|
244 |
+
generate_button.click(
|
245 |
+
fn=generate,
|
246 |
+
inputs=[
|
247 |
+
input_image,
|
248 |
+
instruction,
|
249 |
+
steps,
|
250 |
+
randomize_seed,
|
251 |
+
seed,
|
252 |
+
randomize_cfg,
|
253 |
+
text_cfg_scale,
|
254 |
+
image_cfg_scale,
|
255 |
+
],
|
256 |
+
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
|
257 |
+
)
|
258 |
+
reset_button.click(
|
259 |
+
fn=reset,
|
260 |
+
inputs=[],
|
261 |
+
outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image],
|
262 |
+
)
|
263 |
+
|
264 |
+
demo.queue(concurrency_count=1)
|
265 |
+
demo.launch(share=True)
|
266 |
+
|
267 |
+
|
268 |
+
if __name__ == "__main__":
|
269 |
+
main()
|
edit_cli.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
import einops
|
9 |
+
import k_diffusion as K
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from einops import rearrange
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from PIL import Image, ImageOps
|
16 |
+
from torch import autocast
|
17 |
+
|
18 |
+
sys.path.append("./stable_diffusion")
|
19 |
+
|
20 |
+
from stable_diffusion.ldm.util import instantiate_from_config
|
21 |
+
|
22 |
+
|
23 |
+
class CFGDenoiser(nn.Module):
|
24 |
+
def __init__(self, model):
|
25 |
+
super().__init__()
|
26 |
+
self.inner_model = model
|
27 |
+
|
28 |
+
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
|
29 |
+
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
|
30 |
+
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
|
31 |
+
cfg_cond = {
|
32 |
+
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
|
33 |
+
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
|
34 |
+
}
|
35 |
+
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
|
36 |
+
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
|
37 |
+
|
38 |
+
|
39 |
+
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
|
40 |
+
print(f"Loading model from {ckpt}")
|
41 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
42 |
+
if "global_step" in pl_sd:
|
43 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
44 |
+
sd = pl_sd["state_dict"]
|
45 |
+
if vae_ckpt is not None:
|
46 |
+
print(f"Loading VAE from {vae_ckpt}")
|
47 |
+
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
|
48 |
+
sd = {
|
49 |
+
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
|
50 |
+
for k, v in sd.items()
|
51 |
+
}
|
52 |
+
model = instantiate_from_config(config.model)
|
53 |
+
m, u = model.load_state_dict(sd, strict=False)
|
54 |
+
if len(m) > 0 and verbose:
|
55 |
+
print("missing keys:")
|
56 |
+
print(m)
|
57 |
+
if len(u) > 0 and verbose:
|
58 |
+
print("unexpected keys:")
|
59 |
+
print(u)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
parser = ArgumentParser()
|
65 |
+
parser.add_argument("--resolution", default=512, type=int)
|
66 |
+
parser.add_argument("--steps", default=100, type=int)
|
67 |
+
parser.add_argument("--config", default="configs/generate.yaml", type=str)
|
68 |
+
parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-20000.ckpt", type=str)
|
69 |
+
parser.add_argument("--vae-ckpt", default=None, type=str)
|
70 |
+
parser.add_argument("--input", required=True, type=str)
|
71 |
+
parser.add_argument("--output", required=True, type=str)
|
72 |
+
parser.add_argument("--edit", required=True, type=str)
|
73 |
+
parser.add_argument("--cfg-text", default=7.5, type=float)
|
74 |
+
parser.add_argument("--cfg-image", default=1.2, type=float)
|
75 |
+
parser.add_argument("--seed", type=int)
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
config = OmegaConf.load(args.config)
|
79 |
+
model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
|
80 |
+
model.eval().cuda()
|
81 |
+
model_wrap = K.external.CompVisDenoiser(model)
|
82 |
+
model_wrap_cfg = CFGDenoiser(model_wrap)
|
83 |
+
null_token = model.get_learned_conditioning([""])
|
84 |
+
|
85 |
+
seed = random.randint(0, 100000) if args.seed is None else args.seed
|
86 |
+
input_image = Image.open(args.input).convert("RGB")
|
87 |
+
width, height = input_image.size
|
88 |
+
factor = args.resolution / max(width, height)
|
89 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
90 |
+
width = int((width * factor) // 64) * 64
|
91 |
+
height = int((height * factor) // 64) * 64
|
92 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
93 |
+
|
94 |
+
if args.edit == "":
|
95 |
+
input_image.save(args.output)
|
96 |
+
return
|
97 |
+
|
98 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
99 |
+
cond = {}
|
100 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([args.edit])]
|
101 |
+
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
|
102 |
+
input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
|
103 |
+
cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
|
104 |
+
|
105 |
+
uncond = {}
|
106 |
+
uncond["c_crossattn"] = [null_token]
|
107 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
108 |
+
|
109 |
+
sigmas = model_wrap.get_sigmas(args.steps)
|
110 |
+
|
111 |
+
extra_args = {
|
112 |
+
"cond": cond,
|
113 |
+
"uncond": uncond,
|
114 |
+
"text_cfg_scale": args.cfg_text,
|
115 |
+
"image_cfg_scale": args.cfg_image,
|
116 |
+
}
|
117 |
+
torch.manual_seed(seed)
|
118 |
+
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
|
119 |
+
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
|
120 |
+
x = model.decode_first_stage(z)
|
121 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
122 |
+
x = 255.0 * rearrange(x, "1 c h w -> h w c")
|
123 |
+
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
|
124 |
+
edited_image.save(args.output)
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
main()
|
edit_dataset.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
from einops import rearrange
|
12 |
+
from PIL import Image
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
|
16 |
+
class EditDataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
path: str,
|
20 |
+
split: str = "train",
|
21 |
+
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
|
22 |
+
min_resize_res: int = 256,
|
23 |
+
max_resize_res: int = 256,
|
24 |
+
crop_res: int = 256,
|
25 |
+
flip_prob: float = 0.0,
|
26 |
+
):
|
27 |
+
assert split in ("train", "val", "test")
|
28 |
+
assert sum(splits) == 1
|
29 |
+
self.path = path
|
30 |
+
self.min_resize_res = min_resize_res
|
31 |
+
self.max_resize_res = max_resize_res
|
32 |
+
self.crop_res = crop_res
|
33 |
+
self.flip_prob = flip_prob
|
34 |
+
|
35 |
+
with open(Path(self.path, "seeds.json")) as f:
|
36 |
+
self.seeds = json.load(f)
|
37 |
+
|
38 |
+
split_0, split_1 = {
|
39 |
+
"train": (0.0, splits[0]),
|
40 |
+
"val": (splits[0], splits[0] + splits[1]),
|
41 |
+
"test": (splits[0] + splits[1], 1.0),
|
42 |
+
}[split]
|
43 |
+
|
44 |
+
idx_0 = math.floor(split_0 * len(self.seeds))
|
45 |
+
idx_1 = math.floor(split_1 * len(self.seeds))
|
46 |
+
self.seeds = self.seeds[idx_0:idx_1]
|
47 |
+
|
48 |
+
def __len__(self) -> int:
|
49 |
+
return len(self.seeds)
|
50 |
+
|
51 |
+
def __getitem__(self, i: int) -> dict[str, Any]:
|
52 |
+
name, seeds = self.seeds[i]
|
53 |
+
propt_dir = Path(self.path, name)
|
54 |
+
seed = seeds[torch.randint(0, len(seeds), ()).item()]
|
55 |
+
with open(propt_dir.joinpath("prompt.json")) as fp:
|
56 |
+
prompt = json.load(fp)["edit"]
|
57 |
+
|
58 |
+
image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
|
59 |
+
image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
|
60 |
+
|
61 |
+
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
|
62 |
+
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
|
63 |
+
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
|
64 |
+
|
65 |
+
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
|
66 |
+
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
|
67 |
+
|
68 |
+
crop = torchvision.transforms.RandomCrop(self.crop_res)
|
69 |
+
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
|
70 |
+
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
|
71 |
+
|
72 |
+
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
|
environment.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
name: ip2p
|
5 |
+
channels:
|
6 |
+
- pytorch
|
7 |
+
- defaults
|
8 |
+
dependencies:
|
9 |
+
- python=3.8.5
|
10 |
+
- pip=20.3
|
11 |
+
- cudatoolkit=11.3
|
12 |
+
- pytorch=1.11.0
|
13 |
+
- torchvision=0.12.0
|
14 |
+
- numpy=1.19.2
|
15 |
+
- pip:
|
16 |
+
- albumentations==0.4.3
|
17 |
+
- diffusers
|
18 |
+
- opencv-python==4.1.2.30
|
19 |
+
- pudb==2019.2
|
20 |
+
- invisible-watermark
|
21 |
+
- imageio==2.9.0
|
22 |
+
- imageio-ffmpeg==0.4.2
|
23 |
+
- pytorch-lightning==1.4.2
|
24 |
+
- omegaconf==2.1.1
|
25 |
+
- test-tube>=0.7.5
|
26 |
+
- streamlit>=0.73.1
|
27 |
+
- einops==0.3.0
|
28 |
+
- torch-fidelity==0.3.0
|
29 |
+
- transformers==4.19.2
|
30 |
+
- torchmetrics==0.6.0
|
31 |
+
- kornia==0.6
|
32 |
+
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
33 |
+
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
34 |
+
- openai
|
35 |
+
- gradio
|
36 |
+
- seaborn
|
37 |
+
- git+https://github.com/crowsonkb/k-diffusion.git
|
imgs/example.jpg
ADDED
main.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
+
import argparse, os, sys, datetime, glob
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
import json
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
from packaging import version
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from torch.utils.data import DataLoader, Dataset
|
16 |
+
from functools import partial
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
import torch.distributed as dist
|
20 |
+
from pytorch_lightning import seed_everything
|
21 |
+
from pytorch_lightning.trainer import Trainer
|
22 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
23 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
24 |
+
from pytorch_lightning.utilities import rank_zero_info
|
25 |
+
from pytorch_lightning.plugins import DDPPlugin
|
26 |
+
|
27 |
+
sys.path.append("./stable_diffusion")
|
28 |
+
|
29 |
+
from stable_diffusion.ldm.data.base import Txt2ImgIterableBaseDataset
|
30 |
+
from stable_diffusion.ldm.util import instantiate_from_config
|
31 |
+
|
32 |
+
|
33 |
+
def get_parser(**parser_kwargs):
|
34 |
+
def str2bool(v):
|
35 |
+
if isinstance(v, bool):
|
36 |
+
return v
|
37 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
38 |
+
return True
|
39 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
40 |
+
return False
|
41 |
+
else:
|
42 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
43 |
+
|
44 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
45 |
+
parser.add_argument(
|
46 |
+
"-n",
|
47 |
+
"--name",
|
48 |
+
type=str,
|
49 |
+
const=True,
|
50 |
+
default="",
|
51 |
+
nargs="?",
|
52 |
+
help="postfix for logdir",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"-r",
|
56 |
+
"--resume",
|
57 |
+
type=str,
|
58 |
+
const=True,
|
59 |
+
default="",
|
60 |
+
nargs="?",
|
61 |
+
help="resume from logdir or checkpoint in logdir",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"-b",
|
65 |
+
"--base",
|
66 |
+
nargs="*",
|
67 |
+
metavar="base_config.yaml",
|
68 |
+
help="paths to base configs. Loaded from left-to-right. "
|
69 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
70 |
+
default=list(),
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"-t",
|
74 |
+
"--train",
|
75 |
+
type=str2bool,
|
76 |
+
const=True,
|
77 |
+
default=False,
|
78 |
+
nargs="?",
|
79 |
+
help="train",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--no-test",
|
83 |
+
type=str2bool,
|
84 |
+
const=True,
|
85 |
+
default=False,
|
86 |
+
nargs="?",
|
87 |
+
help="disable test",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"-p",
|
91 |
+
"--project",
|
92 |
+
help="name of new or path to existing project"
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"-d",
|
96 |
+
"--debug",
|
97 |
+
type=str2bool,
|
98 |
+
nargs="?",
|
99 |
+
const=True,
|
100 |
+
default=False,
|
101 |
+
help="enable post-mortem debugging",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"-s",
|
105 |
+
"--seed",
|
106 |
+
type=int,
|
107 |
+
default=23,
|
108 |
+
help="seed for seed_everything",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"-f",
|
112 |
+
"--postfix",
|
113 |
+
type=str,
|
114 |
+
default="",
|
115 |
+
help="post-postfix for default name",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"-l",
|
119 |
+
"--logdir",
|
120 |
+
type=str,
|
121 |
+
default="logs",
|
122 |
+
help="directory for logging dat shit",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--scale_lr",
|
126 |
+
action="store_true",
|
127 |
+
default=False,
|
128 |
+
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
129 |
+
)
|
130 |
+
return parser
|
131 |
+
|
132 |
+
|
133 |
+
def nondefault_trainer_args(opt):
|
134 |
+
parser = argparse.ArgumentParser()
|
135 |
+
parser = Trainer.add_argparse_args(parser)
|
136 |
+
args = parser.parse_args([])
|
137 |
+
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
138 |
+
|
139 |
+
|
140 |
+
class WrappedDataset(Dataset):
|
141 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
142 |
+
|
143 |
+
def __init__(self, dataset):
|
144 |
+
self.data = dataset
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.data)
|
148 |
+
|
149 |
+
def __getitem__(self, idx):
|
150 |
+
return self.data[idx]
|
151 |
+
|
152 |
+
|
153 |
+
def worker_init_fn(_):
|
154 |
+
worker_info = torch.utils.data.get_worker_info()
|
155 |
+
|
156 |
+
dataset = worker_info.dataset
|
157 |
+
worker_id = worker_info.id
|
158 |
+
|
159 |
+
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
160 |
+
split_size = dataset.num_records // worker_info.num_workers
|
161 |
+
# reset num_records to the true number to retain reliable length information
|
162 |
+
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
163 |
+
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
164 |
+
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
165 |
+
else:
|
166 |
+
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
167 |
+
|
168 |
+
|
169 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
170 |
+
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
171 |
+
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
172 |
+
shuffle_val_dataloader=False):
|
173 |
+
super().__init__()
|
174 |
+
self.batch_size = batch_size
|
175 |
+
self.dataset_configs = dict()
|
176 |
+
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
177 |
+
self.use_worker_init_fn = use_worker_init_fn
|
178 |
+
if train is not None:
|
179 |
+
self.dataset_configs["train"] = train
|
180 |
+
self.train_dataloader = self._train_dataloader
|
181 |
+
if validation is not None:
|
182 |
+
self.dataset_configs["validation"] = validation
|
183 |
+
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
184 |
+
if test is not None:
|
185 |
+
self.dataset_configs["test"] = test
|
186 |
+
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
187 |
+
if predict is not None:
|
188 |
+
self.dataset_configs["predict"] = predict
|
189 |
+
self.predict_dataloader = self._predict_dataloader
|
190 |
+
self.wrap = wrap
|
191 |
+
|
192 |
+
def prepare_data(self):
|
193 |
+
for data_cfg in self.dataset_configs.values():
|
194 |
+
instantiate_from_config(data_cfg)
|
195 |
+
|
196 |
+
def setup(self, stage=None):
|
197 |
+
self.datasets = dict(
|
198 |
+
(k, instantiate_from_config(self.dataset_configs[k]))
|
199 |
+
for k in self.dataset_configs)
|
200 |
+
if self.wrap:
|
201 |
+
for k in self.datasets:
|
202 |
+
self.datasets[k] = WrappedDataset(self.datasets[k])
|
203 |
+
|
204 |
+
def _train_dataloader(self):
|
205 |
+
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
206 |
+
if is_iterable_dataset or self.use_worker_init_fn:
|
207 |
+
init_fn = worker_init_fn
|
208 |
+
else:
|
209 |
+
init_fn = None
|
210 |
+
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
211 |
+
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
212 |
+
worker_init_fn=init_fn, persistent_workers=True)
|
213 |
+
|
214 |
+
def _val_dataloader(self, shuffle=False):
|
215 |
+
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
216 |
+
init_fn = worker_init_fn
|
217 |
+
else:
|
218 |
+
init_fn = None
|
219 |
+
return DataLoader(self.datasets["validation"],
|
220 |
+
batch_size=self.batch_size,
|
221 |
+
num_workers=self.num_workers,
|
222 |
+
worker_init_fn=init_fn,
|
223 |
+
shuffle=shuffle, persistent_workers=True)
|
224 |
+
|
225 |
+
def _test_dataloader(self, shuffle=False):
|
226 |
+
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
227 |
+
if is_iterable_dataset or self.use_worker_init_fn:
|
228 |
+
init_fn = worker_init_fn
|
229 |
+
else:
|
230 |
+
init_fn = None
|
231 |
+
|
232 |
+
# do not shuffle dataloader for iterable dataset
|
233 |
+
shuffle = shuffle and (not is_iterable_dataset)
|
234 |
+
|
235 |
+
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
236 |
+
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, persistent_workers=True)
|
237 |
+
|
238 |
+
def _predict_dataloader(self, shuffle=False):
|
239 |
+
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
240 |
+
init_fn = worker_init_fn
|
241 |
+
else:
|
242 |
+
init_fn = None
|
243 |
+
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
|
244 |
+
num_workers=self.num_workers, worker_init_fn=init_fn, persistent_workers=True)
|
245 |
+
|
246 |
+
|
247 |
+
class SetupCallback(Callback):
|
248 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
249 |
+
super().__init__()
|
250 |
+
self.resume = resume
|
251 |
+
self.now = now
|
252 |
+
self.logdir = logdir
|
253 |
+
self.ckptdir = ckptdir
|
254 |
+
self.cfgdir = cfgdir
|
255 |
+
self.config = config
|
256 |
+
self.lightning_config = lightning_config
|
257 |
+
|
258 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
259 |
+
if trainer.global_rank == 0:
|
260 |
+
print("Summoning checkpoint.")
|
261 |
+
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
262 |
+
trainer.save_checkpoint(ckpt_path)
|
263 |
+
|
264 |
+
def on_pretrain_routine_start(self, trainer, pl_module):
|
265 |
+
if trainer.global_rank == 0:
|
266 |
+
# Create logdirs and save configs
|
267 |
+
# os.makedirs(self.logdir, exist_ok=True)
|
268 |
+
# os.makedirs(self.ckptdir, exist_ok=True)
|
269 |
+
# os.makedirs(self.cfgdir, exist_ok=True)
|
270 |
+
|
271 |
+
if "callbacks" in self.lightning_config:
|
272 |
+
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
273 |
+
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
274 |
+
print("Project config")
|
275 |
+
print(OmegaConf.to_yaml(self.config))
|
276 |
+
OmegaConf.save(self.config,
|
277 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
278 |
+
|
279 |
+
print("Lightning config")
|
280 |
+
print(OmegaConf.to_yaml(self.lightning_config))
|
281 |
+
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
282 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
283 |
+
|
284 |
+
def get_world_size():
|
285 |
+
if not dist.is_available():
|
286 |
+
return 1
|
287 |
+
if not dist.is_initialized():
|
288 |
+
return 1
|
289 |
+
return dist.get_world_size()
|
290 |
+
|
291 |
+
def all_gather(data):
|
292 |
+
"""
|
293 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
294 |
+
Args:
|
295 |
+
data: any picklable object
|
296 |
+
Returns:
|
297 |
+
list[data]: list of data gathered from each rank
|
298 |
+
"""
|
299 |
+
world_size = get_world_size()
|
300 |
+
if world_size == 1:
|
301 |
+
return [data]
|
302 |
+
|
303 |
+
# serialized to a Tensor
|
304 |
+
origin_size = None
|
305 |
+
if not isinstance(data, torch.Tensor):
|
306 |
+
buffer = pickle.dumps(data)
|
307 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
308 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
309 |
+
else:
|
310 |
+
origin_size = data.size()
|
311 |
+
tensor = data.reshape(-1)
|
312 |
+
|
313 |
+
tensor_type = tensor.dtype
|
314 |
+
|
315 |
+
# obtain Tensor size of each rank
|
316 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
317 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
318 |
+
dist.all_gather(size_list, local_size)
|
319 |
+
size_list = [int(size.item()) for size in size_list]
|
320 |
+
max_size = max(size_list)
|
321 |
+
|
322 |
+
# receiving Tensor from all ranks
|
323 |
+
# we pad the tensor because torch all_gather does not support
|
324 |
+
# gathering tensors of different shapes
|
325 |
+
tensor_list = []
|
326 |
+
for _ in size_list:
|
327 |
+
tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type))
|
328 |
+
if local_size != max_size:
|
329 |
+
padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type)
|
330 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
331 |
+
dist.all_gather(tensor_list, tensor)
|
332 |
+
|
333 |
+
data_list = []
|
334 |
+
for size, tensor in zip(size_list, tensor_list):
|
335 |
+
if origin_size is None:
|
336 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
337 |
+
data_list.append(pickle.loads(buffer))
|
338 |
+
else:
|
339 |
+
buffer = tensor[:size]
|
340 |
+
data_list.append(buffer)
|
341 |
+
|
342 |
+
if origin_size is not None:
|
343 |
+
new_shape = [-1] + list(origin_size[1:])
|
344 |
+
resized_list = []
|
345 |
+
for data in data_list:
|
346 |
+
# suppose the difference of tensor size exist in first dimension
|
347 |
+
data = data.reshape(new_shape)
|
348 |
+
resized_list.append(data)
|
349 |
+
|
350 |
+
return resized_list
|
351 |
+
else:
|
352 |
+
return data_list
|
353 |
+
|
354 |
+
class ImageLogger(Callback):
|
355 |
+
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
356 |
+
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
357 |
+
log_images_kwargs=None):
|
358 |
+
super().__init__()
|
359 |
+
self.rescale = rescale
|
360 |
+
self.batch_freq = batch_frequency
|
361 |
+
self.max_images = max_images
|
362 |
+
self.logger_log_images = {
|
363 |
+
pl.loggers.TestTubeLogger: self._testtube,
|
364 |
+
}
|
365 |
+
self.log_steps = [2 ** n for n in range(6, int(np.log2(self.batch_freq)) + 1)]
|
366 |
+
if not increase_log_steps:
|
367 |
+
self.log_steps = [self.batch_freq]
|
368 |
+
self.clamp = clamp
|
369 |
+
self.disabled = disabled
|
370 |
+
self.log_on_batch_idx = log_on_batch_idx
|
371 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
372 |
+
self.log_first_step = log_first_step
|
373 |
+
|
374 |
+
@rank_zero_only
|
375 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
376 |
+
for k in images:
|
377 |
+
grid = torchvision.utils.make_grid(images[k])
|
378 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
379 |
+
|
380 |
+
tag = f"{split}/{k}"
|
381 |
+
pl_module.logger.experiment.add_image(
|
382 |
+
tag, grid,
|
383 |
+
global_step=pl_module.global_step)
|
384 |
+
|
385 |
+
@rank_zero_only
|
386 |
+
def log_local(self, save_dir, split, images, prompts,
|
387 |
+
global_step, current_epoch, batch_idx):
|
388 |
+
root = os.path.join(save_dir, "images", split)
|
389 |
+
names = {"reals": "before", "inputs": "after", "reconstruction": "before-vq", "samples": "after-gen"}
|
390 |
+
# print(root)
|
391 |
+
for k in images:
|
392 |
+
grid = torchvision.utils.make_grid(images[k], nrow=8)
|
393 |
+
if self.rescale:
|
394 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
395 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
396 |
+
grid = grid.numpy()
|
397 |
+
grid = (grid * 255).astype(np.uint8)
|
398 |
+
filename = "gs-{:06}_e-{:06}_b-{:06}_{}.png".format(
|
399 |
+
global_step,
|
400 |
+
current_epoch,
|
401 |
+
batch_idx,
|
402 |
+
names[k])
|
403 |
+
path = os.path.join(root, filename)
|
404 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
405 |
+
# print(path)
|
406 |
+
Image.fromarray(grid).save(path)
|
407 |
+
|
408 |
+
filename = "gs-{:06}_e-{:06}_b-{:06}_prompt.json".format(
|
409 |
+
global_step,
|
410 |
+
current_epoch,
|
411 |
+
batch_idx)
|
412 |
+
path = os.path.join(root, filename)
|
413 |
+
with open(path, "w") as f:
|
414 |
+
for p in prompts:
|
415 |
+
f.write(f"{json.dumps(p)}\n")
|
416 |
+
|
417 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
418 |
+
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
419 |
+
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
420 |
+
hasattr(pl_module, "log_images") and
|
421 |
+
callable(pl_module.log_images) and
|
422 |
+
self.max_images > 0) or (split == "val" and batch_idx == 0):
|
423 |
+
logger = type(pl_module.logger)
|
424 |
+
|
425 |
+
is_train = pl_module.training
|
426 |
+
if is_train:
|
427 |
+
pl_module.eval()
|
428 |
+
|
429 |
+
with torch.no_grad():
|
430 |
+
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
431 |
+
|
432 |
+
prompts = batch["edit"]["c_crossattn"][:self.max_images]
|
433 |
+
prompts = [p for ps in all_gather(prompts) for p in ps]
|
434 |
+
|
435 |
+
for k in images:
|
436 |
+
N = min(images[k].shape[0], self.max_images)
|
437 |
+
images[k] = images[k][:N]
|
438 |
+
images[k] = torch.cat(all_gather(images[k][:N]))
|
439 |
+
if isinstance(images[k], torch.Tensor):
|
440 |
+
images[k] = images[k].detach().cpu()
|
441 |
+
if self.clamp:
|
442 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
443 |
+
|
444 |
+
self.log_local(pl_module.logger.save_dir, split, images, prompts,
|
445 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
446 |
+
|
447 |
+
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
448 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
449 |
+
|
450 |
+
if is_train:
|
451 |
+
pl_module.train()
|
452 |
+
|
453 |
+
def check_frequency(self, check_idx):
|
454 |
+
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
455 |
+
check_idx > 0 or self.log_first_step):
|
456 |
+
if len(self.log_steps) > 0:
|
457 |
+
self.log_steps.pop(0)
|
458 |
+
return True
|
459 |
+
return False
|
460 |
+
|
461 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
462 |
+
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
463 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
464 |
+
|
465 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
466 |
+
if not self.disabled and pl_module.global_step > 0:
|
467 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
468 |
+
if hasattr(pl_module, 'calibrate_grad_norm'):
|
469 |
+
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
470 |
+
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
471 |
+
|
472 |
+
|
473 |
+
class CUDACallback(Callback):
|
474 |
+
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
475 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
476 |
+
# Reset the memory use counter
|
477 |
+
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
478 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
479 |
+
self.start_time = time.time()
|
480 |
+
|
481 |
+
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
482 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
483 |
+
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
484 |
+
epoch_time = time.time() - self.start_time
|
485 |
+
|
486 |
+
try:
|
487 |
+
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
488 |
+
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
489 |
+
|
490 |
+
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
491 |
+
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
492 |
+
except AttributeError:
|
493 |
+
pass
|
494 |
+
|
495 |
+
|
496 |
+
if __name__ == "__main__":
|
497 |
+
# custom parser to specify config files, train, test and debug mode,
|
498 |
+
# postfix, resume.
|
499 |
+
# `--key value` arguments are interpreted as arguments to the trainer.
|
500 |
+
# `nested.key=value` arguments are interpreted as config parameters.
|
501 |
+
# configs are merged from left-to-right followed by command line parameters.
|
502 |
+
|
503 |
+
# model:
|
504 |
+
# base_learning_rate: float
|
505 |
+
# target: path to lightning module
|
506 |
+
# params:
|
507 |
+
# key: value
|
508 |
+
# data:
|
509 |
+
# target: main.DataModuleFromConfig
|
510 |
+
# params:
|
511 |
+
# batch_size: int
|
512 |
+
# wrap: bool
|
513 |
+
# train:
|
514 |
+
# target: path to train dataset
|
515 |
+
# params:
|
516 |
+
# key: value
|
517 |
+
# validation:
|
518 |
+
# target: path to validation dataset
|
519 |
+
# params:
|
520 |
+
# key: value
|
521 |
+
# test:
|
522 |
+
# target: path to test dataset
|
523 |
+
# params:
|
524 |
+
# key: value
|
525 |
+
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
526 |
+
# trainer:
|
527 |
+
# additional arguments to trainer
|
528 |
+
# logger:
|
529 |
+
# logger to instantiate
|
530 |
+
# modelcheckpoint:
|
531 |
+
# modelcheckpoint to instantiate
|
532 |
+
# callbacks:
|
533 |
+
# callback1:
|
534 |
+
# target: importpath
|
535 |
+
# params:
|
536 |
+
# key: value
|
537 |
+
|
538 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
539 |
+
|
540 |
+
# add cwd for convenience and to make classes in this file available when
|
541 |
+
# running as `python main.py`
|
542 |
+
# (in particular `main.DataModuleFromConfig`)
|
543 |
+
sys.path.append(os.getcwd())
|
544 |
+
|
545 |
+
parser = get_parser()
|
546 |
+
parser = Trainer.add_argparse_args(parser)
|
547 |
+
|
548 |
+
opt, unknown = parser.parse_known_args()
|
549 |
+
|
550 |
+
assert opt.name
|
551 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
552 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
553 |
+
nowname = f"{cfg_name}_{opt.name}"
|
554 |
+
logdir = os.path.join(opt.logdir, nowname)
|
555 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
556 |
+
|
557 |
+
if os.path.isfile(ckpt):
|
558 |
+
opt.resume_from_checkpoint = ckpt
|
559 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
560 |
+
opt.base = base_configs + opt.base
|
561 |
+
_tmp = logdir.split("/")
|
562 |
+
nowname = _tmp[-1]
|
563 |
+
# By default, when finetuning from Stable Diffusion, we load the EMA-only checkpoint to initialize all weights.
|
564 |
+
# If resuming InstructPix2Pix from a finetuning checkpoint, instead load both EMA and non-EMA weights.
|
565 |
+
opt.model.params.load_ema = True
|
566 |
+
|
567 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
568 |
+
cfgdir = os.path.join(logdir, "configs")
|
569 |
+
|
570 |
+
os.makedirs(logdir, exist_ok=True)
|
571 |
+
os.makedirs(ckptdir, exist_ok=True)
|
572 |
+
os.makedirs(cfgdir, exist_ok=True)
|
573 |
+
|
574 |
+
try:
|
575 |
+
# init and save configs
|
576 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
577 |
+
cli = OmegaConf.from_dotlist(unknown)
|
578 |
+
config = OmegaConf.merge(*configs, cli)
|
579 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
580 |
+
# merge trainer cli with config
|
581 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
582 |
+
# default to ddp
|
583 |
+
trainer_config["accelerator"] = "ddp"
|
584 |
+
for k in nondefault_trainer_args(opt):
|
585 |
+
trainer_config[k] = getattr(opt, k)
|
586 |
+
if not "gpus" in trainer_config:
|
587 |
+
del trainer_config["accelerator"]
|
588 |
+
cpu = True
|
589 |
+
else:
|
590 |
+
gpuinfo = trainer_config["gpus"]
|
591 |
+
print(f"Running on GPUs {gpuinfo}")
|
592 |
+
cpu = False
|
593 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
594 |
+
lightning_config.trainer = trainer_config
|
595 |
+
|
596 |
+
# model
|
597 |
+
model = instantiate_from_config(config.model)
|
598 |
+
|
599 |
+
# trainer and callbacks
|
600 |
+
trainer_kwargs = dict()
|
601 |
+
|
602 |
+
# default logger configs
|
603 |
+
default_logger_cfgs = {
|
604 |
+
"wandb": {
|
605 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
606 |
+
"params": {
|
607 |
+
"name": nowname,
|
608 |
+
"save_dir": logdir,
|
609 |
+
"id": nowname,
|
610 |
+
}
|
611 |
+
},
|
612 |
+
"testtube": {
|
613 |
+
"target": "pytorch_lightning.loggers.TestTubeLogger",
|
614 |
+
"params": {
|
615 |
+
"name": "testtube",
|
616 |
+
"save_dir": logdir,
|
617 |
+
}
|
618 |
+
},
|
619 |
+
}
|
620 |
+
default_logger_cfg = default_logger_cfgs["wandb"]
|
621 |
+
if "logger" in lightning_config:
|
622 |
+
logger_cfg = lightning_config.logger
|
623 |
+
else:
|
624 |
+
logger_cfg = OmegaConf.create()
|
625 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
626 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
627 |
+
|
628 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
629 |
+
# specify which metric is used to determine best models
|
630 |
+
default_modelckpt_cfg = {
|
631 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
632 |
+
"params": {
|
633 |
+
"dirpath": ckptdir,
|
634 |
+
"filename": "{epoch:06}",
|
635 |
+
"verbose": True,
|
636 |
+
"save_last": True,
|
637 |
+
}
|
638 |
+
}
|
639 |
+
|
640 |
+
if "modelcheckpoint" in lightning_config:
|
641 |
+
modelckpt_cfg = lightning_config.modelcheckpoint
|
642 |
+
else:
|
643 |
+
modelckpt_cfg = OmegaConf.create()
|
644 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
645 |
+
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
646 |
+
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
647 |
+
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
648 |
+
|
649 |
+
# add callback which sets up log directory
|
650 |
+
default_callbacks_cfg = {
|
651 |
+
"setup_callback": {
|
652 |
+
"target": "main.SetupCallback",
|
653 |
+
"params": {
|
654 |
+
"resume": opt.resume,
|
655 |
+
"now": now,
|
656 |
+
"logdir": logdir,
|
657 |
+
"ckptdir": ckptdir,
|
658 |
+
"cfgdir": cfgdir,
|
659 |
+
"config": config,
|
660 |
+
"lightning_config": lightning_config,
|
661 |
+
}
|
662 |
+
},
|
663 |
+
"image_logger": {
|
664 |
+
"target": "main.ImageLogger",
|
665 |
+
"params": {
|
666 |
+
"batch_frequency": 750,
|
667 |
+
"max_images": 4,
|
668 |
+
"clamp": True
|
669 |
+
}
|
670 |
+
},
|
671 |
+
"learning_rate_logger": {
|
672 |
+
"target": "main.LearningRateMonitor",
|
673 |
+
"params": {
|
674 |
+
"logging_interval": "step",
|
675 |
+
# "log_momentum": True
|
676 |
+
}
|
677 |
+
},
|
678 |
+
"cuda_callback": {
|
679 |
+
"target": "main.CUDACallback"
|
680 |
+
},
|
681 |
+
}
|
682 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
683 |
+
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
684 |
+
|
685 |
+
if "callbacks" in lightning_config:
|
686 |
+
callbacks_cfg = lightning_config.callbacks
|
687 |
+
else:
|
688 |
+
callbacks_cfg = OmegaConf.create()
|
689 |
+
|
690 |
+
print(
|
691 |
+
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
692 |
+
default_metrics_over_trainsteps_ckpt_dict = {
|
693 |
+
'metrics_over_trainsteps_checkpoint': {
|
694 |
+
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
695 |
+
'params': {
|
696 |
+
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
697 |
+
"filename": "{epoch:06}-{step:09}",
|
698 |
+
"verbose": True,
|
699 |
+
'save_top_k': -1,
|
700 |
+
'every_n_train_steps': 1000,
|
701 |
+
'save_weights_only': True
|
702 |
+
}
|
703 |
+
}
|
704 |
+
}
|
705 |
+
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
706 |
+
|
707 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
708 |
+
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
709 |
+
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
710 |
+
elif 'ignore_keys_callback' in callbacks_cfg:
|
711 |
+
del callbacks_cfg['ignore_keys_callback']
|
712 |
+
|
713 |
+
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
714 |
+
|
715 |
+
trainer = Trainer.from_argparse_args(trainer_opt, plugins=DDPPlugin(find_unused_parameters=False), **trainer_kwargs)
|
716 |
+
trainer.logdir = logdir ###
|
717 |
+
|
718 |
+
# data
|
719 |
+
data = instantiate_from_config(config.data)
|
720 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
721 |
+
# calling these ourselves should not be necessary but it is.
|
722 |
+
# lightning still takes care of proper multiprocessing though
|
723 |
+
data.prepare_data()
|
724 |
+
data.setup()
|
725 |
+
print("#### Data #####")
|
726 |
+
for k in data.datasets:
|
727 |
+
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
728 |
+
|
729 |
+
# configure learning rate
|
730 |
+
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
731 |
+
if not cpu:
|
732 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
733 |
+
else:
|
734 |
+
ngpu = 1
|
735 |
+
if 'accumulate_grad_batches' in lightning_config.trainer:
|
736 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
737 |
+
else:
|
738 |
+
accumulate_grad_batches = 1
|
739 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
740 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
741 |
+
if opt.scale_lr:
|
742 |
+
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
743 |
+
print(
|
744 |
+
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
745 |
+
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
746 |
+
else:
|
747 |
+
model.learning_rate = base_lr
|
748 |
+
print("++++ NOT USING LR SCALING ++++")
|
749 |
+
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
750 |
+
|
751 |
+
|
752 |
+
# allow checkpointing via USR1
|
753 |
+
def melk(*args, **kwargs):
|
754 |
+
# run all checkpoint hooks
|
755 |
+
if trainer.global_rank == 0:
|
756 |
+
print("Summoning checkpoint.")
|
757 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
758 |
+
trainer.save_checkpoint(ckpt_path)
|
759 |
+
|
760 |
+
|
761 |
+
def divein(*args, **kwargs):
|
762 |
+
if trainer.global_rank == 0:
|
763 |
+
import pudb;
|
764 |
+
pudb.set_trace()
|
765 |
+
|
766 |
+
|
767 |
+
import signal
|
768 |
+
|
769 |
+
signal.signal(signal.SIGUSR1, melk)
|
770 |
+
signal.signal(signal.SIGUSR2, divein)
|
771 |
+
|
772 |
+
# run
|
773 |
+
if opt.train:
|
774 |
+
try:
|
775 |
+
trainer.fit(model, data)
|
776 |
+
except Exception:
|
777 |
+
melk()
|
778 |
+
raise
|
779 |
+
if not opt.no_test and not trainer.interrupted:
|
780 |
+
trainer.test(model, data)
|
781 |
+
except Exception:
|
782 |
+
if opt.debug and trainer.global_rank == 0:
|
783 |
+
try:
|
784 |
+
import pudb as debugger
|
785 |
+
except ImportError:
|
786 |
+
import pdb as debugger
|
787 |
+
debugger.post_mortem()
|
788 |
+
raise
|
789 |
+
finally:
|
790 |
+
# move newly created debug project to debug_runs
|
791 |
+
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
792 |
+
dst, name = os.path.split(logdir)
|
793 |
+
dst = os.path.join(dst, "debug_runs", name)
|
794 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
795 |
+
os.rename(logdir, dst)
|
796 |
+
if trainer.global_rank == 0:
|
797 |
+
print(trainer.profiler.summary())
|
metrics/clip_similarity.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import clip
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class ClipSimilarity(nn.Module):
|
11 |
+
def __init__(self, name: str = "ViT-L/14"):
|
12 |
+
super().__init__()
|
13 |
+
assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip
|
14 |
+
self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224)
|
15 |
+
|
16 |
+
self.model, _ = clip.load(name, device="cpu", download_root="./")
|
17 |
+
self.model.eval().requires_grad_(False)
|
18 |
+
|
19 |
+
self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073)))
|
20 |
+
self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711)))
|
21 |
+
|
22 |
+
def encode_text(self, text: list[str]) -> torch.Tensor:
|
23 |
+
text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device)
|
24 |
+
text_features = self.model.encode_text(text)
|
25 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
26 |
+
return text_features
|
27 |
+
|
28 |
+
def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1].
|
29 |
+
image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False)
|
30 |
+
image = image - rearrange(self.mean, "c -> 1 c 1 1")
|
31 |
+
image = image / rearrange(self.std, "c -> 1 c 1 1")
|
32 |
+
image_features = self.model.encode_image(image)
|
33 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
34 |
+
return image_features
|
35 |
+
|
36 |
+
def forward(
|
37 |
+
self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str]
|
38 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
39 |
+
image_features_0 = self.encode_image(image_0)
|
40 |
+
image_features_1 = self.encode_image(image_1)
|
41 |
+
text_features_0 = self.encode_text(text_0)
|
42 |
+
text_features_1 = self.encode_text(text_1)
|
43 |
+
sim_0 = F.cosine_similarity(image_features_0, text_features_0)
|
44 |
+
sim_1 = F.cosine_similarity(image_features_1, text_features_1)
|
45 |
+
sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0)
|
46 |
+
sim_image = F.cosine_similarity(image_features_0, image_features_1)
|
47 |
+
return sim_0, sim_1, sim_direction, sim_image
|
prompt_app.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
import datasets
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import openai
|
9 |
+
|
10 |
+
from dataset_creation.generate_txt_dataset import generate
|
11 |
+
|
12 |
+
|
13 |
+
def main(openai_model: str):
|
14 |
+
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
|
15 |
+
captions = dataset[np.random.permutation(len(dataset))]["TEXT"]
|
16 |
+
index = 0
|
17 |
+
|
18 |
+
def click_random():
|
19 |
+
nonlocal index
|
20 |
+
output = captions[index]
|
21 |
+
index = (index + 1) % len(captions)
|
22 |
+
return output
|
23 |
+
|
24 |
+
def click_generate(input: str):
|
25 |
+
if input == "":
|
26 |
+
raise gr.Error("Input caption is missing!")
|
27 |
+
edit_output = generate(openai_model, input)
|
28 |
+
if edit_output is None:
|
29 |
+
return "Failed :(", "Failed :("
|
30 |
+
return edit_output
|
31 |
+
|
32 |
+
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
33 |
+
txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") # fmt: skip
|
34 |
+
txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False)
|
35 |
+
txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False)
|
36 |
+
|
37 |
+
with gr.Row():
|
38 |
+
clear_btn = gr.Button("Clear")
|
39 |
+
random_btn = gr.Button("Random Input")
|
40 |
+
generate_btn = gr.Button("Generate Instruction + Edited Caption")
|
41 |
+
|
42 |
+
clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output])
|
43 |
+
random_btn.click(fn=click_random, inputs=[], outputs=[txt_input])
|
44 |
+
generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output])
|
45 |
+
|
46 |
+
demo.launch(share=True)
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
parser = ArgumentParser()
|
51 |
+
parser.add_argument("openai-api-key", type=str)
|
52 |
+
parser.add_argument("openai-model", type=str)
|
53 |
+
args = parser.parse_args()
|
54 |
+
openai.api_key = args.openai_api_key
|
55 |
+
main(args.openai_model)
|
scripts/download_checkpoints.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
4 |
+
|
5 |
+
mkdir -p $SCRIPT_DIR/../checkpoints
|
6 |
+
|
7 |
+
curl http://instruct-pix2pix.eecs.berkeley.edu/instruct-pix2pix-00-20000.ckpt -o $SCRIPT_DIR/../checkpoints/instruct-pix2pix-00-20000.ckpt
|
scripts/download_data.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
4 |
+
|
5 |
+
mkdir -p $SCRIPT_DIR/../data
|
6 |
+
|
7 |
+
wget http://instruct-pix2pix.eecs.berkeley.edu/gpt-generated-prompts.jsonl -O $SCRIPT_DIR/../data/gpt-generated-prompts.jsonl
|
8 |
+
wget http://instruct-pix2pix.eecs.berkeley.edu/human-written-prompts.jsonl -O $SCRIPT_DIR/../data/human-written-prompts.jsonl
|
9 |
+
|
10 |
+
mkdir $SCRIPT_DIR/../data/$1
|
11 |
+
wget -A zip,json -r http://instruct-pix2pix.eecs.berkeley.edu/$1 -nd -P $SCRIPT_DIR/../data/$1
|
stable_diffusion/ldm/models/diffusion/ddpm_edit.py
ADDED
@@ -0,0 +1,1459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
wild mixture of
|
3 |
+
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
4 |
+
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
5 |
+
https://github.com/CompVis/taming-transformers
|
6 |
+
-- merci
|
7 |
+
"""
|
8 |
+
|
9 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
10 |
+
# See more details in LICENSE.
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
import pytorch_lightning as pl
|
16 |
+
from torch.optim.lr_scheduler import LambdaLR
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
from contextlib import contextmanager
|
19 |
+
from functools import partial
|
20 |
+
from tqdm import tqdm
|
21 |
+
from torchvision.utils import make_grid
|
22 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
23 |
+
|
24 |
+
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
25 |
+
from ldm.modules.ema import LitEma
|
26 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
27 |
+
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
28 |
+
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
29 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
30 |
+
|
31 |
+
|
32 |
+
__conditioning_keys__ = {'concat': 'c_concat',
|
33 |
+
'crossattn': 'c_crossattn',
|
34 |
+
'adm': 'y'}
|
35 |
+
|
36 |
+
|
37 |
+
def disabled_train(self, mode=True):
|
38 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
39 |
+
does not change anymore."""
|
40 |
+
return self
|
41 |
+
|
42 |
+
|
43 |
+
def uniform_on_device(r1, r2, shape, device):
|
44 |
+
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
45 |
+
|
46 |
+
|
47 |
+
class DDPM(pl.LightningModule):
|
48 |
+
# classic DDPM with Gaussian diffusion, in image space
|
49 |
+
def __init__(self,
|
50 |
+
unet_config,
|
51 |
+
timesteps=1000,
|
52 |
+
beta_schedule="linear",
|
53 |
+
loss_type="l2",
|
54 |
+
ckpt_path=None,
|
55 |
+
ignore_keys=[],
|
56 |
+
load_only_unet=False,
|
57 |
+
monitor="val/loss",
|
58 |
+
use_ema=True,
|
59 |
+
first_stage_key="image",
|
60 |
+
image_size=256,
|
61 |
+
channels=3,
|
62 |
+
log_every_t=100,
|
63 |
+
clip_denoised=True,
|
64 |
+
linear_start=1e-4,
|
65 |
+
linear_end=2e-2,
|
66 |
+
cosine_s=8e-3,
|
67 |
+
given_betas=None,
|
68 |
+
original_elbo_weight=0.,
|
69 |
+
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
70 |
+
l_simple_weight=1.,
|
71 |
+
conditioning_key=None,
|
72 |
+
parameterization="eps", # all assuming fixed variance schedules
|
73 |
+
scheduler_config=None,
|
74 |
+
use_positional_encodings=False,
|
75 |
+
learn_logvar=False,
|
76 |
+
logvar_init=0.,
|
77 |
+
load_ema=True,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
81 |
+
self.parameterization = parameterization
|
82 |
+
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
|
83 |
+
self.cond_stage_model = None
|
84 |
+
self.clip_denoised = clip_denoised
|
85 |
+
self.log_every_t = log_every_t
|
86 |
+
self.first_stage_key = first_stage_key
|
87 |
+
self.image_size = image_size # try conv?
|
88 |
+
self.channels = channels
|
89 |
+
self.use_positional_encodings = use_positional_encodings
|
90 |
+
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
91 |
+
count_params(self.model, verbose=True)
|
92 |
+
self.use_ema = use_ema
|
93 |
+
|
94 |
+
self.use_scheduler = scheduler_config is not None
|
95 |
+
if self.use_scheduler:
|
96 |
+
self.scheduler_config = scheduler_config
|
97 |
+
|
98 |
+
self.v_posterior = v_posterior
|
99 |
+
self.original_elbo_weight = original_elbo_weight
|
100 |
+
self.l_simple_weight = l_simple_weight
|
101 |
+
|
102 |
+
if monitor is not None:
|
103 |
+
self.monitor = monitor
|
104 |
+
|
105 |
+
if self.use_ema and load_ema:
|
106 |
+
self.model_ema = LitEma(self.model)
|
107 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
108 |
+
|
109 |
+
if ckpt_path is not None:
|
110 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
111 |
+
|
112 |
+
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
113 |
+
if self.use_ema and not load_ema:
|
114 |
+
self.model_ema = LitEma(self.model)
|
115 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
116 |
+
|
117 |
+
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
118 |
+
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
119 |
+
|
120 |
+
self.loss_type = loss_type
|
121 |
+
|
122 |
+
self.learn_logvar = learn_logvar
|
123 |
+
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
|
124 |
+
if self.learn_logvar:
|
125 |
+
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
126 |
+
|
127 |
+
|
128 |
+
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
129 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
130 |
+
if exists(given_betas):
|
131 |
+
betas = given_betas
|
132 |
+
else:
|
133 |
+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
134 |
+
cosine_s=cosine_s)
|
135 |
+
alphas = 1. - betas
|
136 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
137 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
138 |
+
|
139 |
+
timesteps, = betas.shape
|
140 |
+
self.num_timesteps = int(timesteps)
|
141 |
+
self.linear_start = linear_start
|
142 |
+
self.linear_end = linear_end
|
143 |
+
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
144 |
+
|
145 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
146 |
+
|
147 |
+
self.register_buffer('betas', to_torch(betas))
|
148 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
149 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
150 |
+
|
151 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
152 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
153 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
154 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
155 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
156 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
157 |
+
|
158 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
159 |
+
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
|
160 |
+
1. - alphas_cumprod) + self.v_posterior * betas
|
161 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
162 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
163 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
164 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
165 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
166 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
167 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
168 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
169 |
+
|
170 |
+
if self.parameterization == "eps":
|
171 |
+
lvlb_weights = self.betas ** 2 / (
|
172 |
+
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
|
173 |
+
elif self.parameterization == "x0":
|
174 |
+
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
|
175 |
+
else:
|
176 |
+
raise NotImplementedError("mu not supported")
|
177 |
+
# TODO how to choose this term
|
178 |
+
lvlb_weights[0] = lvlb_weights[1]
|
179 |
+
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
|
180 |
+
assert not torch.isnan(self.lvlb_weights).all()
|
181 |
+
|
182 |
+
@contextmanager
|
183 |
+
def ema_scope(self, context=None):
|
184 |
+
if self.use_ema:
|
185 |
+
self.model_ema.store(self.model.parameters())
|
186 |
+
self.model_ema.copy_to(self.model)
|
187 |
+
if context is not None:
|
188 |
+
print(f"{context}: Switched to EMA weights")
|
189 |
+
try:
|
190 |
+
yield None
|
191 |
+
finally:
|
192 |
+
if self.use_ema:
|
193 |
+
self.model_ema.restore(self.model.parameters())
|
194 |
+
if context is not None:
|
195 |
+
print(f"{context}: Restored training weights")
|
196 |
+
|
197 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
198 |
+
sd = torch.load(path, map_location="cpu")
|
199 |
+
if "state_dict" in list(sd.keys()):
|
200 |
+
sd = sd["state_dict"]
|
201 |
+
keys = list(sd.keys())
|
202 |
+
|
203 |
+
# Our model adds additional channels to the first layer to condition on an input image.
|
204 |
+
# For the first layer, copy existing channel weights and initialize new channel weights to zero.
|
205 |
+
input_keys = [
|
206 |
+
"model.diffusion_model.input_blocks.0.0.weight",
|
207 |
+
"model_ema.diffusion_modelinput_blocks00weight",
|
208 |
+
]
|
209 |
+
|
210 |
+
self_sd = self.state_dict()
|
211 |
+
for input_key in input_keys:
|
212 |
+
if input_key not in sd or input_key not in self_sd:
|
213 |
+
continue
|
214 |
+
|
215 |
+
input_weight = self_sd[input_key]
|
216 |
+
|
217 |
+
if input_weight.size() != sd[input_key].size():
|
218 |
+
print(f"Manual init: {input_key}")
|
219 |
+
input_weight.zero_()
|
220 |
+
input_weight[:, :4, :, :].copy_(sd[input_key])
|
221 |
+
ignore_keys.append(input_key)
|
222 |
+
|
223 |
+
for k in keys:
|
224 |
+
for ik in ignore_keys:
|
225 |
+
if k.startswith(ik):
|
226 |
+
print("Deleting key {} from state_dict.".format(k))
|
227 |
+
del sd[k]
|
228 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
229 |
+
sd, strict=False)
|
230 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
231 |
+
if len(missing) > 0:
|
232 |
+
print(f"Missing Keys: {missing}")
|
233 |
+
if len(unexpected) > 0:
|
234 |
+
print(f"Unexpected Keys: {unexpected}")
|
235 |
+
|
236 |
+
def q_mean_variance(self, x_start, t):
|
237 |
+
"""
|
238 |
+
Get the distribution q(x_t | x_0).
|
239 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
240 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
241 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
242 |
+
"""
|
243 |
+
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
|
244 |
+
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
245 |
+
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
246 |
+
return mean, variance, log_variance
|
247 |
+
|
248 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
249 |
+
return (
|
250 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
251 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
252 |
+
)
|
253 |
+
|
254 |
+
def q_posterior(self, x_start, x_t, t):
|
255 |
+
posterior_mean = (
|
256 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
257 |
+
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
258 |
+
)
|
259 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
260 |
+
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
261 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
262 |
+
|
263 |
+
def p_mean_variance(self, x, t, clip_denoised: bool):
|
264 |
+
model_out = self.model(x, t)
|
265 |
+
if self.parameterization == "eps":
|
266 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
267 |
+
elif self.parameterization == "x0":
|
268 |
+
x_recon = model_out
|
269 |
+
if clip_denoised:
|
270 |
+
x_recon.clamp_(-1., 1.)
|
271 |
+
|
272 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
273 |
+
return model_mean, posterior_variance, posterior_log_variance
|
274 |
+
|
275 |
+
@torch.no_grad()
|
276 |
+
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
277 |
+
b, *_, device = *x.shape, x.device
|
278 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
279 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
280 |
+
# no noise when t == 0
|
281 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
282 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
283 |
+
|
284 |
+
@torch.no_grad()
|
285 |
+
def p_sample_loop(self, shape, return_intermediates=False):
|
286 |
+
device = self.betas.device
|
287 |
+
b = shape[0]
|
288 |
+
img = torch.randn(shape, device=device)
|
289 |
+
intermediates = [img]
|
290 |
+
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
|
291 |
+
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
|
292 |
+
clip_denoised=self.clip_denoised)
|
293 |
+
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
|
294 |
+
intermediates.append(img)
|
295 |
+
if return_intermediates:
|
296 |
+
return img, intermediates
|
297 |
+
return img
|
298 |
+
|
299 |
+
@torch.no_grad()
|
300 |
+
def sample(self, batch_size=16, return_intermediates=False):
|
301 |
+
image_size = self.image_size
|
302 |
+
channels = self.channels
|
303 |
+
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
304 |
+
return_intermediates=return_intermediates)
|
305 |
+
|
306 |
+
def q_sample(self, x_start, t, noise=None):
|
307 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
308 |
+
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
309 |
+
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
310 |
+
|
311 |
+
def get_loss(self, pred, target, mean=True):
|
312 |
+
if self.loss_type == 'l1':
|
313 |
+
loss = (target - pred).abs()
|
314 |
+
if mean:
|
315 |
+
loss = loss.mean()
|
316 |
+
elif self.loss_type == 'l2':
|
317 |
+
if mean:
|
318 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
319 |
+
else:
|
320 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
321 |
+
else:
|
322 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
323 |
+
|
324 |
+
return loss
|
325 |
+
|
326 |
+
def p_losses(self, x_start, t, noise=None):
|
327 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
328 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
329 |
+
model_out = self.model(x_noisy, t)
|
330 |
+
|
331 |
+
loss_dict = {}
|
332 |
+
if self.parameterization == "eps":
|
333 |
+
target = noise
|
334 |
+
elif self.parameterization == "x0":
|
335 |
+
target = x_start
|
336 |
+
else:
|
337 |
+
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
338 |
+
|
339 |
+
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
340 |
+
|
341 |
+
log_prefix = 'train' if self.training else 'val'
|
342 |
+
|
343 |
+
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
|
344 |
+
loss_simple = loss.mean() * self.l_simple_weight
|
345 |
+
|
346 |
+
loss_vlb = (self.lvlb_weights[t] * loss).mean()
|
347 |
+
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
|
348 |
+
|
349 |
+
loss = loss_simple + self.original_elbo_weight * loss_vlb
|
350 |
+
|
351 |
+
loss_dict.update({f'{log_prefix}/loss': loss})
|
352 |
+
|
353 |
+
return loss, loss_dict
|
354 |
+
|
355 |
+
def forward(self, x, *args, **kwargs):
|
356 |
+
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
|
357 |
+
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
358 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
359 |
+
return self.p_losses(x, t, *args, **kwargs)
|
360 |
+
|
361 |
+
def get_input(self, batch, k):
|
362 |
+
return batch[k]
|
363 |
+
|
364 |
+
def shared_step(self, batch):
|
365 |
+
x = self.get_input(batch, self.first_stage_key)
|
366 |
+
loss, loss_dict = self(x)
|
367 |
+
return loss, loss_dict
|
368 |
+
|
369 |
+
def training_step(self, batch, batch_idx):
|
370 |
+
loss, loss_dict = self.shared_step(batch)
|
371 |
+
|
372 |
+
self.log_dict(loss_dict, prog_bar=True,
|
373 |
+
logger=True, on_step=True, on_epoch=True)
|
374 |
+
|
375 |
+
self.log("global_step", self.global_step,
|
376 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
377 |
+
|
378 |
+
if self.use_scheduler:
|
379 |
+
lr = self.optimizers().param_groups[0]['lr']
|
380 |
+
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
381 |
+
|
382 |
+
return loss
|
383 |
+
|
384 |
+
@torch.no_grad()
|
385 |
+
def validation_step(self, batch, batch_idx):
|
386 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
387 |
+
with self.ema_scope():
|
388 |
+
_, loss_dict_ema = self.shared_step(batch)
|
389 |
+
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
390 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
391 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
392 |
+
|
393 |
+
def on_train_batch_end(self, *args, **kwargs):
|
394 |
+
if self.use_ema:
|
395 |
+
self.model_ema(self.model)
|
396 |
+
|
397 |
+
def _get_rows_from_list(self, samples):
|
398 |
+
n_imgs_per_row = len(samples)
|
399 |
+
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
|
400 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
401 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
402 |
+
return denoise_grid
|
403 |
+
|
404 |
+
@torch.no_grad()
|
405 |
+
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
406 |
+
log = dict()
|
407 |
+
x = self.get_input(batch, self.first_stage_key)
|
408 |
+
N = min(x.shape[0], N)
|
409 |
+
n_row = min(x.shape[0], n_row)
|
410 |
+
x = x.to(self.device)[:N]
|
411 |
+
log["inputs"] = x
|
412 |
+
|
413 |
+
# get diffusion row
|
414 |
+
diffusion_row = list()
|
415 |
+
x_start = x[:n_row]
|
416 |
+
|
417 |
+
for t in range(self.num_timesteps):
|
418 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
419 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
420 |
+
t = t.to(self.device).long()
|
421 |
+
noise = torch.randn_like(x_start)
|
422 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
423 |
+
diffusion_row.append(x_noisy)
|
424 |
+
|
425 |
+
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
|
426 |
+
|
427 |
+
if sample:
|
428 |
+
# get denoise row
|
429 |
+
with self.ema_scope("Plotting"):
|
430 |
+
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
|
431 |
+
|
432 |
+
log["samples"] = samples
|
433 |
+
log["denoise_row"] = self._get_rows_from_list(denoise_row)
|
434 |
+
|
435 |
+
if return_keys:
|
436 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
437 |
+
return log
|
438 |
+
else:
|
439 |
+
return {key: log[key] for key in return_keys}
|
440 |
+
return log
|
441 |
+
|
442 |
+
def configure_optimizers(self):
|
443 |
+
lr = self.learning_rate
|
444 |
+
params = list(self.model.parameters())
|
445 |
+
if self.learn_logvar:
|
446 |
+
params = params + [self.logvar]
|
447 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
448 |
+
return opt
|
449 |
+
|
450 |
+
|
451 |
+
class LatentDiffusion(DDPM):
|
452 |
+
"""main class"""
|
453 |
+
def __init__(self,
|
454 |
+
first_stage_config,
|
455 |
+
cond_stage_config,
|
456 |
+
num_timesteps_cond=None,
|
457 |
+
cond_stage_key="image",
|
458 |
+
cond_stage_trainable=False,
|
459 |
+
concat_mode=True,
|
460 |
+
cond_stage_forward=None,
|
461 |
+
conditioning_key=None,
|
462 |
+
scale_factor=1.0,
|
463 |
+
scale_by_std=False,
|
464 |
+
load_ema=True,
|
465 |
+
*args, **kwargs):
|
466 |
+
self.num_timesteps_cond = default(num_timesteps_cond, 1)
|
467 |
+
self.scale_by_std = scale_by_std
|
468 |
+
assert self.num_timesteps_cond <= kwargs['timesteps']
|
469 |
+
# for backwards compatibility after implementation of DiffusionWrapper
|
470 |
+
if conditioning_key is None:
|
471 |
+
conditioning_key = 'concat' if concat_mode else 'crossattn'
|
472 |
+
if cond_stage_config == '__is_unconditional__':
|
473 |
+
conditioning_key = None
|
474 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
475 |
+
ignore_keys = kwargs.pop("ignore_keys", [])
|
476 |
+
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
477 |
+
self.concat_mode = concat_mode
|
478 |
+
self.cond_stage_trainable = cond_stage_trainable
|
479 |
+
self.cond_stage_key = cond_stage_key
|
480 |
+
try:
|
481 |
+
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
482 |
+
except:
|
483 |
+
self.num_downs = 0
|
484 |
+
if not scale_by_std:
|
485 |
+
self.scale_factor = scale_factor
|
486 |
+
else:
|
487 |
+
self.register_buffer('scale_factor', torch.tensor(scale_factor))
|
488 |
+
self.instantiate_first_stage(first_stage_config)
|
489 |
+
self.instantiate_cond_stage(cond_stage_config)
|
490 |
+
self.cond_stage_forward = cond_stage_forward
|
491 |
+
self.clip_denoised = False
|
492 |
+
self.bbox_tokenizer = None
|
493 |
+
|
494 |
+
self.restarted_from_ckpt = False
|
495 |
+
if ckpt_path is not None:
|
496 |
+
self.init_from_ckpt(ckpt_path, ignore_keys)
|
497 |
+
self.restarted_from_ckpt = True
|
498 |
+
|
499 |
+
if self.use_ema and not load_ema:
|
500 |
+
self.model_ema = LitEma(self.model)
|
501 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
502 |
+
|
503 |
+
def make_cond_schedule(self, ):
|
504 |
+
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
|
505 |
+
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
|
506 |
+
self.cond_ids[:self.num_timesteps_cond] = ids
|
507 |
+
|
508 |
+
@rank_zero_only
|
509 |
+
@torch.no_grad()
|
510 |
+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
511 |
+
# only for very first batch
|
512 |
+
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
513 |
+
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
514 |
+
# set rescale weight to 1./std of encodings
|
515 |
+
print("### USING STD-RESCALING ###")
|
516 |
+
x = super().get_input(batch, self.first_stage_key)
|
517 |
+
x = x.to(self.device)
|
518 |
+
encoder_posterior = self.encode_first_stage(x)
|
519 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
520 |
+
del self.scale_factor
|
521 |
+
self.register_buffer('scale_factor', 1. / z.flatten().std())
|
522 |
+
print(f"setting self.scale_factor to {self.scale_factor}")
|
523 |
+
print("### USING STD-RESCALING ###")
|
524 |
+
|
525 |
+
def register_schedule(self,
|
526 |
+
given_betas=None, beta_schedule="linear", timesteps=1000,
|
527 |
+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
528 |
+
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
|
529 |
+
|
530 |
+
self.shorten_cond_schedule = self.num_timesteps_cond > 1
|
531 |
+
if self.shorten_cond_schedule:
|
532 |
+
self.make_cond_schedule()
|
533 |
+
|
534 |
+
def instantiate_first_stage(self, config):
|
535 |
+
model = instantiate_from_config(config)
|
536 |
+
self.first_stage_model = model.eval()
|
537 |
+
self.first_stage_model.train = disabled_train
|
538 |
+
for param in self.first_stage_model.parameters():
|
539 |
+
param.requires_grad = False
|
540 |
+
|
541 |
+
def instantiate_cond_stage(self, config):
|
542 |
+
if not self.cond_stage_trainable:
|
543 |
+
if config == "__is_first_stage__":
|
544 |
+
print("Using first stage also as cond stage.")
|
545 |
+
self.cond_stage_model = self.first_stage_model
|
546 |
+
elif config == "__is_unconditional__":
|
547 |
+
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
548 |
+
self.cond_stage_model = None
|
549 |
+
# self.be_unconditional = True
|
550 |
+
else:
|
551 |
+
model = instantiate_from_config(config)
|
552 |
+
self.cond_stage_model = model.eval()
|
553 |
+
self.cond_stage_model.train = disabled_train
|
554 |
+
for param in self.cond_stage_model.parameters():
|
555 |
+
param.requires_grad = False
|
556 |
+
else:
|
557 |
+
assert config != '__is_first_stage__'
|
558 |
+
assert config != '__is_unconditional__'
|
559 |
+
model = instantiate_from_config(config)
|
560 |
+
self.cond_stage_model = model
|
561 |
+
|
562 |
+
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
|
563 |
+
denoise_row = []
|
564 |
+
for zd in tqdm(samples, desc=desc):
|
565 |
+
denoise_row.append(self.decode_first_stage(zd.to(self.device),
|
566 |
+
force_not_quantize=force_no_decoder_quantization))
|
567 |
+
n_imgs_per_row = len(denoise_row)
|
568 |
+
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
|
569 |
+
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
|
570 |
+
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
|
571 |
+
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
|
572 |
+
return denoise_grid
|
573 |
+
|
574 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
575 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
576 |
+
z = encoder_posterior.sample()
|
577 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
578 |
+
z = encoder_posterior
|
579 |
+
else:
|
580 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
581 |
+
return self.scale_factor * z
|
582 |
+
|
583 |
+
def get_learned_conditioning(self, c):
|
584 |
+
if self.cond_stage_forward is None:
|
585 |
+
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
586 |
+
c = self.cond_stage_model.encode(c)
|
587 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
588 |
+
c = c.mode()
|
589 |
+
else:
|
590 |
+
c = self.cond_stage_model(c)
|
591 |
+
else:
|
592 |
+
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
593 |
+
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
594 |
+
return c
|
595 |
+
|
596 |
+
def meshgrid(self, h, w):
|
597 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
598 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
599 |
+
|
600 |
+
arr = torch.cat([y, x], dim=-1)
|
601 |
+
return arr
|
602 |
+
|
603 |
+
def delta_border(self, h, w):
|
604 |
+
"""
|
605 |
+
:param h: height
|
606 |
+
:param w: width
|
607 |
+
:return: normalized distance to image border,
|
608 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
609 |
+
"""
|
610 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
611 |
+
arr = self.meshgrid(h, w) / lower_right_corner
|
612 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
613 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
614 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
615 |
+
return edge_dist
|
616 |
+
|
617 |
+
def get_weighting(self, h, w, Ly, Lx, device):
|
618 |
+
weighting = self.delta_border(h, w)
|
619 |
+
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
|
620 |
+
self.split_input_params["clip_max_weight"], )
|
621 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
622 |
+
|
623 |
+
if self.split_input_params["tie_braker"]:
|
624 |
+
L_weighting = self.delta_border(Ly, Lx)
|
625 |
+
L_weighting = torch.clip(L_weighting,
|
626 |
+
self.split_input_params["clip_min_tie_weight"],
|
627 |
+
self.split_input_params["clip_max_tie_weight"])
|
628 |
+
|
629 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
630 |
+
weighting = weighting * L_weighting
|
631 |
+
return weighting
|
632 |
+
|
633 |
+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
|
634 |
+
"""
|
635 |
+
:param x: img of size (bs, c, h, w)
|
636 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
637 |
+
"""
|
638 |
+
bs, nc, h, w = x.shape
|
639 |
+
|
640 |
+
# number of crops in image
|
641 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
642 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
643 |
+
|
644 |
+
if uf == 1 and df == 1:
|
645 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
646 |
+
unfold = torch.nn.Unfold(**fold_params)
|
647 |
+
|
648 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
649 |
+
|
650 |
+
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
|
651 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
652 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
653 |
+
|
654 |
+
elif uf > 1 and df == 1:
|
655 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
656 |
+
unfold = torch.nn.Unfold(**fold_params)
|
657 |
+
|
658 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
659 |
+
dilation=1, padding=0,
|
660 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
661 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
662 |
+
|
663 |
+
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
|
664 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
665 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
666 |
+
|
667 |
+
elif df > 1 and uf == 1:
|
668 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
669 |
+
unfold = torch.nn.Unfold(**fold_params)
|
670 |
+
|
671 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
672 |
+
dilation=1, padding=0,
|
673 |
+
stride=(stride[0] // df, stride[1] // df))
|
674 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
675 |
+
|
676 |
+
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
|
677 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
678 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
679 |
+
|
680 |
+
else:
|
681 |
+
raise NotImplementedError
|
682 |
+
|
683 |
+
return fold, unfold, normalization, weighting
|
684 |
+
|
685 |
+
@torch.no_grad()
|
686 |
+
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
|
687 |
+
cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
|
688 |
+
x = super().get_input(batch, k)
|
689 |
+
if bs is not None:
|
690 |
+
x = x[:bs]
|
691 |
+
x = x.to(self.device)
|
692 |
+
encoder_posterior = self.encode_first_stage(x)
|
693 |
+
z = self.get_first_stage_encoding(encoder_posterior).detach()
|
694 |
+
cond_key = cond_key or self.cond_stage_key
|
695 |
+
xc = super().get_input(batch, cond_key)
|
696 |
+
if bs is not None:
|
697 |
+
xc["c_crossattn"] = xc["c_crossattn"][:bs]
|
698 |
+
xc["c_concat"] = xc["c_concat"][:bs]
|
699 |
+
cond = {}
|
700 |
+
|
701 |
+
# To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
|
702 |
+
random = torch.rand(x.size(0), device=x.device)
|
703 |
+
prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
|
704 |
+
input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
|
705 |
+
|
706 |
+
null_prompt = self.get_learned_conditioning([""])
|
707 |
+
cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())]
|
708 |
+
cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()]
|
709 |
+
|
710 |
+
out = [z, cond]
|
711 |
+
if return_first_stage_outputs:
|
712 |
+
xrec = self.decode_first_stage(z)
|
713 |
+
out.extend([x, xrec])
|
714 |
+
if return_original_cond:
|
715 |
+
out.append(xc)
|
716 |
+
return out
|
717 |
+
|
718 |
+
@torch.no_grad()
|
719 |
+
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
720 |
+
if predict_cids:
|
721 |
+
if z.dim() == 4:
|
722 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
723 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
724 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
725 |
+
|
726 |
+
z = 1. / self.scale_factor * z
|
727 |
+
|
728 |
+
if hasattr(self, "split_input_params"):
|
729 |
+
if self.split_input_params["patch_distributed_vq"]:
|
730 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
731 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
732 |
+
uf = self.split_input_params["vqf"]
|
733 |
+
bs, nc, h, w = z.shape
|
734 |
+
if ks[0] > h or ks[1] > w:
|
735 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
736 |
+
print("reducing Kernel")
|
737 |
+
|
738 |
+
if stride[0] > h or stride[1] > w:
|
739 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
740 |
+
print("reducing stride")
|
741 |
+
|
742 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
743 |
+
|
744 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
745 |
+
# 1. Reshape to img shape
|
746 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
747 |
+
|
748 |
+
# 2. apply model loop over last dim
|
749 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
750 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
751 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
752 |
+
for i in range(z.shape[-1])]
|
753 |
+
else:
|
754 |
+
|
755 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
756 |
+
for i in range(z.shape[-1])]
|
757 |
+
|
758 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
759 |
+
o = o * weighting
|
760 |
+
# Reverse 1. reshape to img shape
|
761 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
762 |
+
# stitch crops together
|
763 |
+
decoded = fold(o)
|
764 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
765 |
+
return decoded
|
766 |
+
else:
|
767 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
768 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
769 |
+
else:
|
770 |
+
return self.first_stage_model.decode(z)
|
771 |
+
|
772 |
+
else:
|
773 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
774 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
775 |
+
else:
|
776 |
+
return self.first_stage_model.decode(z)
|
777 |
+
|
778 |
+
# same as above but without decorator
|
779 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
780 |
+
if predict_cids:
|
781 |
+
if z.dim() == 4:
|
782 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
783 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
784 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
785 |
+
|
786 |
+
z = 1. / self.scale_factor * z
|
787 |
+
|
788 |
+
if hasattr(self, "split_input_params"):
|
789 |
+
if self.split_input_params["patch_distributed_vq"]:
|
790 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
791 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
792 |
+
uf = self.split_input_params["vqf"]
|
793 |
+
bs, nc, h, w = z.shape
|
794 |
+
if ks[0] > h or ks[1] > w:
|
795 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
796 |
+
print("reducing Kernel")
|
797 |
+
|
798 |
+
if stride[0] > h or stride[1] > w:
|
799 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
800 |
+
print("reducing stride")
|
801 |
+
|
802 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
|
803 |
+
|
804 |
+
z = unfold(z) # (bn, nc * prod(**ks), L)
|
805 |
+
# 1. Reshape to img shape
|
806 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
807 |
+
|
808 |
+
# 2. apply model loop over last dim
|
809 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
810 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
811 |
+
force_not_quantize=predict_cids or force_not_quantize)
|
812 |
+
for i in range(z.shape[-1])]
|
813 |
+
else:
|
814 |
+
|
815 |
+
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
|
816 |
+
for i in range(z.shape[-1])]
|
817 |
+
|
818 |
+
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
819 |
+
o = o * weighting
|
820 |
+
# Reverse 1. reshape to img shape
|
821 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
822 |
+
# stitch crops together
|
823 |
+
decoded = fold(o)
|
824 |
+
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
825 |
+
return decoded
|
826 |
+
else:
|
827 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
828 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
829 |
+
else:
|
830 |
+
return self.first_stage_model.decode(z)
|
831 |
+
|
832 |
+
else:
|
833 |
+
if isinstance(self.first_stage_model, VQModelInterface):
|
834 |
+
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
|
835 |
+
else:
|
836 |
+
return self.first_stage_model.decode(z)
|
837 |
+
|
838 |
+
@torch.no_grad()
|
839 |
+
def encode_first_stage(self, x):
|
840 |
+
if hasattr(self, "split_input_params"):
|
841 |
+
if self.split_input_params["patch_distributed_vq"]:
|
842 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
843 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
844 |
+
df = self.split_input_params["vqf"]
|
845 |
+
self.split_input_params['original_image_size'] = x.shape[-2:]
|
846 |
+
bs, nc, h, w = x.shape
|
847 |
+
if ks[0] > h or ks[1] > w:
|
848 |
+
ks = (min(ks[0], h), min(ks[1], w))
|
849 |
+
print("reducing Kernel")
|
850 |
+
|
851 |
+
if stride[0] > h or stride[1] > w:
|
852 |
+
stride = (min(stride[0], h), min(stride[1], w))
|
853 |
+
print("reducing stride")
|
854 |
+
|
855 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
|
856 |
+
z = unfold(x) # (bn, nc * prod(**ks), L)
|
857 |
+
# Reshape to img shape
|
858 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
859 |
+
|
860 |
+
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
|
861 |
+
for i in range(z.shape[-1])]
|
862 |
+
|
863 |
+
o = torch.stack(output_list, axis=-1)
|
864 |
+
o = o * weighting
|
865 |
+
|
866 |
+
# Reverse reshape to img shape
|
867 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
868 |
+
# stitch crops together
|
869 |
+
decoded = fold(o)
|
870 |
+
decoded = decoded / normalization
|
871 |
+
return decoded
|
872 |
+
|
873 |
+
else:
|
874 |
+
return self.first_stage_model.encode(x)
|
875 |
+
else:
|
876 |
+
return self.first_stage_model.encode(x)
|
877 |
+
|
878 |
+
def shared_step(self, batch, **kwargs):
|
879 |
+
x, c = self.get_input(batch, self.first_stage_key)
|
880 |
+
loss = self(x, c)
|
881 |
+
return loss
|
882 |
+
|
883 |
+
def forward(self, x, c, *args, **kwargs):
|
884 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
885 |
+
if self.model.conditioning_key is not None:
|
886 |
+
assert c is not None
|
887 |
+
if self.cond_stage_trainable:
|
888 |
+
c = self.get_learned_conditioning(c)
|
889 |
+
if self.shorten_cond_schedule: # TODO: drop this option
|
890 |
+
tc = self.cond_ids[t].to(self.device)
|
891 |
+
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
892 |
+
return self.p_losses(x, c, t, *args, **kwargs)
|
893 |
+
|
894 |
+
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
895 |
+
def rescale_bbox(bbox):
|
896 |
+
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
897 |
+
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
898 |
+
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
899 |
+
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
900 |
+
return x0, y0, w, h
|
901 |
+
|
902 |
+
return [rescale_bbox(b) for b in bboxes]
|
903 |
+
|
904 |
+
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
905 |
+
|
906 |
+
if isinstance(cond, dict):
|
907 |
+
# hybrid case, cond is exptected to be a dict
|
908 |
+
pass
|
909 |
+
else:
|
910 |
+
if not isinstance(cond, list):
|
911 |
+
cond = [cond]
|
912 |
+
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
|
913 |
+
cond = {key: cond}
|
914 |
+
|
915 |
+
if hasattr(self, "split_input_params"):
|
916 |
+
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
917 |
+
assert not return_ids
|
918 |
+
ks = self.split_input_params["ks"] # eg. (128, 128)
|
919 |
+
stride = self.split_input_params["stride"] # eg. (64, 64)
|
920 |
+
|
921 |
+
h, w = x_noisy.shape[-2:]
|
922 |
+
|
923 |
+
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
|
924 |
+
|
925 |
+
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
|
926 |
+
# Reshape to img shape
|
927 |
+
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
928 |
+
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
929 |
+
|
930 |
+
if self.cond_stage_key in ["image", "LR_image", "segmentation",
|
931 |
+
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
|
932 |
+
c_key = next(iter(cond.keys())) # get key
|
933 |
+
c = next(iter(cond.values())) # get value
|
934 |
+
assert (len(c) == 1) # todo extend to list with more than one elem
|
935 |
+
c = c[0] # get element
|
936 |
+
|
937 |
+
c = unfold(c)
|
938 |
+
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
939 |
+
|
940 |
+
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
941 |
+
|
942 |
+
elif self.cond_stage_key == 'coordinates_bbox':
|
943 |
+
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
944 |
+
|
945 |
+
# assuming padding of unfold is always 0 and its dilation is always 1
|
946 |
+
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
947 |
+
full_img_h, full_img_w = self.split_input_params['original_image_size']
|
948 |
+
# as we are operating on latents, we need the factor from the original image size to the
|
949 |
+
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
|
950 |
+
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
951 |
+
rescale_latent = 2 ** (num_downs)
|
952 |
+
|
953 |
+
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
954 |
+
# need to rescale the tl patch coordinates to be in between (0,1)
|
955 |
+
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
956 |
+
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
957 |
+
for patch_nr in range(z.shape[-1])]
|
958 |
+
|
959 |
+
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
|
960 |
+
patch_limits = [(x_tl, y_tl,
|
961 |
+
rescale_latent * ks[0] / full_img_w,
|
962 |
+
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
|
963 |
+
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
|
964 |
+
|
965 |
+
# tokenize crop coordinates for the bounding boxes of the respective patches
|
966 |
+
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
|
967 |
+
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
|
968 |
+
print(patch_limits_tknzd[0].shape)
|
969 |
+
# cut tknzd crop position from conditioning
|
970 |
+
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
|
971 |
+
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
|
972 |
+
print(cut_cond.shape)
|
973 |
+
|
974 |
+
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
|
975 |
+
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
|
976 |
+
print(adapted_cond.shape)
|
977 |
+
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
978 |
+
print(adapted_cond.shape)
|
979 |
+
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
|
980 |
+
print(adapted_cond.shape)
|
981 |
+
|
982 |
+
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
|
983 |
+
|
984 |
+
else:
|
985 |
+
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
|
986 |
+
|
987 |
+
# apply model by loop over crops
|
988 |
+
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
|
989 |
+
assert not isinstance(output_list[0],
|
990 |
+
tuple) # todo cant deal with multiple model outputs check this never happens
|
991 |
+
|
992 |
+
o = torch.stack(output_list, axis=-1)
|
993 |
+
o = o * weighting
|
994 |
+
# Reverse reshape to img shape
|
995 |
+
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
996 |
+
# stitch crops together
|
997 |
+
x_recon = fold(o) / normalization
|
998 |
+
|
999 |
+
else:
|
1000 |
+
x_recon = self.model(x_noisy, t, **cond)
|
1001 |
+
|
1002 |
+
if isinstance(x_recon, tuple) and not return_ids:
|
1003 |
+
return x_recon[0]
|
1004 |
+
else:
|
1005 |
+
return x_recon
|
1006 |
+
|
1007 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
1008 |
+
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
|
1009 |
+
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
1010 |
+
|
1011 |
+
def _prior_bpd(self, x_start):
|
1012 |
+
"""
|
1013 |
+
Get the prior KL term for the variational lower-bound, measured in
|
1014 |
+
bits-per-dim.
|
1015 |
+
This term can't be optimized, as it only depends on the encoder.
|
1016 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1017 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
1018 |
+
"""
|
1019 |
+
batch_size = x_start.shape[0]
|
1020 |
+
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1021 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1022 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
1023 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1024 |
+
|
1025 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
1026 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
1027 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
1028 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
1029 |
+
|
1030 |
+
loss_dict = {}
|
1031 |
+
prefix = 'train' if self.training else 'val'
|
1032 |
+
|
1033 |
+
if self.parameterization == "x0":
|
1034 |
+
target = x_start
|
1035 |
+
elif self.parameterization == "eps":
|
1036 |
+
target = noise
|
1037 |
+
else:
|
1038 |
+
raise NotImplementedError()
|
1039 |
+
|
1040 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
1041 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
1042 |
+
|
1043 |
+
logvar_t = self.logvar[t].to(self.device)
|
1044 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
1045 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
1046 |
+
if self.learn_logvar:
|
1047 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
1048 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
1049 |
+
|
1050 |
+
loss = self.l_simple_weight * loss.mean()
|
1051 |
+
|
1052 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
1053 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
1054 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
1055 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
1056 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
1057 |
+
|
1058 |
+
return loss, loss_dict
|
1059 |
+
|
1060 |
+
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
|
1061 |
+
return_x0=False, score_corrector=None, corrector_kwargs=None):
|
1062 |
+
t_in = t
|
1063 |
+
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
1064 |
+
|
1065 |
+
if score_corrector is not None:
|
1066 |
+
assert self.parameterization == "eps"
|
1067 |
+
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
|
1068 |
+
|
1069 |
+
if return_codebook_ids:
|
1070 |
+
model_out, logits = model_out
|
1071 |
+
|
1072 |
+
if self.parameterization == "eps":
|
1073 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
|
1074 |
+
elif self.parameterization == "x0":
|
1075 |
+
x_recon = model_out
|
1076 |
+
else:
|
1077 |
+
raise NotImplementedError()
|
1078 |
+
|
1079 |
+
if clip_denoised:
|
1080 |
+
x_recon.clamp_(-1., 1.)
|
1081 |
+
if quantize_denoised:
|
1082 |
+
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
1083 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
1084 |
+
if return_codebook_ids:
|
1085 |
+
return model_mean, posterior_variance, posterior_log_variance, logits
|
1086 |
+
elif return_x0:
|
1087 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
1088 |
+
else:
|
1089 |
+
return model_mean, posterior_variance, posterior_log_variance
|
1090 |
+
|
1091 |
+
@torch.no_grad()
|
1092 |
+
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
|
1093 |
+
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
|
1094 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
1095 |
+
b, *_, device = *x.shape, x.device
|
1096 |
+
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
|
1097 |
+
return_codebook_ids=return_codebook_ids,
|
1098 |
+
quantize_denoised=quantize_denoised,
|
1099 |
+
return_x0=return_x0,
|
1100 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1101 |
+
if return_codebook_ids:
|
1102 |
+
raise DeprecationWarning("Support dropped.")
|
1103 |
+
model_mean, _, model_log_variance, logits = outputs
|
1104 |
+
elif return_x0:
|
1105 |
+
model_mean, _, model_log_variance, x0 = outputs
|
1106 |
+
else:
|
1107 |
+
model_mean, _, model_log_variance = outputs
|
1108 |
+
|
1109 |
+
noise = noise_like(x.shape, device, repeat_noise) * temperature
|
1110 |
+
if noise_dropout > 0.:
|
1111 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
1112 |
+
# no noise when t == 0
|
1113 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
1114 |
+
|
1115 |
+
if return_codebook_ids:
|
1116 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
|
1117 |
+
if return_x0:
|
1118 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
|
1119 |
+
else:
|
1120 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
1121 |
+
|
1122 |
+
@torch.no_grad()
|
1123 |
+
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
|
1124 |
+
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
|
1125 |
+
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
|
1126 |
+
log_every_t=None):
|
1127 |
+
if not log_every_t:
|
1128 |
+
log_every_t = self.log_every_t
|
1129 |
+
timesteps = self.num_timesteps
|
1130 |
+
if batch_size is not None:
|
1131 |
+
b = batch_size if batch_size is not None else shape[0]
|
1132 |
+
shape = [batch_size] + list(shape)
|
1133 |
+
else:
|
1134 |
+
b = batch_size = shape[0]
|
1135 |
+
if x_T is None:
|
1136 |
+
img = torch.randn(shape, device=self.device)
|
1137 |
+
else:
|
1138 |
+
img = x_T
|
1139 |
+
intermediates = []
|
1140 |
+
if cond is not None:
|
1141 |
+
if isinstance(cond, dict):
|
1142 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1143 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
1144 |
+
else:
|
1145 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1146 |
+
|
1147 |
+
if start_T is not None:
|
1148 |
+
timesteps = min(timesteps, start_T)
|
1149 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
|
1150 |
+
total=timesteps) if verbose else reversed(
|
1151 |
+
range(0, timesteps))
|
1152 |
+
if type(temperature) == float:
|
1153 |
+
temperature = [temperature] * timesteps
|
1154 |
+
|
1155 |
+
for i in iterator:
|
1156 |
+
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
|
1157 |
+
if self.shorten_cond_schedule:
|
1158 |
+
assert self.model.conditioning_key != 'hybrid'
|
1159 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1160 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1161 |
+
|
1162 |
+
img, x0_partial = self.p_sample(img, cond, ts,
|
1163 |
+
clip_denoised=self.clip_denoised,
|
1164 |
+
quantize_denoised=quantize_denoised, return_x0=True,
|
1165 |
+
temperature=temperature[i], noise_dropout=noise_dropout,
|
1166 |
+
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
|
1167 |
+
if mask is not None:
|
1168 |
+
assert x0 is not None
|
1169 |
+
img_orig = self.q_sample(x0, ts)
|
1170 |
+
img = img_orig * mask + (1. - mask) * img
|
1171 |
+
|
1172 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1173 |
+
intermediates.append(x0_partial)
|
1174 |
+
if callback: callback(i)
|
1175 |
+
if img_callback: img_callback(img, i)
|
1176 |
+
return img, intermediates
|
1177 |
+
|
1178 |
+
@torch.no_grad()
|
1179 |
+
def p_sample_loop(self, cond, shape, return_intermediates=False,
|
1180 |
+
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
|
1181 |
+
mask=None, x0=None, img_callback=None, start_T=None,
|
1182 |
+
log_every_t=None):
|
1183 |
+
|
1184 |
+
if not log_every_t:
|
1185 |
+
log_every_t = self.log_every_t
|
1186 |
+
device = self.betas.device
|
1187 |
+
b = shape[0]
|
1188 |
+
if x_T is None:
|
1189 |
+
img = torch.randn(shape, device=device)
|
1190 |
+
else:
|
1191 |
+
img = x_T
|
1192 |
+
|
1193 |
+
intermediates = [img]
|
1194 |
+
if timesteps is None:
|
1195 |
+
timesteps = self.num_timesteps
|
1196 |
+
|
1197 |
+
if start_T is not None:
|
1198 |
+
timesteps = min(timesteps, start_T)
|
1199 |
+
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
|
1200 |
+
range(0, timesteps))
|
1201 |
+
|
1202 |
+
if mask is not None:
|
1203 |
+
assert x0 is not None
|
1204 |
+
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
|
1205 |
+
|
1206 |
+
for i in iterator:
|
1207 |
+
ts = torch.full((b,), i, device=device, dtype=torch.long)
|
1208 |
+
if self.shorten_cond_schedule:
|
1209 |
+
assert self.model.conditioning_key != 'hybrid'
|
1210 |
+
tc = self.cond_ids[ts].to(cond.device)
|
1211 |
+
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
|
1212 |
+
|
1213 |
+
img = self.p_sample(img, cond, ts,
|
1214 |
+
clip_denoised=self.clip_denoised,
|
1215 |
+
quantize_denoised=quantize_denoised)
|
1216 |
+
if mask is not None:
|
1217 |
+
img_orig = self.q_sample(x0, ts)
|
1218 |
+
img = img_orig * mask + (1. - mask) * img
|
1219 |
+
|
1220 |
+
if i % log_every_t == 0 or i == timesteps - 1:
|
1221 |
+
intermediates.append(img)
|
1222 |
+
if callback: callback(i)
|
1223 |
+
if img_callback: img_callback(img, i)
|
1224 |
+
|
1225 |
+
if return_intermediates:
|
1226 |
+
return img, intermediates
|
1227 |
+
return img
|
1228 |
+
|
1229 |
+
@torch.no_grad()
|
1230 |
+
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
1231 |
+
verbose=True, timesteps=None, quantize_denoised=False,
|
1232 |
+
mask=None, x0=None, shape=None,**kwargs):
|
1233 |
+
if shape is None:
|
1234 |
+
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
1235 |
+
if cond is not None:
|
1236 |
+
if isinstance(cond, dict):
|
1237 |
+
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
1238 |
+
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
1239 |
+
else:
|
1240 |
+
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
1241 |
+
return self.p_sample_loop(cond,
|
1242 |
+
shape,
|
1243 |
+
return_intermediates=return_intermediates, x_T=x_T,
|
1244 |
+
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
|
1245 |
+
mask=mask, x0=x0)
|
1246 |
+
|
1247 |
+
@torch.no_grad()
|
1248 |
+
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
1249 |
+
|
1250 |
+
if ddim:
|
1251 |
+
ddim_sampler = DDIMSampler(self)
|
1252 |
+
shape = (self.channels, self.image_size, self.image_size)
|
1253 |
+
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
1254 |
+
shape,cond,verbose=False,**kwargs)
|
1255 |
+
|
1256 |
+
else:
|
1257 |
+
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
1258 |
+
return_intermediates=True,**kwargs)
|
1259 |
+
|
1260 |
+
return samples, intermediates
|
1261 |
+
|
1262 |
+
|
1263 |
+
@torch.no_grad()
|
1264 |
+
def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
1265 |
+
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
|
1266 |
+
plot_diffusion_rows=False, **kwargs):
|
1267 |
+
|
1268 |
+
use_ddim = False
|
1269 |
+
|
1270 |
+
log = dict()
|
1271 |
+
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
1272 |
+
return_first_stage_outputs=True,
|
1273 |
+
force_c_encode=True,
|
1274 |
+
return_original_cond=True,
|
1275 |
+
bs=N, uncond=0)
|
1276 |
+
N = min(x.shape[0], N)
|
1277 |
+
n_row = min(x.shape[0], n_row)
|
1278 |
+
log["inputs"] = x
|
1279 |
+
log["reals"] = xc["c_concat"]
|
1280 |
+
log["reconstruction"] = xrec
|
1281 |
+
if self.model.conditioning_key is not None:
|
1282 |
+
if hasattr(self.cond_stage_model, "decode"):
|
1283 |
+
xc = self.cond_stage_model.decode(c)
|
1284 |
+
log["conditioning"] = xc
|
1285 |
+
elif self.cond_stage_key in ["caption"]:
|
1286 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
|
1287 |
+
log["conditioning"] = xc
|
1288 |
+
elif self.cond_stage_key == 'class_label':
|
1289 |
+
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
1290 |
+
log['conditioning'] = xc
|
1291 |
+
elif isimage(xc):
|
1292 |
+
log["conditioning"] = xc
|
1293 |
+
if ismap(xc):
|
1294 |
+
log["original_conditioning"] = self.to_rgb(xc)
|
1295 |
+
|
1296 |
+
if plot_diffusion_rows:
|
1297 |
+
# get diffusion row
|
1298 |
+
diffusion_row = list()
|
1299 |
+
z_start = z[:n_row]
|
1300 |
+
for t in range(self.num_timesteps):
|
1301 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
1302 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
1303 |
+
t = t.to(self.device).long()
|
1304 |
+
noise = torch.randn_like(z_start)
|
1305 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
1306 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
1307 |
+
|
1308 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
1309 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
1310 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
1311 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
1312 |
+
log["diffusion_row"] = diffusion_grid
|
1313 |
+
|
1314 |
+
if sample:
|
1315 |
+
# get denoise row
|
1316 |
+
with self.ema_scope("Plotting"):
|
1317 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1318 |
+
ddim_steps=ddim_steps,eta=ddim_eta)
|
1319 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
1320 |
+
x_samples = self.decode_first_stage(samples)
|
1321 |
+
log["samples"] = x_samples
|
1322 |
+
if plot_denoise_rows:
|
1323 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
1324 |
+
log["denoise_row"] = denoise_grid
|
1325 |
+
|
1326 |
+
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
1327 |
+
self.first_stage_model, IdentityFirstStage):
|
1328 |
+
# also display when quantizing x0 while sampling
|
1329 |
+
with self.ema_scope("Plotting Quantized Denoised"):
|
1330 |
+
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
1331 |
+
ddim_steps=ddim_steps,eta=ddim_eta,
|
1332 |
+
quantize_denoised=True)
|
1333 |
+
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
1334 |
+
# quantize_denoised=True)
|
1335 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1336 |
+
log["samples_x0_quantized"] = x_samples
|
1337 |
+
|
1338 |
+
if inpaint:
|
1339 |
+
# make a simple center square
|
1340 |
+
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
1341 |
+
mask = torch.ones(N, h, w).to(self.device)
|
1342 |
+
# zeros will be filled in
|
1343 |
+
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
1344 |
+
mask = mask[:, None, ...]
|
1345 |
+
with self.ema_scope("Plotting Inpaint"):
|
1346 |
+
|
1347 |
+
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
1348 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1349 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1350 |
+
log["samples_inpainting"] = x_samples
|
1351 |
+
log["mask"] = mask
|
1352 |
+
|
1353 |
+
# outpaint
|
1354 |
+
with self.ema_scope("Plotting Outpaint"):
|
1355 |
+
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
1356 |
+
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
1357 |
+
x_samples = self.decode_first_stage(samples.to(self.device))
|
1358 |
+
log["samples_outpainting"] = x_samples
|
1359 |
+
|
1360 |
+
if plot_progressive_rows:
|
1361 |
+
with self.ema_scope("Plotting Progressives"):
|
1362 |
+
img, progressives = self.progressive_denoising(c,
|
1363 |
+
shape=(self.channels, self.image_size, self.image_size),
|
1364 |
+
batch_size=N)
|
1365 |
+
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
1366 |
+
log["progressive_row"] = prog_row
|
1367 |
+
|
1368 |
+
if return_keys:
|
1369 |
+
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
1370 |
+
return log
|
1371 |
+
else:
|
1372 |
+
return {key: log[key] for key in return_keys}
|
1373 |
+
return log
|
1374 |
+
|
1375 |
+
def configure_optimizers(self):
|
1376 |
+
lr = self.learning_rate
|
1377 |
+
params = list(self.model.parameters())
|
1378 |
+
if self.cond_stage_trainable:
|
1379 |
+
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
1380 |
+
params = params + list(self.cond_stage_model.parameters())
|
1381 |
+
if self.learn_logvar:
|
1382 |
+
print('Diffusion model optimizing logvar')
|
1383 |
+
params.append(self.logvar)
|
1384 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
1385 |
+
if self.use_scheduler:
|
1386 |
+
assert 'target' in self.scheduler_config
|
1387 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
1388 |
+
|
1389 |
+
print("Setting up LambdaLR scheduler...")
|
1390 |
+
scheduler = [
|
1391 |
+
{
|
1392 |
+
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
|
1393 |
+
'interval': 'step',
|
1394 |
+
'frequency': 1
|
1395 |
+
}]
|
1396 |
+
return [opt], scheduler
|
1397 |
+
return opt
|
1398 |
+
|
1399 |
+
@torch.no_grad()
|
1400 |
+
def to_rgb(self, x):
|
1401 |
+
x = x.float()
|
1402 |
+
if not hasattr(self, "colorize"):
|
1403 |
+
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
|
1404 |
+
x = nn.functional.conv2d(x, weight=self.colorize)
|
1405 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
1406 |
+
return x
|
1407 |
+
|
1408 |
+
|
1409 |
+
class DiffusionWrapper(pl.LightningModule):
|
1410 |
+
def __init__(self, diff_model_config, conditioning_key):
|
1411 |
+
super().__init__()
|
1412 |
+
self.diffusion_model = instantiate_from_config(diff_model_config)
|
1413 |
+
self.conditioning_key = conditioning_key
|
1414 |
+
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
|
1415 |
+
|
1416 |
+
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
|
1417 |
+
if self.conditioning_key is None:
|
1418 |
+
out = self.diffusion_model(x, t)
|
1419 |
+
elif self.conditioning_key == 'concat':
|
1420 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1421 |
+
out = self.diffusion_model(xc, t)
|
1422 |
+
elif self.conditioning_key == 'crossattn':
|
1423 |
+
cc = torch.cat(c_crossattn, 1)
|
1424 |
+
out = self.diffusion_model(x, t, context=cc)
|
1425 |
+
elif self.conditioning_key == 'hybrid':
|
1426 |
+
xc = torch.cat([x] + c_concat, dim=1)
|
1427 |
+
cc = torch.cat(c_crossattn, 1)
|
1428 |
+
out = self.diffusion_model(xc, t, context=cc)
|
1429 |
+
elif self.conditioning_key == 'adm':
|
1430 |
+
cc = c_crossattn[0]
|
1431 |
+
out = self.diffusion_model(x, t, y=cc)
|
1432 |
+
else:
|
1433 |
+
raise NotImplementedError()
|
1434 |
+
|
1435 |
+
return out
|
1436 |
+
|
1437 |
+
|
1438 |
+
class Layout2ImgDiffusion(LatentDiffusion):
|
1439 |
+
# TODO: move all layout-specific hacks to this class
|
1440 |
+
def __init__(self, cond_stage_key, *args, **kwargs):
|
1441 |
+
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
1442 |
+
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
1443 |
+
|
1444 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
1445 |
+
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
1446 |
+
|
1447 |
+
key = 'train' if self.training else 'validation'
|
1448 |
+
dset = self.trainer.datamodule.datasets[key]
|
1449 |
+
mapper = dset.conditional_builders[self.cond_stage_key]
|
1450 |
+
|
1451 |
+
bbox_imgs = []
|
1452 |
+
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
|
1453 |
+
for tknzd_bbox in batch[self.cond_stage_key][:N]:
|
1454 |
+
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
|
1455 |
+
bbox_imgs.append(bboximg)
|
1456 |
+
|
1457 |
+
cond_img = torch.stack(bbox_imgs, dim=0)
|
1458 |
+
logs['bbox_image'] = cond_img
|
1459 |
+
return logs
|
stable_diffusion/ldm/modules/attention.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
from inspect import isfunction
|
2 |
import math
|
3 |
import torch
|
@@ -89,7 +92,7 @@ class LinearAttention(nn.Module):
|
|
89 |
b, c, h, w = x.shape
|
90 |
qkv = self.to_qkv(x)
|
91 |
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
92 |
-
k = k.softmax(dim=-1)
|
93 |
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
94 |
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
95 |
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
@@ -167,7 +170,11 @@ class CrossAttention(nn.Module):
|
|
167 |
nn.Dropout(dropout)
|
168 |
)
|
169 |
|
|
|
|
|
170 |
def forward(self, x, context=None, mask=None):
|
|
|
|
|
171 |
h = self.heads
|
172 |
|
173 |
q = self.to_q(x)
|
@@ -179,6 +186,13 @@ class CrossAttention(nn.Module):
|
|
179 |
|
180 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
if exists(mask):
|
183 |
mask = rearrange(mask, 'b ... -> b (...)')
|
184 |
max_neg_value = -torch.finfo(sim.dtype).max
|
@@ -258,4 +272,4 @@ class SpatialTransformer(nn.Module):
|
|
258 |
x = block(x, context=context)
|
259 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
260 |
x = self.proj_out(x)
|
261 |
-
return x + x_in
|
|
|
1 |
+
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
2 |
+
# See more details in LICENSE.
|
3 |
+
|
4 |
from inspect import isfunction
|
5 |
import math
|
6 |
import torch
|
|
|
92 |
b, c, h, w = x.shape
|
93 |
qkv = self.to_qkv(x)
|
94 |
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
95 |
+
k = k.softmax(dim=-1)
|
96 |
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
97 |
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
98 |
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
|
|
170 |
nn.Dropout(dropout)
|
171 |
)
|
172 |
|
173 |
+
self.prompt_to_prompt = False
|
174 |
+
|
175 |
def forward(self, x, context=None, mask=None):
|
176 |
+
is_self_attn = context is None
|
177 |
+
|
178 |
h = self.heads
|
179 |
|
180 |
q = self.to_q(x)
|
|
|
186 |
|
187 |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
188 |
|
189 |
+
if self.prompt_to_prompt and is_self_attn:
|
190 |
+
# Unlike the original Prompt-to-Prompt which uses cross-attention layers, we copy attention maps for self-attention layers.
|
191 |
+
# There must be 4 elements in the batch: {conditional, unconditional} x {prompt 1, prompt 2}
|
192 |
+
assert x.size(0) == 4
|
193 |
+
sims = sim.chunk(4)
|
194 |
+
sim = torch.cat((sims[0], sims[0], sims[2], sims[2]))
|
195 |
+
|
196 |
if exists(mask):
|
197 |
mask = rearrange(mask, 'b ... -> b (...)')
|
198 |
max_neg_value = -torch.finfo(sim.dtype).max
|
|
|
272 |
x = block(x, context=context)
|
273 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
274 |
x = self.proj_out(x)
|
275 |
+
return x + x_in
|
stable_diffusion/main.py
CHANGED
@@ -738,4 +738,4 @@ if __name__ == "__main__":
|
|
738 |
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
739 |
os.rename(logdir, dst)
|
740 |
if trainer.global_rank == 0:
|
741 |
-
print(trainer.profiler.summary())
|
|
|
738 |
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
739 |
os.rename(logdir, dst)
|
740 |
if trainer.global_rank == 0:
|
741 |
+
print(trainer.profiler.summary())
|