Robin Rombach commited on
Commit
1fe8a15
β€’
1 Parent(s): 677e3db

stable diffusion

Browse files

Former-commit-id: 2ff270f4e0c884d9684fa038f6d84d8600a94b39

Files changed (43) hide show
  1. LICENSE +6 -13
  2. README.md +105 -144
  3. Stable_Diffusion_v1_Model_Card.md +140 -0
  4. assets/a-painting-of-a-fire.png +0 -0
  5. assets/a-photograph-of-a-fire.png +0 -0
  6. assets/a-shirt-with-a-fire-printed-on-it.png +0 -0
  7. assets/a-shirt-with-the-inscription-'fire'.png +0 -0
  8. assets/a-watercolor-painting-of-a-fire.png +0 -0
  9. assets/birdhouse.png +0 -0
  10. assets/fire.png +0 -0
  11. assets/rdm-preview.jpg +0 -0
  12. assets/stable-samples/img2img/mountains-1.png +0 -0
  13. assets/stable-samples/img2img/mountains-2.png +0 -0
  14. assets/stable-samples/img2img/mountains-3.png +0 -0
  15. assets/stable-samples/img2img/sketch-mountains-input.jpg +0 -0
  16. assets/stable-samples/img2img/upscaling-in.png.REMOVED.git-id +1 -0
  17. assets/stable-samples/img2img/upscaling-out.png.REMOVED.git-id +1 -0
  18. assets/stable-samples/txt2img/000002025.png +0 -0
  19. assets/stable-samples/txt2img/000002035.png +0 -0
  20. assets/stable-samples/txt2img/merged-0005.png.REMOVED.git-id +1 -0
  21. assets/stable-samples/txt2img/merged-0006.png.REMOVED.git-id +1 -0
  22. assets/stable-samples/txt2img/merged-0007.png.REMOVED.git-id +1 -0
  23. assets/the-earth-is-on-fire,-oil-on-canvas.png +0 -0
  24. assets/txt2img-convsample.png +0 -0
  25. assets/txt2img-preview.png.REMOVED.git-id +1 -0
  26. assets/v1-variants-scores.jpg +0 -0
  27. configs/latent-diffusion/cin256-v2.yaml +68 -0
  28. configs/latent-diffusion/txt2img-1p4B-eval.yaml +71 -0
  29. configs/retrieval-augmented-diffusion/768x768.yaml +68 -0
  30. configs/stable-diffusion/v1-inference.yaml +70 -0
  31. data/imagenet_clsidx_to_label.txt +1000 -0
  32. environment.yaml +5 -5
  33. ldm/models/diffusion/ddim.py +62 -6
  34. ldm/models/diffusion/plms.py +236 -0
  35. ldm/modules/diffusionmodules/openaimodel.py +39 -14
  36. ldm/modules/encoders/modules.py +103 -0
  37. ldm/modules/x_transformer.py +1 -1
  38. ldm/util.py +120 -3
  39. scripts/img2img.py +293 -0
  40. scripts/knn2img.py +398 -0
  41. scripts/latent_imagenet_diffusion.ipynb.REMOVED.git-id +1 -0
  42. scripts/train_searcher.py +147 -0
  43. scripts/txt2img.py +279 -0
LICENSE CHANGED
@@ -1,16 +1,9 @@
1
- MIT License
 
 
 
 
2
 
3
- Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
 
15
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
@@ -18,4 +11,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
1
+ All rights reserved by the authors.
2
+ You must not distribute the weights provided to you directly or indirectly without explicit consent of the authors.
3
+ You must not distribute harmful, offensive, dehumanizing content or otherwise harmful representations of people or their environments, cultures, religions, etc. produced with the model weights
4
+ or other generated content described in the "Misuse and Malicious Use" section in the model card.
5
+ The model weights are provided for research purposes only.
6
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
9
  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 
11
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
12
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
13
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
14
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,5 @@
1
- # Latent Diffusion Models
2
- [arXiv](https://arxiv.org/abs/2112.10752) | [BibTeX](#bibtex)
3
-
4
- <p align="center">
5
- <img src=assets/results.gif />
6
- </p>
7
-
8
-
9
 
10
  [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
11
  [Robin Rombach](https://github.com/rromb)\*,
@@ -13,12 +7,19 @@
13
  [Dominik Lorenz](https://github.com/qp-qp)\,
14
  [Patrick Esser](https://github.com/pesser),
15
  [BjΓΆrn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
16
- \* equal contribution
17
 
18
- <p align="center">
19
- <img src=assets/modelfigure.png />
20
- </p>
 
 
 
 
 
 
 
21
 
 
22
  ## Requirements
23
  A suitable [conda](https://conda.io/) environment named `ldm` can be created
24
  and activated with:
@@ -28,176 +29,135 @@ conda env create -f environment.yaml
28
  conda activate ldm
29
  ```
30
 
31
- # Model Zoo
32
-
33
- ## Pretrained Autoencoding Models
34
- ![rec2](assets/reconstruction2.png)
35
-
36
- All models were trained until convergence (no further substantial improvement in rFID).
37
-
38
- | Model | rFID vs val | train steps |PSNR | PSIM | Link | Comments
39
- |-------------------------|------------|----------------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|
40
- | f=4, VQ (Z=8192, d=3) | 0.58 | 533066 | 27.43 +/- 4.26 | 0.53 +/- 0.21 | https://ommer-lab.com/files/latent-diffusion/vq-f4.zip | |
41
- | f=4, VQ (Z=8192, d=3) | 1.06 | 658131 | 25.21 +/- 4.17 | 0.72 +/- 0.26 | https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 | no attention |
42
- | f=8, VQ (Z=16384, d=4) | 1.14 | 971043 | 23.07 +/- 3.99 | 1.17 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | |
43
- | f=8, VQ (Z=256, d=4) | 1.49 | 1608649 | 22.35 +/- 3.81 | 1.26 +/- 0.37 | https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |
44
- | f=16, VQ (Z=16384, d=8) | 5.15 | 1101166 | 20.83 +/- 3.61 | 1.73 +/- 0.43 | https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 | |
45
- | | | | | | | |
46
- | f=4, KL | 0.27 | 176991 | 27.53 +/- 4.54 | 0.55 +/- 0.24 | https://ommer-lab.com/files/latent-diffusion/kl-f4.zip | |
47
- | f=8, KL | 0.90 | 246803 | 24.19 +/- 4.19 | 1.02 +/- 0.35 | https://ommer-lab.com/files/latent-diffusion/kl-f8.zip | |
48
- | f=16, KL (d=16) | 0.87 | 442998 | 24.08 +/- 4.22 | 1.07 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/kl-f16.zip | |
49
- | f=32, KL (d=64) | 2.04 | 406763 | 22.27 +/- 3.93 | 1.41 +/- 0.40 | https://ommer-lab.com/files/latent-diffusion/kl-f32.zip | |
50
 
51
- ### Get the models
52
-
53
- Running the following script downloads und extracts all available pretrained autoencoding models.
54
- ```shell script
55
- bash scripts/download_first_stages.sh
56
  ```
 
 
 
 
57
 
58
- The first stage models can then be found in `models/first_stage_models/<model_spec>`
59
 
 
60
 
 
 
 
 
61
 
62
- ## Pretrained LDMs
63
- | Datset | Task | Model | FID | IS | Prec | Recall | Link | Comments
64
- |---------------------------------|------|--------------|---------------|-----------------|------|------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
65
- | CelebA-HQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0)| 5.11 (5.11) | 3.29 | 0.72 | 0.49 | https://ommer-lab.com/files/latent-diffusion/celeba.zip | |
66
- | FFHQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 4.98 (4.98) | 4.50 (4.50) | 0.73 | 0.50 | https://ommer-lab.com/files/latent-diffusion/ffhq.zip | |
67
- | LSUN-Churches | Unconditional Image Synthesis | LDM-KL-8 (400 DDIM steps, eta=0)| 4.02 (4.02) | 2.72 | 0.64 | 0.52 | https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip | |
68
- | LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
69
- | ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
70
- | Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
71
- | OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
72
- | OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
73
- | Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
74
- | Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
75
 
 
76
 
77
- ### Get the models
78
 
79
- The LDMs listed above can jointly be downloaded and extracted via
 
80
 
81
- ```shell script
82
- bash scripts/download_models.sh
83
- ```
 
 
 
84
 
85
- The models can then be found in `models/ldm/<model_spec>`.
 
 
 
86
 
87
- ### Sampling with unconditional models
88
 
89
- We provide a first script for sampling from our unconditional models. Start it via
90
 
91
- ```shell script
92
- CUDA_VISIBLE_DEVICES=<GPU_ID> python scripts/sample_diffusion.py -r models/ldm/<model_spec>/model.ckpt -l <logdir> -n <\#samples> --batch_size <batch_size> -c <\#ddim steps> -e <\#eta>
93
- ```
94
 
95
- # Inpainting
96
- ![inpainting](assets/inpainting.png)
97
 
98
- Download the pre-trained weights
99
  ```
100
- wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
 
101
  ```
102
-
103
  and sample with
104
  ```
105
- python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inpainting_results
106
  ```
107
- `indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
108
- the examples provided in `data/inpainting_examples`.
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Train your own LDMs
112
-
113
- ## Data preparation
114
-
115
- ### Faces
116
- For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
117
- repository.
118
-
119
- ### LSUN
120
-
121
- The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
122
- We performed a custom split into training and validation images, and provide the corresponding filenames
123
- at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
124
- After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
125
- also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
126
-
127
- ### ImageNet
128
- The code will try to download (through [Academic
129
- Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
130
- is used. However, since ImageNet is quite large, this requires a lot of disk
131
- space and time. If you already have ImageNet on your disk, you can speed things
132
- up by putting the data into
133
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
134
- `~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
135
- of `train`/`validation`. It should have the following structure:
136
-
137
- ```
138
- ${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
139
- β”œβ”€β”€ n01440764
140
- β”‚ β”œβ”€β”€ n01440764_10026.JPEG
141
- β”‚ β”œβ”€β”€ n01440764_10027.JPEG
142
- β”‚ β”œβ”€β”€ ...
143
- β”œβ”€β”€ n01443537
144
- β”‚ β”œβ”€β”€ n01443537_10007.JPEG
145
- β”‚ β”œβ”€β”€ n01443537_10014.JPEG
146
- β”‚ β”œβ”€β”€ ...
147
- β”œβ”€β”€ ...
148
  ```
 
 
 
 
149
 
150
- If you haven't extracted the data, you can also place
151
- `ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
152
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
153
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
154
- extracted into above structure without downloading it again. Note that this
155
- will only happen if neither a folder
156
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
157
- `${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
158
- if you want to force running the dataset preparation again.
159
 
 
160
 
161
- ## Model Training
 
 
162
 
163
- Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
164
-
165
- ### Training autoencoder models
166
-
167
- Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
168
- Training can be started by running
169
  ```
170
- CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
171
  ```
172
- where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
173
- `autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
174
-
175
- For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
176
- repository.
177
 
178
- ### Training LDMs
179
 
180
- In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
181
- Training can be started by running
182
 
183
- ```shell script
184
- CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
185
- ```
186
-
187
- where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
188
- `lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
189
- `lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
190
 
191
- ## Coming Soon...
 
192
 
193
- * More inference scripts for conditional LDMs.
194
- * In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
195
- * We will also release some further pretrained models.
196
 
197
 
198
  ## Comments
199
 
200
- - Our codebase for the diffusion models builds heavily on [OpenAI's codebase](https://github.com/openai/guided-diffusion)
201
  and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
202
  Thanks for open-sourcing!
203
 
@@ -215,6 +175,7 @@ Thanks for open-sourcing!
215
  archivePrefix={arXiv},
216
  primaryClass={cs.CV}
217
  }
 
218
  ```
219
 
220
 
 
1
+ # Stable Diffusion
2
+ *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
 
 
 
 
 
 
3
 
4
  [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
5
  [Robin Rombach](https://github.com/rromb)\*,
 
7
  [Dominik Lorenz](https://github.com/qp-qp)\,
8
  [Patrick Esser](https://github.com/pesser),
9
  [BjΓΆrn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
 
10
 
11
+ which is available on [GitHub](https://github.com/CompVis/latent-diffusion).
12
+
13
+ ![txt2img-stable2](assets/stable-samples/txt2img/merged-0006.png)
14
+ [Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion
15
+ model.
16
+ Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
17
+ Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
18
+ this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
19
+ With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
20
+ See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion).
21
 
22
+
23
  ## Requirements
24
  A suitable [conda](https://conda.io/) environment named `ldm` can be created
25
  and activated with:
 
29
  conda activate ldm
30
  ```
31
 
32
+ You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
34
  ```
35
+ conda install pytorch torchvision -c pytorch
36
+ pip install transformers==4.19.2
37
+ pip install -e .
38
+ ```
39
 
 
40
 
41
+ ## Stable Diffusion v1
42
 
43
+ Stable Diffusion v1 refers to a specific configuration of the model
44
+ architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet
45
+ and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and
46
+ then finetuned on 512x512 images.
47
 
48
+ *Note: Stable Diffusion v1 is a general text-to-image diffusion model and therefore mirrors biases and (mis-)conceptions that are present
49
+ in its training data.
50
+ Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/CompVis/stable-diffusion).
51
+ Research into the safe deployment of general text-to-image models is an ongoing effort. To prevent misuse and harm, we currently provide access to the checkpoints only for [academic research purposes upon request](TODO).
52
+ **This is an experiment in safe and community-driven publication of a capable and general text-to-image model. We are working on a public release with a more permissive license that also incorporates ethical considerations.***
 
 
 
 
 
 
 
 
53
 
54
+ [Request access to Stable Diffusion v1 checkpoints for academic research](TODO)
55
 
56
+ ### Weights
57
 
58
+ We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
59
+ which were trained as follows,
60
 
61
+ - `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
62
+ 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
63
+ - `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
64
+ 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
65
+ filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
66
+ - `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
67
 
68
+ Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
69
+ 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
70
+ steps show the relative improvements of the checkpoints:
71
+ ![sd evaluation results](assets/v1-variants-scores.jpg)
72
 
 
73
 
 
74
 
75
+ ### Text-to-Image with Stable Diffusion
76
+ ![txt2img-stable2](assets/stable-samples/txt2img/merged-0005.png)
77
+ ![txt2img-stable2](assets/stable-samples/txt2img/merged-0007.png)
78
 
79
+ Stable Diffusion is a latent diffusion model conditioned on the (non-pooled) text embeddings of a CLIP ViT-L/14 text encoder.
 
80
 
81
+ After [obtaining the weights](#weights), link them
82
  ```
83
+ mkdir -p models/ldm/stable-diffusion-v1/
84
+ ln -s <path/to/model.ckpt> models/ldm/stable-diffusion-v1/model.ckpt
85
  ```
 
86
  and sample with
87
  ```
88
+ python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
89
  ```
90
+ By default, this uses a guidance scale of `--scale 7.5`, [Katherine Crowson's implementation](https://github.com/CompVis/latent-diffusion/pull/51) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler,
91
+ and renders images of size 512x512 (which it was trained on) in 50 steps. All supported arguments are listed below (type `python scripts/txt2img.py --help`).
92
+
93
+ ```commandline
94
+ usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS]
95
+ [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] [--seed SEED] [--precision {full,autocast}]
96
+
97
+ optional arguments:
98
+ -h, --help show this help message and exit
99
+ --prompt [PROMPT] the prompt to render
100
+ --outdir [OUTDIR] dir to write results to
101
+ --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
102
+ --skip_save do not save individual samples. For speed measurements.
103
+ --ddim_steps DDIM_STEPS
104
+ number of ddim sampling steps
105
+ --plms use plms sampling
106
+ --laion400m uses the LAION400M model
107
+ --fixed_code if enabled, uses the same starting code across samples
108
+ --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
109
+ --n_iter N_ITER sample this often
110
+ --H H image height, in pixel space
111
+ --W W image width, in pixel space
112
+ --C C latent channels
113
+ --f F downsampling factor
114
+ --n_samples N_SAMPLES
115
+ how many samples to produce for each given prompt. A.k.a. batch size
116
+ --n_rows N_ROWS rows in the grid (default: n_samples)
117
+ --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
118
+ --from-file FROM_FILE
119
+ if specified, load prompts from this file
120
+ --config CONFIG path to config which constructs model
121
+ --ckpt CKPT path to checkpoint of model
122
+ --seed SEED the seed (for reproducible sampling)
123
+ --precision {full,autocast}
124
+ evaluate at this precision
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  ```
127
+ Note: The inference config for all v1 versions is designed to be used with EMA-only checkpoints.
128
+ For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
129
+ non-EMA to EMA weights. If you want to examine the effect of EMA vs no EMA, we provide "full" checkpoints
130
+ which contain both types of weights. For these, `use_ema=False` will load and use the non-EMA weights.
131
 
 
 
 
 
 
 
 
 
 
132
 
133
+ ### Image Modification with Stable Diffusion
134
 
135
+ By using a diffusion-denoising mechanism as first proposed by [SDEdit](https://arxiv.org/abs/2108.01073), the model can be used for different
136
+ tasks such as text-guided image-to-image translation and upscaling. Similar to the txt2img sampling script,
137
+ we provide a script to perform image modification with Stable Diffusion.
138
 
139
+ The following describes an example where a rough sketch made in [Pinta](https://www.pinta-project.com/) is converted into a detailed artwork.
 
 
 
 
 
140
  ```
141
+ python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img <path-to-img.jpg> --strength 0.8
142
  ```
143
+ Here, strength is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
144
+ Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. See the following example.
 
 
 
145
 
146
+ **Input**
147
 
148
+ ![sketch-in](assets/stable-samples/img2img/sketch-mountains-input.jpg)
 
149
 
150
+ **Outputs**
 
 
 
 
 
 
151
 
152
+ ![out3](assets/stable-samples/img2img/mountains-3.png)
153
+ ![out2](assets/stable-samples/img2img/mountains-2.png)
154
 
155
+ This procedure can, for example, also be used to upscale samples from the base model.
 
 
156
 
157
 
158
  ## Comments
159
 
160
+ - Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
161
  and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
162
  Thanks for open-sourcing!
163
 
 
175
  archivePrefix={arXiv},
176
  primaryClass={cs.CV}
177
  }
178
+
179
  ```
180
 
181
 
Stable_Diffusion_v1_Model_Card.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Diffusion v1 Model Card
2
+ This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
3
+
4
+ ## Model Details
5
+ - **Developed by:** Robin Rombach, Patrick Esser
6
+ - **Model type:** Diffusion-based text-to-image generation model
7
+ - **Language(s):** English
8
+ - **License:** [Proprietary](LICENSE)
9
+ - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
10
+ - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
11
+ - **Cite as:**
12
+
13
+ @InProceedings{Rombach_2022_CVPR,
14
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
15
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
16
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
17
+ month = {June},
18
+ year = {2022},
19
+ pages = {10684-10695}
20
+ }
21
+
22
+ # Uses
23
+
24
+ ## Direct Use
25
+ The model is intended for research purposes only. Possible research areas and
26
+ tasks include
27
+
28
+ - Safe deployment of models which have the potential to generate harmful content.
29
+ - Probing and understanding the limitations and biases of generative models.
30
+ - Generation of artworks and use in design and other artistic processes.
31
+ - Applications in educational or creative tools.
32
+ - Research on generative models.
33
+
34
+ Excluded uses are described below.
35
+
36
+ ### Misuse, Malicious Use, and Out-of-Scope Use
37
+ _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
38
+
39
+
40
+ The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
41
+ #### Out-of-Scope Use
42
+ The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
43
+ #### Misuse and Malicious Use
44
+ Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
45
+
46
+ - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
47
+ - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
48
+ - Impersonating individuals without their consent.
49
+ - Sexual content without consent of the people who might see it.
50
+ - Mis- and disinformation
51
+ - Representations of egregious violence and gore
52
+ - Sharing of copyrighted or licensed material in violation of its terms of use.
53
+ - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
54
+
55
+ ## Limitations and Bias
56
+
57
+ ### Limitations
58
+
59
+ - The model does not achieve perfect photorealism
60
+ - The model cannot render legible text
61
+ - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to β€œA red cube on top of a blue sphere”
62
+ - Faces and people in general may not be generated properly.
63
+ - The model was trained mainly with English captions and will not work as well in other languages.
64
+ - The autoencoding part of the model is lossy
65
+ - The model was trained on a large-scale dataset
66
+ [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
67
+ and is not fit for product use without additional safety mechanisms and
68
+ considerations.
69
+
70
+ ### Bias
71
+ While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
72
+ Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
73
+ which consists of images that are primarily limited to English descriptions.
74
+ Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
75
+ This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
76
+ ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
77
+
78
+
79
+ ## Training
80
+
81
+ **Training Data**
82
+ The model developers used the following dataset for training the model:
83
+
84
+ - LAION-2B (en) and subsets thereof (see next section)
85
+
86
+ **Training Procedure**
87
+ Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
88
+
89
+ - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
90
+ - Text prompts are encoded through a ViT-L/14 text-encoder.
91
+ - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
92
+ - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
93
+
94
+ We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
95
+ which were trained as follows,
96
+
97
+ - `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
98
+ 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
99
+ - `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
100
+ 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
101
+ filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
102
+ - `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
103
+
104
+
105
+ - **Hardware:** 32 x 8 x A100 GPUs
106
+ - **Optimizer:** AdamW
107
+ - **Gradient Accumulations**: 2
108
+ - **Batch:** 32 x 8 x 2 x 4 = 2048
109
+ - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
110
+
111
+ ## Evaluation Results
112
+ Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
113
+ 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
114
+ steps show the relative improvements of the checkpoints:
115
+
116
+ ![pareto](assets/v1-variants-scores.jpg)
117
+
118
+ Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
119
+ ## Environmental Impact
120
+
121
+ **Stable Diffusion v1** **Estimated Emissions**
122
+ Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
123
+
124
+ - **Hardware Type:** A100 PCIe 40GB
125
+ - **Hours used:** 150000
126
+ - **Cloud Provider:** AWS
127
+ - **Compute Region:** US-east
128
+ - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
129
+ ## Citation
130
+ @InProceedings{Rombach_2022_CVPR,
131
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
132
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
133
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
134
+ month = {June},
135
+ year = {2022},
136
+ pages = {10684-10695}
137
+ }
138
+
139
+ *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
140
+
assets/a-painting-of-a-fire.png ADDED
assets/a-photograph-of-a-fire.png ADDED
assets/a-shirt-with-a-fire-printed-on-it.png ADDED
assets/a-shirt-with-the-inscription-'fire'.png ADDED
assets/a-watercolor-painting-of-a-fire.png ADDED
assets/birdhouse.png ADDED
assets/fire.png ADDED
assets/rdm-preview.jpg ADDED
assets/stable-samples/img2img/mountains-1.png ADDED
assets/stable-samples/img2img/mountains-2.png ADDED
assets/stable-samples/img2img/mountains-3.png ADDED
assets/stable-samples/img2img/sketch-mountains-input.jpg ADDED
assets/stable-samples/img2img/upscaling-in.png.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ 501c31c21751664957e69ce52cad1818b6d2f4ce
assets/stable-samples/img2img/upscaling-out.png.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ 1c4bb25a779f34d86b2d90e584ac67af91bb1303
assets/stable-samples/txt2img/000002025.png ADDED
assets/stable-samples/txt2img/000002035.png ADDED
assets/stable-samples/txt2img/merged-0005.png.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ ca0a1af206555f0f208a1ab879e95efedc1b1c5b
assets/stable-samples/txt2img/merged-0006.png.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ 999f3703230580e8c89e9081abd6a1f8f50896d4
assets/stable-samples/txt2img/merged-0007.png.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ af390acaf601283782d6f479d4cade4d78e30b26
assets/the-earth-is-on-fire,-oil-on-canvas.png ADDED
assets/txt2img-convsample.png ADDED
assets/txt2img-preview.png.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ 51ee1c235dfdc63d4c41de7d303d03730e43c33c
assets/v1-variants-scores.jpg ADDED
configs/latent-diffusion/cin256-v2.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.0195
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: class_label
12
+ image_size: 64
13
+ channels: 3
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss
17
+ use_ema: False
18
+
19
+ unet_config:
20
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ image_size: 64
23
+ in_channels: 3
24
+ out_channels: 3
25
+ model_channels: 192
26
+ attention_resolutions:
27
+ - 8
28
+ - 4
29
+ - 2
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 5
36
+ num_heads: 1
37
+ use_spatial_transformer: true
38
+ transformer_depth: 1
39
+ context_dim: 512
40
+
41
+ first_stage_config:
42
+ target: ldm.models.autoencoder.VQModelInterface
43
+ params:
44
+ embed_dim: 3
45
+ n_embed: 8192
46
+ ddconfig:
47
+ double_z: false
48
+ z_channels: 3
49
+ resolution: 256
50
+ in_channels: 3
51
+ out_ch: 3
52
+ ch: 128
53
+ ch_mult:
54
+ - 1
55
+ - 2
56
+ - 4
57
+ num_res_blocks: 2
58
+ attn_resolutions: []
59
+ dropout: 0.0
60
+ lossconfig:
61
+ target: torch.nn.Identity
62
+
63
+ cond_stage_config:
64
+ target: ldm.modules.encoders.modules.ClassEmbedder
65
+ params:
66
+ n_classes: 1001
67
+ embed_dim: 512
68
+ key: class_label
configs/latent-diffusion/txt2img-1p4B-eval.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 5.0e-05
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.012
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: image
11
+ cond_stage_key: caption
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: true
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ unet_config:
21
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
+ params:
23
+ image_size: 32
24
+ in_channels: 4
25
+ out_channels: 4
26
+ model_channels: 320
27
+ attention_resolutions:
28
+ - 4
29
+ - 2
30
+ - 1
31
+ num_res_blocks: 2
32
+ channel_mult:
33
+ - 1
34
+ - 2
35
+ - 4
36
+ - 4
37
+ num_heads: 8
38
+ use_spatial_transformer: true
39
+ transformer_depth: 1
40
+ context_dim: 1280
41
+ use_checkpoint: true
42
+ legacy: False
43
+
44
+ first_stage_config:
45
+ target: ldm.models.autoencoder.AutoencoderKL
46
+ params:
47
+ embed_dim: 4
48
+ monitor: val/rec_loss
49
+ ddconfig:
50
+ double_z: true
51
+ z_channels: 4
52
+ resolution: 256
53
+ in_channels: 3
54
+ out_ch: 3
55
+ ch: 128
56
+ ch_mult:
57
+ - 1
58
+ - 2
59
+ - 4
60
+ - 4
61
+ num_res_blocks: 2
62
+ attn_resolutions: []
63
+ dropout: 0.0
64
+ lossconfig:
65
+ target: torch.nn.Identity
66
+
67
+ cond_stage_config:
68
+ target: ldm.modules.encoders.modules.BERTEmbedder
69
+ params:
70
+ n_embed: 1280
71
+ n_layer: 32
configs/retrieval-augmented-diffusion/768x768.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.0015
6
+ linear_end: 0.015
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: jpg
11
+ cond_stage_key: nix
12
+ image_size: 48
13
+ channels: 16
14
+ cond_stage_trainable: false
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_by_std: false
18
+ scale_factor: 0.22765929
19
+ unet_config:
20
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21
+ params:
22
+ image_size: 48
23
+ in_channels: 16
24
+ out_channels: 16
25
+ model_channels: 448
26
+ attention_resolutions:
27
+ - 4
28
+ - 2
29
+ - 1
30
+ num_res_blocks: 2
31
+ channel_mult:
32
+ - 1
33
+ - 2
34
+ - 3
35
+ - 4
36
+ use_scale_shift_norm: false
37
+ resblock_updown: false
38
+ num_head_channels: 32
39
+ use_spatial_transformer: true
40
+ transformer_depth: 1
41
+ context_dim: 768
42
+ use_checkpoint: true
43
+ first_stage_config:
44
+ target: ldm.models.autoencoder.AutoencoderKL
45
+ params:
46
+ monitor: val/rec_loss
47
+ embed_dim: 16
48
+ ddconfig:
49
+ double_z: true
50
+ z_channels: 16
51
+ resolution: 256
52
+ in_channels: 3
53
+ out_ch: 3
54
+ ch: 128
55
+ ch_mult:
56
+ - 1
57
+ - 1
58
+ - 2
59
+ - 2
60
+ - 4
61
+ num_res_blocks: 2
62
+ attn_resolutions:
63
+ - 16
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
67
+ cond_stage_config:
68
+ target: torch.nn.Identity
configs/stable-diffusion/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data/imagenet_clsidx_to_label.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0: 'tench, Tinca tinca',
2
+ 1: 'goldfish, Carassius auratus',
3
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
4
+ 3: 'tiger shark, Galeocerdo cuvieri',
5
+ 4: 'hammerhead, hammerhead shark',
6
+ 5: 'electric ray, crampfish, numbfish, torpedo',
7
+ 6: 'stingray',
8
+ 7: 'cock',
9
+ 8: 'hen',
10
+ 9: 'ostrich, Struthio camelus',
11
+ 10: 'brambling, Fringilla montifringilla',
12
+ 11: 'goldfinch, Carduelis carduelis',
13
+ 12: 'house finch, linnet, Carpodacus mexicanus',
14
+ 13: 'junco, snowbird',
15
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
16
+ 15: 'robin, American robin, Turdus migratorius',
17
+ 16: 'bulbul',
18
+ 17: 'jay',
19
+ 18: 'magpie',
20
+ 19: 'chickadee',
21
+ 20: 'water ouzel, dipper',
22
+ 21: 'kite',
23
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
24
+ 23: 'vulture',
25
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
26
+ 25: 'European fire salamander, Salamandra salamandra',
27
+ 26: 'common newt, Triturus vulgaris',
28
+ 27: 'eft',
29
+ 28: 'spotted salamander, Ambystoma maculatum',
30
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
31
+ 30: 'bullfrog, Rana catesbeiana',
32
+ 31: 'tree frog, tree-frog',
33
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
34
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
35
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
36
+ 35: 'mud turtle',
37
+ 36: 'terrapin',
38
+ 37: 'box turtle, box tortoise',
39
+ 38: 'banded gecko',
40
+ 39: 'common iguana, iguana, Iguana iguana',
41
+ 40: 'American chameleon, anole, Anolis carolinensis',
42
+ 41: 'whiptail, whiptail lizard',
43
+ 42: 'agama',
44
+ 43: 'frilled lizard, Chlamydosaurus kingi',
45
+ 44: 'alligator lizard',
46
+ 45: 'Gila monster, Heloderma suspectum',
47
+ 46: 'green lizard, Lacerta viridis',
48
+ 47: 'African chameleon, Chamaeleo chamaeleon',
49
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
50
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
51
+ 50: 'American alligator, Alligator mississipiensis',
52
+ 51: 'triceratops',
53
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
54
+ 53: 'ringneck snake, ring-necked snake, ring snake',
55
+ 54: 'hognose snake, puff adder, sand viper',
56
+ 55: 'green snake, grass snake',
57
+ 56: 'king snake, kingsnake',
58
+ 57: 'garter snake, grass snake',
59
+ 58: 'water snake',
60
+ 59: 'vine snake',
61
+ 60: 'night snake, Hypsiglena torquata',
62
+ 61: 'boa constrictor, Constrictor constrictor',
63
+ 62: 'rock python, rock snake, Python sebae',
64
+ 63: 'Indian cobra, Naja naja',
65
+ 64: 'green mamba',
66
+ 65: 'sea snake',
67
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
68
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
69
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
70
+ 69: 'trilobite',
71
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
72
+ 71: 'scorpion',
73
+ 72: 'black and gold garden spider, Argiope aurantia',
74
+ 73: 'barn spider, Araneus cavaticus',
75
+ 74: 'garden spider, Aranea diademata',
76
+ 75: 'black widow, Latrodectus mactans',
77
+ 76: 'tarantula',
78
+ 77: 'wolf spider, hunting spider',
79
+ 78: 'tick',
80
+ 79: 'centipede',
81
+ 80: 'black grouse',
82
+ 81: 'ptarmigan',
83
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
84
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
85
+ 84: 'peacock',
86
+ 85: 'quail',
87
+ 86: 'partridge',
88
+ 87: 'African grey, African gray, Psittacus erithacus',
89
+ 88: 'macaw',
90
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
91
+ 90: 'lorikeet',
92
+ 91: 'coucal',
93
+ 92: 'bee eater',
94
+ 93: 'hornbill',
95
+ 94: 'hummingbird',
96
+ 95: 'jacamar',
97
+ 96: 'toucan',
98
+ 97: 'drake',
99
+ 98: 'red-breasted merganser, Mergus serrator',
100
+ 99: 'goose',
101
+ 100: 'black swan, Cygnus atratus',
102
+ 101: 'tusker',
103
+ 102: 'echidna, spiny anteater, anteater',
104
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
105
+ 104: 'wallaby, brush kangaroo',
106
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
107
+ 106: 'wombat',
108
+ 107: 'jellyfish',
109
+ 108: 'sea anemone, anemone',
110
+ 109: 'brain coral',
111
+ 110: 'flatworm, platyhelminth',
112
+ 111: 'nematode, nematode worm, roundworm',
113
+ 112: 'conch',
114
+ 113: 'snail',
115
+ 114: 'slug',
116
+ 115: 'sea slug, nudibranch',
117
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
118
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
119
+ 118: 'Dungeness crab, Cancer magister',
120
+ 119: 'rock crab, Cancer irroratus',
121
+ 120: 'fiddler crab',
122
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
123
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
124
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
125
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
126
+ 125: 'hermit crab',
127
+ 126: 'isopod',
128
+ 127: 'white stork, Ciconia ciconia',
129
+ 128: 'black stork, Ciconia nigra',
130
+ 129: 'spoonbill',
131
+ 130: 'flamingo',
132
+ 131: 'little blue heron, Egretta caerulea',
133
+ 132: 'American egret, great white heron, Egretta albus',
134
+ 133: 'bittern',
135
+ 134: 'crane',
136
+ 135: 'limpkin, Aramus pictus',
137
+ 136: 'European gallinule, Porphyrio porphyrio',
138
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
139
+ 138: 'bustard',
140
+ 139: 'ruddy turnstone, Arenaria interpres',
141
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
142
+ 141: 'redshank, Tringa totanus',
143
+ 142: 'dowitcher',
144
+ 143: 'oystercatcher, oyster catcher',
145
+ 144: 'pelican',
146
+ 145: 'king penguin, Aptenodytes patagonica',
147
+ 146: 'albatross, mollymawk',
148
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
149
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
150
+ 149: 'dugong, Dugong dugon',
151
+ 150: 'sea lion',
152
+ 151: 'Chihuahua',
153
+ 152: 'Japanese spaniel',
154
+ 153: 'Maltese dog, Maltese terrier, Maltese',
155
+ 154: 'Pekinese, Pekingese, Peke',
156
+ 155: 'Shih-Tzu',
157
+ 156: 'Blenheim spaniel',
158
+ 157: 'papillon',
159
+ 158: 'toy terrier',
160
+ 159: 'Rhodesian ridgeback',
161
+ 160: 'Afghan hound, Afghan',
162
+ 161: 'basset, basset hound',
163
+ 162: 'beagle',
164
+ 163: 'bloodhound, sleuthhound',
165
+ 164: 'bluetick',
166
+ 165: 'black-and-tan coonhound',
167
+ 166: 'Walker hound, Walker foxhound',
168
+ 167: 'English foxhound',
169
+ 168: 'redbone',
170
+ 169: 'borzoi, Russian wolfhound',
171
+ 170: 'Irish wolfhound',
172
+ 171: 'Italian greyhound',
173
+ 172: 'whippet',
174
+ 173: 'Ibizan hound, Ibizan Podenco',
175
+ 174: 'Norwegian elkhound, elkhound',
176
+ 175: 'otterhound, otter hound',
177
+ 176: 'Saluki, gazelle hound',
178
+ 177: 'Scottish deerhound, deerhound',
179
+ 178: 'Weimaraner',
180
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
181
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
182
+ 181: 'Bedlington terrier',
183
+ 182: 'Border terrier',
184
+ 183: 'Kerry blue terrier',
185
+ 184: 'Irish terrier',
186
+ 185: 'Norfolk terrier',
187
+ 186: 'Norwich terrier',
188
+ 187: 'Yorkshire terrier',
189
+ 188: 'wire-haired fox terrier',
190
+ 189: 'Lakeland terrier',
191
+ 190: 'Sealyham terrier, Sealyham',
192
+ 191: 'Airedale, Airedale terrier',
193
+ 192: 'cairn, cairn terrier',
194
+ 193: 'Australian terrier',
195
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
196
+ 195: 'Boston bull, Boston terrier',
197
+ 196: 'miniature schnauzer',
198
+ 197: 'giant schnauzer',
199
+ 198: 'standard schnauzer',
200
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
201
+ 200: 'Tibetan terrier, chrysanthemum dog',
202
+ 201: 'silky terrier, Sydney silky',
203
+ 202: 'soft-coated wheaten terrier',
204
+ 203: 'West Highland white terrier',
205
+ 204: 'Lhasa, Lhasa apso',
206
+ 205: 'flat-coated retriever',
207
+ 206: 'curly-coated retriever',
208
+ 207: 'golden retriever',
209
+ 208: 'Labrador retriever',
210
+ 209: 'Chesapeake Bay retriever',
211
+ 210: 'German short-haired pointer',
212
+ 211: 'vizsla, Hungarian pointer',
213
+ 212: 'English setter',
214
+ 213: 'Irish setter, red setter',
215
+ 214: 'Gordon setter',
216
+ 215: 'Brittany spaniel',
217
+ 216: 'clumber, clumber spaniel',
218
+ 217: 'English springer, English springer spaniel',
219
+ 218: 'Welsh springer spaniel',
220
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
221
+ 220: 'Sussex spaniel',
222
+ 221: 'Irish water spaniel',
223
+ 222: 'kuvasz',
224
+ 223: 'schipperke',
225
+ 224: 'groenendael',
226
+ 225: 'malinois',
227
+ 226: 'briard',
228
+ 227: 'kelpie',
229
+ 228: 'komondor',
230
+ 229: 'Old English sheepdog, bobtail',
231
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
232
+ 231: 'collie',
233
+ 232: 'Border collie',
234
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
235
+ 234: 'Rottweiler',
236
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
237
+ 236: 'Doberman, Doberman pinscher',
238
+ 237: 'miniature pinscher',
239
+ 238: 'Greater Swiss Mountain dog',
240
+ 239: 'Bernese mountain dog',
241
+ 240: 'Appenzeller',
242
+ 241: 'EntleBucher',
243
+ 242: 'boxer',
244
+ 243: 'bull mastiff',
245
+ 244: 'Tibetan mastiff',
246
+ 245: 'French bulldog',
247
+ 246: 'Great Dane',
248
+ 247: 'Saint Bernard, St Bernard',
249
+ 248: 'Eskimo dog, husky',
250
+ 249: 'malamute, malemute, Alaskan malamute',
251
+ 250: 'Siberian husky',
252
+ 251: 'dalmatian, coach dog, carriage dog',
253
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
254
+ 253: 'basenji',
255
+ 254: 'pug, pug-dog',
256
+ 255: 'Leonberg',
257
+ 256: 'Newfoundland, Newfoundland dog',
258
+ 257: 'Great Pyrenees',
259
+ 258: 'Samoyed, Samoyede',
260
+ 259: 'Pomeranian',
261
+ 260: 'chow, chow chow',
262
+ 261: 'keeshond',
263
+ 262: 'Brabancon griffon',
264
+ 263: 'Pembroke, Pembroke Welsh corgi',
265
+ 264: 'Cardigan, Cardigan Welsh corgi',
266
+ 265: 'toy poodle',
267
+ 266: 'miniature poodle',
268
+ 267: 'standard poodle',
269
+ 268: 'Mexican hairless',
270
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
271
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
272
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
273
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
274
+ 273: 'dingo, warrigal, warragal, Canis dingo',
275
+ 274: 'dhole, Cuon alpinus',
276
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
277
+ 276: 'hyena, hyaena',
278
+ 277: 'red fox, Vulpes vulpes',
279
+ 278: 'kit fox, Vulpes macrotis',
280
+ 279: 'Arctic fox, white fox, Alopex lagopus',
281
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
282
+ 281: 'tabby, tabby cat',
283
+ 282: 'tiger cat',
284
+ 283: 'Persian cat',
285
+ 284: 'Siamese cat, Siamese',
286
+ 285: 'Egyptian cat',
287
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
288
+ 287: 'lynx, catamount',
289
+ 288: 'leopard, Panthera pardus',
290
+ 289: 'snow leopard, ounce, Panthera uncia',
291
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
292
+ 291: 'lion, king of beasts, Panthera leo',
293
+ 292: 'tiger, Panthera tigris',
294
+ 293: 'cheetah, chetah, Acinonyx jubatus',
295
+ 294: 'brown bear, bruin, Ursus arctos',
296
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
297
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
298
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
299
+ 298: 'mongoose',
300
+ 299: 'meerkat, mierkat',
301
+ 300: 'tiger beetle',
302
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
303
+ 302: 'ground beetle, carabid beetle',
304
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
305
+ 304: 'leaf beetle, chrysomelid',
306
+ 305: 'dung beetle',
307
+ 306: 'rhinoceros beetle',
308
+ 307: 'weevil',
309
+ 308: 'fly',
310
+ 309: 'bee',
311
+ 310: 'ant, emmet, pismire',
312
+ 311: 'grasshopper, hopper',
313
+ 312: 'cricket',
314
+ 313: 'walking stick, walkingstick, stick insect',
315
+ 314: 'cockroach, roach',
316
+ 315: 'mantis, mantid',
317
+ 316: 'cicada, cicala',
318
+ 317: 'leafhopper',
319
+ 318: 'lacewing, lacewing fly',
320
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
321
+ 320: 'damselfly',
322
+ 321: 'admiral',
323
+ 322: 'ringlet, ringlet butterfly',
324
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
325
+ 324: 'cabbage butterfly',
326
+ 325: 'sulphur butterfly, sulfur butterfly',
327
+ 326: 'lycaenid, lycaenid butterfly',
328
+ 327: 'starfish, sea star',
329
+ 328: 'sea urchin',
330
+ 329: 'sea cucumber, holothurian',
331
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
332
+ 331: 'hare',
333
+ 332: 'Angora, Angora rabbit',
334
+ 333: 'hamster',
335
+ 334: 'porcupine, hedgehog',
336
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
337
+ 336: 'marmot',
338
+ 337: 'beaver',
339
+ 338: 'guinea pig, Cavia cobaya',
340
+ 339: 'sorrel',
341
+ 340: 'zebra',
342
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
343
+ 342: 'wild boar, boar, Sus scrofa',
344
+ 343: 'warthog',
345
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
346
+ 345: 'ox',
347
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
348
+ 347: 'bison',
349
+ 348: 'ram, tup',
350
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
351
+ 350: 'ibex, Capra ibex',
352
+ 351: 'hartebeest',
353
+ 352: 'impala, Aepyceros melampus',
354
+ 353: 'gazelle',
355
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
356
+ 355: 'llama',
357
+ 356: 'weasel',
358
+ 357: 'mink',
359
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
360
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
361
+ 360: 'otter',
362
+ 361: 'skunk, polecat, wood pussy',
363
+ 362: 'badger',
364
+ 363: 'armadillo',
365
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
366
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
367
+ 366: 'gorilla, Gorilla gorilla',
368
+ 367: 'chimpanzee, chimp, Pan troglodytes',
369
+ 368: 'gibbon, Hylobates lar',
370
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
371
+ 370: 'guenon, guenon monkey',
372
+ 371: 'patas, hussar monkey, Erythrocebus patas',
373
+ 372: 'baboon',
374
+ 373: 'macaque',
375
+ 374: 'langur',
376
+ 375: 'colobus, colobus monkey',
377
+ 376: 'proboscis monkey, Nasalis larvatus',
378
+ 377: 'marmoset',
379
+ 378: 'capuchin, ringtail, Cebus capucinus',
380
+ 379: 'howler monkey, howler',
381
+ 380: 'titi, titi monkey',
382
+ 381: 'spider monkey, Ateles geoffroyi',
383
+ 382: 'squirrel monkey, Saimiri sciureus',
384
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
385
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
386
+ 385: 'Indian elephant, Elephas maximus',
387
+ 386: 'African elephant, Loxodonta africana',
388
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
389
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
390
+ 389: 'barracouta, snoek',
391
+ 390: 'eel',
392
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
393
+ 392: 'rock beauty, Holocanthus tricolor',
394
+ 393: 'anemone fish',
395
+ 394: 'sturgeon',
396
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
397
+ 396: 'lionfish',
398
+ 397: 'puffer, pufferfish, blowfish, globefish',
399
+ 398: 'abacus',
400
+ 399: 'abaya',
401
+ 400: "academic gown, academic robe, judge's robe",
402
+ 401: 'accordion, piano accordion, squeeze box',
403
+ 402: 'acoustic guitar',
404
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
405
+ 404: 'airliner',
406
+ 405: 'airship, dirigible',
407
+ 406: 'altar',
408
+ 407: 'ambulance',
409
+ 408: 'amphibian, amphibious vehicle',
410
+ 409: 'analog clock',
411
+ 410: 'apiary, bee house',
412
+ 411: 'apron',
413
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
414
+ 413: 'assault rifle, assault gun',
415
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
416
+ 415: 'bakery, bakeshop, bakehouse',
417
+ 416: 'balance beam, beam',
418
+ 417: 'balloon',
419
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
420
+ 419: 'Band Aid',
421
+ 420: 'banjo',
422
+ 421: 'bannister, banister, balustrade, balusters, handrail',
423
+ 422: 'barbell',
424
+ 423: 'barber chair',
425
+ 424: 'barbershop',
426
+ 425: 'barn',
427
+ 426: 'barometer',
428
+ 427: 'barrel, cask',
429
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
430
+ 429: 'baseball',
431
+ 430: 'basketball',
432
+ 431: 'bassinet',
433
+ 432: 'bassoon',
434
+ 433: 'bathing cap, swimming cap',
435
+ 434: 'bath towel',
436
+ 435: 'bathtub, bathing tub, bath, tub',
437
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
438
+ 437: 'beacon, lighthouse, beacon light, pharos',
439
+ 438: 'beaker',
440
+ 439: 'bearskin, busby, shako',
441
+ 440: 'beer bottle',
442
+ 441: 'beer glass',
443
+ 442: 'bell cote, bell cot',
444
+ 443: 'bib',
445
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
446
+ 445: 'bikini, two-piece',
447
+ 446: 'binder, ring-binder',
448
+ 447: 'binoculars, field glasses, opera glasses',
449
+ 448: 'birdhouse',
450
+ 449: 'boathouse',
451
+ 450: 'bobsled, bobsleigh, bob',
452
+ 451: 'bolo tie, bolo, bola tie, bola',
453
+ 452: 'bonnet, poke bonnet',
454
+ 453: 'bookcase',
455
+ 454: 'bookshop, bookstore, bookstall',
456
+ 455: 'bottlecap',
457
+ 456: 'bow',
458
+ 457: 'bow tie, bow-tie, bowtie',
459
+ 458: 'brass, memorial tablet, plaque',
460
+ 459: 'brassiere, bra, bandeau',
461
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
462
+ 461: 'breastplate, aegis, egis',
463
+ 462: 'broom',
464
+ 463: 'bucket, pail',
465
+ 464: 'buckle',
466
+ 465: 'bulletproof vest',
467
+ 466: 'bullet train, bullet',
468
+ 467: 'butcher shop, meat market',
469
+ 468: 'cab, hack, taxi, taxicab',
470
+ 469: 'caldron, cauldron',
471
+ 470: 'candle, taper, wax light',
472
+ 471: 'cannon',
473
+ 472: 'canoe',
474
+ 473: 'can opener, tin opener',
475
+ 474: 'cardigan',
476
+ 475: 'car mirror',
477
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
478
+ 477: "carpenter's kit, tool kit",
479
+ 478: 'carton',
480
+ 479: 'car wheel',
481
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
482
+ 481: 'cassette',
483
+ 482: 'cassette player',
484
+ 483: 'castle',
485
+ 484: 'catamaran',
486
+ 485: 'CD player',
487
+ 486: 'cello, violoncello',
488
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
489
+ 488: 'chain',
490
+ 489: 'chainlink fence',
491
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
492
+ 491: 'chain saw, chainsaw',
493
+ 492: 'chest',
494
+ 493: 'chiffonier, commode',
495
+ 494: 'chime, bell, gong',
496
+ 495: 'china cabinet, china closet',
497
+ 496: 'Christmas stocking',
498
+ 497: 'church, church building',
499
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
500
+ 499: 'cleaver, meat cleaver, chopper',
501
+ 500: 'cliff dwelling',
502
+ 501: 'cloak',
503
+ 502: 'clog, geta, patten, sabot',
504
+ 503: 'cocktail shaker',
505
+ 504: 'coffee mug',
506
+ 505: 'coffeepot',
507
+ 506: 'coil, spiral, volute, whorl, helix',
508
+ 507: 'combination lock',
509
+ 508: 'computer keyboard, keypad',
510
+ 509: 'confectionery, confectionary, candy store',
511
+ 510: 'container ship, containership, container vessel',
512
+ 511: 'convertible',
513
+ 512: 'corkscrew, bottle screw',
514
+ 513: 'cornet, horn, trumpet, trump',
515
+ 514: 'cowboy boot',
516
+ 515: 'cowboy hat, ten-gallon hat',
517
+ 516: 'cradle',
518
+ 517: 'crane',
519
+ 518: 'crash helmet',
520
+ 519: 'crate',
521
+ 520: 'crib, cot',
522
+ 521: 'Crock Pot',
523
+ 522: 'croquet ball',
524
+ 523: 'crutch',
525
+ 524: 'cuirass',
526
+ 525: 'dam, dike, dyke',
527
+ 526: 'desk',
528
+ 527: 'desktop computer',
529
+ 528: 'dial telephone, dial phone',
530
+ 529: 'diaper, nappy, napkin',
531
+ 530: 'digital clock',
532
+ 531: 'digital watch',
533
+ 532: 'dining table, board',
534
+ 533: 'dishrag, dishcloth',
535
+ 534: 'dishwasher, dish washer, dishwashing machine',
536
+ 535: 'disk brake, disc brake',
537
+ 536: 'dock, dockage, docking facility',
538
+ 537: 'dogsled, dog sled, dog sleigh',
539
+ 538: 'dome',
540
+ 539: 'doormat, welcome mat',
541
+ 540: 'drilling platform, offshore rig',
542
+ 541: 'drum, membranophone, tympan',
543
+ 542: 'drumstick',
544
+ 543: 'dumbbell',
545
+ 544: 'Dutch oven',
546
+ 545: 'electric fan, blower',
547
+ 546: 'electric guitar',
548
+ 547: 'electric locomotive',
549
+ 548: 'entertainment center',
550
+ 549: 'envelope',
551
+ 550: 'espresso maker',
552
+ 551: 'face powder',
553
+ 552: 'feather boa, boa',
554
+ 553: 'file, file cabinet, filing cabinet',
555
+ 554: 'fireboat',
556
+ 555: 'fire engine, fire truck',
557
+ 556: 'fire screen, fireguard',
558
+ 557: 'flagpole, flagstaff',
559
+ 558: 'flute, transverse flute',
560
+ 559: 'folding chair',
561
+ 560: 'football helmet',
562
+ 561: 'forklift',
563
+ 562: 'fountain',
564
+ 563: 'fountain pen',
565
+ 564: 'four-poster',
566
+ 565: 'freight car',
567
+ 566: 'French horn, horn',
568
+ 567: 'frying pan, frypan, skillet',
569
+ 568: 'fur coat',
570
+ 569: 'garbage truck, dustcart',
571
+ 570: 'gasmask, respirator, gas helmet',
572
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
573
+ 572: 'goblet',
574
+ 573: 'go-kart',
575
+ 574: 'golf ball',
576
+ 575: 'golfcart, golf cart',
577
+ 576: 'gondola',
578
+ 577: 'gong, tam-tam',
579
+ 578: 'gown',
580
+ 579: 'grand piano, grand',
581
+ 580: 'greenhouse, nursery, glasshouse',
582
+ 581: 'grille, radiator grille',
583
+ 582: 'grocery store, grocery, food market, market',
584
+ 583: 'guillotine',
585
+ 584: 'hair slide',
586
+ 585: 'hair spray',
587
+ 586: 'half track',
588
+ 587: 'hammer',
589
+ 588: 'hamper',
590
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
591
+ 590: 'hand-held computer, hand-held microcomputer',
592
+ 591: 'handkerchief, hankie, hanky, hankey',
593
+ 592: 'hard disc, hard disk, fixed disk',
594
+ 593: 'harmonica, mouth organ, harp, mouth harp',
595
+ 594: 'harp',
596
+ 595: 'harvester, reaper',
597
+ 596: 'hatchet',
598
+ 597: 'holster',
599
+ 598: 'home theater, home theatre',
600
+ 599: 'honeycomb',
601
+ 600: 'hook, claw',
602
+ 601: 'hoopskirt, crinoline',
603
+ 602: 'horizontal bar, high bar',
604
+ 603: 'horse cart, horse-cart',
605
+ 604: 'hourglass',
606
+ 605: 'iPod',
607
+ 606: 'iron, smoothing iron',
608
+ 607: "jack-o'-lantern",
609
+ 608: 'jean, blue jean, denim',
610
+ 609: 'jeep, landrover',
611
+ 610: 'jersey, T-shirt, tee shirt',
612
+ 611: 'jigsaw puzzle',
613
+ 612: 'jinrikisha, ricksha, rickshaw',
614
+ 613: 'joystick',
615
+ 614: 'kimono',
616
+ 615: 'knee pad',
617
+ 616: 'knot',
618
+ 617: 'lab coat, laboratory coat',
619
+ 618: 'ladle',
620
+ 619: 'lampshade, lamp shade',
621
+ 620: 'laptop, laptop computer',
622
+ 621: 'lawn mower, mower',
623
+ 622: 'lens cap, lens cover',
624
+ 623: 'letter opener, paper knife, paperknife',
625
+ 624: 'library',
626
+ 625: 'lifeboat',
627
+ 626: 'lighter, light, igniter, ignitor',
628
+ 627: 'limousine, limo',
629
+ 628: 'liner, ocean liner',
630
+ 629: 'lipstick, lip rouge',
631
+ 630: 'Loafer',
632
+ 631: 'lotion',
633
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
634
+ 633: "loupe, jeweler's loupe",
635
+ 634: 'lumbermill, sawmill',
636
+ 635: 'magnetic compass',
637
+ 636: 'mailbag, postbag',
638
+ 637: 'mailbox, letter box',
639
+ 638: 'maillot',
640
+ 639: 'maillot, tank suit',
641
+ 640: 'manhole cover',
642
+ 641: 'maraca',
643
+ 642: 'marimba, xylophone',
644
+ 643: 'mask',
645
+ 644: 'matchstick',
646
+ 645: 'maypole',
647
+ 646: 'maze, labyrinth',
648
+ 647: 'measuring cup',
649
+ 648: 'medicine chest, medicine cabinet',
650
+ 649: 'megalith, megalithic structure',
651
+ 650: 'microphone, mike',
652
+ 651: 'microwave, microwave oven',
653
+ 652: 'military uniform',
654
+ 653: 'milk can',
655
+ 654: 'minibus',
656
+ 655: 'miniskirt, mini',
657
+ 656: 'minivan',
658
+ 657: 'missile',
659
+ 658: 'mitten',
660
+ 659: 'mixing bowl',
661
+ 660: 'mobile home, manufactured home',
662
+ 661: 'Model T',
663
+ 662: 'modem',
664
+ 663: 'monastery',
665
+ 664: 'monitor',
666
+ 665: 'moped',
667
+ 666: 'mortar',
668
+ 667: 'mortarboard',
669
+ 668: 'mosque',
670
+ 669: 'mosquito net',
671
+ 670: 'motor scooter, scooter',
672
+ 671: 'mountain bike, all-terrain bike, off-roader',
673
+ 672: 'mountain tent',
674
+ 673: 'mouse, computer mouse',
675
+ 674: 'mousetrap',
676
+ 675: 'moving van',
677
+ 676: 'muzzle',
678
+ 677: 'nail',
679
+ 678: 'neck brace',
680
+ 679: 'necklace',
681
+ 680: 'nipple',
682
+ 681: 'notebook, notebook computer',
683
+ 682: 'obelisk',
684
+ 683: 'oboe, hautboy, hautbois',
685
+ 684: 'ocarina, sweet potato',
686
+ 685: 'odometer, hodometer, mileometer, milometer',
687
+ 686: 'oil filter',
688
+ 687: 'organ, pipe organ',
689
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
690
+ 689: 'overskirt',
691
+ 690: 'oxcart',
692
+ 691: 'oxygen mask',
693
+ 692: 'packet',
694
+ 693: 'paddle, boat paddle',
695
+ 694: 'paddlewheel, paddle wheel',
696
+ 695: 'padlock',
697
+ 696: 'paintbrush',
698
+ 697: "pajama, pyjama, pj's, jammies",
699
+ 698: 'palace',
700
+ 699: 'panpipe, pandean pipe, syrinx',
701
+ 700: 'paper towel',
702
+ 701: 'parachute, chute',
703
+ 702: 'parallel bars, bars',
704
+ 703: 'park bench',
705
+ 704: 'parking meter',
706
+ 705: 'passenger car, coach, carriage',
707
+ 706: 'patio, terrace',
708
+ 707: 'pay-phone, pay-station',
709
+ 708: 'pedestal, plinth, footstall',
710
+ 709: 'pencil box, pencil case',
711
+ 710: 'pencil sharpener',
712
+ 711: 'perfume, essence',
713
+ 712: 'Petri dish',
714
+ 713: 'photocopier',
715
+ 714: 'pick, plectrum, plectron',
716
+ 715: 'pickelhaube',
717
+ 716: 'picket fence, paling',
718
+ 717: 'pickup, pickup truck',
719
+ 718: 'pier',
720
+ 719: 'piggy bank, penny bank',
721
+ 720: 'pill bottle',
722
+ 721: 'pillow',
723
+ 722: 'ping-pong ball',
724
+ 723: 'pinwheel',
725
+ 724: 'pirate, pirate ship',
726
+ 725: 'pitcher, ewer',
727
+ 726: "plane, carpenter's plane, woodworking plane",
728
+ 727: 'planetarium',
729
+ 728: 'plastic bag',
730
+ 729: 'plate rack',
731
+ 730: 'plow, plough',
732
+ 731: "plunger, plumber's helper",
733
+ 732: 'Polaroid camera, Polaroid Land camera',
734
+ 733: 'pole',
735
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
736
+ 735: 'poncho',
737
+ 736: 'pool table, billiard table, snooker table',
738
+ 737: 'pop bottle, soda bottle',
739
+ 738: 'pot, flowerpot',
740
+ 739: "potter's wheel",
741
+ 740: 'power drill',
742
+ 741: 'prayer rug, prayer mat',
743
+ 742: 'printer',
744
+ 743: 'prison, prison house',
745
+ 744: 'projectile, missile',
746
+ 745: 'projector',
747
+ 746: 'puck, hockey puck',
748
+ 747: 'punching bag, punch bag, punching ball, punchball',
749
+ 748: 'purse',
750
+ 749: 'quill, quill pen',
751
+ 750: 'quilt, comforter, comfort, puff',
752
+ 751: 'racer, race car, racing car',
753
+ 752: 'racket, racquet',
754
+ 753: 'radiator',
755
+ 754: 'radio, wireless',
756
+ 755: 'radio telescope, radio reflector',
757
+ 756: 'rain barrel',
758
+ 757: 'recreational vehicle, RV, R.V.',
759
+ 758: 'reel',
760
+ 759: 'reflex camera',
761
+ 760: 'refrigerator, icebox',
762
+ 761: 'remote control, remote',
763
+ 762: 'restaurant, eating house, eating place, eatery',
764
+ 763: 'revolver, six-gun, six-shooter',
765
+ 764: 'rifle',
766
+ 765: 'rocking chair, rocker',
767
+ 766: 'rotisserie',
768
+ 767: 'rubber eraser, rubber, pencil eraser',
769
+ 768: 'rugby ball',
770
+ 769: 'rule, ruler',
771
+ 770: 'running shoe',
772
+ 771: 'safe',
773
+ 772: 'safety pin',
774
+ 773: 'saltshaker, salt shaker',
775
+ 774: 'sandal',
776
+ 775: 'sarong',
777
+ 776: 'sax, saxophone',
778
+ 777: 'scabbard',
779
+ 778: 'scale, weighing machine',
780
+ 779: 'school bus',
781
+ 780: 'schooner',
782
+ 781: 'scoreboard',
783
+ 782: 'screen, CRT screen',
784
+ 783: 'screw',
785
+ 784: 'screwdriver',
786
+ 785: 'seat belt, seatbelt',
787
+ 786: 'sewing machine',
788
+ 787: 'shield, buckler',
789
+ 788: 'shoe shop, shoe-shop, shoe store',
790
+ 789: 'shoji',
791
+ 790: 'shopping basket',
792
+ 791: 'shopping cart',
793
+ 792: 'shovel',
794
+ 793: 'shower cap',
795
+ 794: 'shower curtain',
796
+ 795: 'ski',
797
+ 796: 'ski mask',
798
+ 797: 'sleeping bag',
799
+ 798: 'slide rule, slipstick',
800
+ 799: 'sliding door',
801
+ 800: 'slot, one-armed bandit',
802
+ 801: 'snorkel',
803
+ 802: 'snowmobile',
804
+ 803: 'snowplow, snowplough',
805
+ 804: 'soap dispenser',
806
+ 805: 'soccer ball',
807
+ 806: 'sock',
808
+ 807: 'solar dish, solar collector, solar furnace',
809
+ 808: 'sombrero',
810
+ 809: 'soup bowl',
811
+ 810: 'space bar',
812
+ 811: 'space heater',
813
+ 812: 'space shuttle',
814
+ 813: 'spatula',
815
+ 814: 'speedboat',
816
+ 815: "spider web, spider's web",
817
+ 816: 'spindle',
818
+ 817: 'sports car, sport car',
819
+ 818: 'spotlight, spot',
820
+ 819: 'stage',
821
+ 820: 'steam locomotive',
822
+ 821: 'steel arch bridge',
823
+ 822: 'steel drum',
824
+ 823: 'stethoscope',
825
+ 824: 'stole',
826
+ 825: 'stone wall',
827
+ 826: 'stopwatch, stop watch',
828
+ 827: 'stove',
829
+ 828: 'strainer',
830
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
831
+ 830: 'stretcher',
832
+ 831: 'studio couch, day bed',
833
+ 832: 'stupa, tope',
834
+ 833: 'submarine, pigboat, sub, U-boat',
835
+ 834: 'suit, suit of clothes',
836
+ 835: 'sundial',
837
+ 836: 'sunglass',
838
+ 837: 'sunglasses, dark glasses, shades',
839
+ 838: 'sunscreen, sunblock, sun blocker',
840
+ 839: 'suspension bridge',
841
+ 840: 'swab, swob, mop',
842
+ 841: 'sweatshirt',
843
+ 842: 'swimming trunks, bathing trunks',
844
+ 843: 'swing',
845
+ 844: 'switch, electric switch, electrical switch',
846
+ 845: 'syringe',
847
+ 846: 'table lamp',
848
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
849
+ 848: 'tape player',
850
+ 849: 'teapot',
851
+ 850: 'teddy, teddy bear',
852
+ 851: 'television, television system',
853
+ 852: 'tennis ball',
854
+ 853: 'thatch, thatched roof',
855
+ 854: 'theater curtain, theatre curtain',
856
+ 855: 'thimble',
857
+ 856: 'thresher, thrasher, threshing machine',
858
+ 857: 'throne',
859
+ 858: 'tile roof',
860
+ 859: 'toaster',
861
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
862
+ 861: 'toilet seat',
863
+ 862: 'torch',
864
+ 863: 'totem pole',
865
+ 864: 'tow truck, tow car, wrecker',
866
+ 865: 'toyshop',
867
+ 866: 'tractor',
868
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
869
+ 868: 'tray',
870
+ 869: 'trench coat',
871
+ 870: 'tricycle, trike, velocipede',
872
+ 871: 'trimaran',
873
+ 872: 'tripod',
874
+ 873: 'triumphal arch',
875
+ 874: 'trolleybus, trolley coach, trackless trolley',
876
+ 875: 'trombone',
877
+ 876: 'tub, vat',
878
+ 877: 'turnstile',
879
+ 878: 'typewriter keyboard',
880
+ 879: 'umbrella',
881
+ 880: 'unicycle, monocycle',
882
+ 881: 'upright, upright piano',
883
+ 882: 'vacuum, vacuum cleaner',
884
+ 883: 'vase',
885
+ 884: 'vault',
886
+ 885: 'velvet',
887
+ 886: 'vending machine',
888
+ 887: 'vestment',
889
+ 888: 'viaduct',
890
+ 889: 'violin, fiddle',
891
+ 890: 'volleyball',
892
+ 891: 'waffle iron',
893
+ 892: 'wall clock',
894
+ 893: 'wallet, billfold, notecase, pocketbook',
895
+ 894: 'wardrobe, closet, press',
896
+ 895: 'warplane, military plane',
897
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
898
+ 897: 'washer, automatic washer, washing machine',
899
+ 898: 'water bottle',
900
+ 899: 'water jug',
901
+ 900: 'water tower',
902
+ 901: 'whiskey jug',
903
+ 902: 'whistle',
904
+ 903: 'wig',
905
+ 904: 'window screen',
906
+ 905: 'window shade',
907
+ 906: 'Windsor tie',
908
+ 907: 'wine bottle',
909
+ 908: 'wing',
910
+ 909: 'wok',
911
+ 910: 'wooden spoon',
912
+ 911: 'wool, woolen, woollen',
913
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
914
+ 913: 'wreck',
915
+ 914: 'yawl',
916
+ 915: 'yurt',
917
+ 916: 'web site, website, internet site, site',
918
+ 917: 'comic book',
919
+ 918: 'crossword puzzle, crossword',
920
+ 919: 'street sign',
921
+ 920: 'traffic light, traffic signal, stoplight',
922
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
923
+ 922: 'menu',
924
+ 923: 'plate',
925
+ 924: 'guacamole',
926
+ 925: 'consomme',
927
+ 926: 'hot pot, hotpot',
928
+ 927: 'trifle',
929
+ 928: 'ice cream, icecream',
930
+ 929: 'ice lolly, lolly, lollipop, popsicle',
931
+ 930: 'French loaf',
932
+ 931: 'bagel, beigel',
933
+ 932: 'pretzel',
934
+ 933: 'cheeseburger',
935
+ 934: 'hotdog, hot dog, red hot',
936
+ 935: 'mashed potato',
937
+ 936: 'head cabbage',
938
+ 937: 'broccoli',
939
+ 938: 'cauliflower',
940
+ 939: 'zucchini, courgette',
941
+ 940: 'spaghetti squash',
942
+ 941: 'acorn squash',
943
+ 942: 'butternut squash',
944
+ 943: 'cucumber, cuke',
945
+ 944: 'artichoke, globe artichoke',
946
+ 945: 'bell pepper',
947
+ 946: 'cardoon',
948
+ 947: 'mushroom',
949
+ 948: 'Granny Smith',
950
+ 949: 'strawberry',
951
+ 950: 'orange',
952
+ 951: 'lemon',
953
+ 952: 'fig',
954
+ 953: 'pineapple, ananas',
955
+ 954: 'banana',
956
+ 955: 'jackfruit, jak, jack',
957
+ 956: 'custard apple',
958
+ 957: 'pomegranate',
959
+ 958: 'hay',
960
+ 959: 'carbonara',
961
+ 960: 'chocolate sauce, chocolate syrup',
962
+ 961: 'dough',
963
+ 962: 'meat loaf, meatloaf',
964
+ 963: 'pizza, pizza pie',
965
+ 964: 'potpie',
966
+ 965: 'burrito',
967
+ 966: 'red wine',
968
+ 967: 'espresso',
969
+ 968: 'cup',
970
+ 969: 'eggnog',
971
+ 970: 'alp',
972
+ 971: 'bubble',
973
+ 972: 'cliff, drop, drop-off',
974
+ 973: 'coral reef',
975
+ 974: 'geyser',
976
+ 975: 'lakeside, lakeshore',
977
+ 976: 'promontory, headland, head, foreland',
978
+ 977: 'sandbar, sand bar',
979
+ 978: 'seashore, coast, seacoast, sea-coast',
980
+ 979: 'valley, vale',
981
+ 980: 'volcano',
982
+ 981: 'ballplayer, baseball player',
983
+ 982: 'groom, bridegroom',
984
+ 983: 'scuba diver',
985
+ 984: 'rapeseed',
986
+ 985: 'daisy',
987
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
988
+ 987: 'corn',
989
+ 988: 'acorn',
990
+ 989: 'hip, rose hip, rosehip',
991
+ 990: 'buckeye, horse chestnut, conker',
992
+ 991: 'coral fungus',
993
+ 992: 'agaric',
994
+ 993: 'gyromitra',
995
+ 994: 'stinkhorn, carrion fungus',
996
+ 995: 'earthstar',
997
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
998
+ 997: 'bolete',
999
+ 998: 'ear, spike, capitulum',
1000
+ 999: 'toilet tissue, toilet paper, bathroom tissue'
environment.yaml CHANGED
@@ -5,9 +5,9 @@ channels:
5
  dependencies:
6
  - python=3.8.5
7
  - pip=20.3
8
- - cudatoolkit=11.0
9
- - pytorch=1.7.0
10
- - torchvision=0.8.1
11
  - numpy=1.19.2
12
  - pip:
13
  - albumentations==0.4.3
@@ -21,7 +21,7 @@ dependencies:
21
  - streamlit>=0.73.1
22
  - einops==0.3.0
23
  - torch-fidelity==0.3.0
24
- - transformers==4.3.1
25
  - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
26
  - -e git+https://github.com/openai/CLIP.git@main#egg=clip
27
- - -e .
 
5
  dependencies:
6
  - python=3.8.5
7
  - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.11.0
10
+ - torchvision=0.12.0
11
  - numpy=1.19.2
12
  - pip:
13
  - albumentations==0.4.3
 
21
  - streamlit>=0.73.1
22
  - einops==0.3.0
23
  - torch-fidelity==0.3.0
24
+ - transformers==4.19.2
25
  - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
26
  - -e git+https://github.com/openai/CLIP.git@main#egg=clip
27
+ - -e .
ldm/models/diffusion/ddim.py CHANGED
@@ -5,7 +5,8 @@ import numpy as np
5
  from tqdm import tqdm
6
  from functools import partial
7
 
8
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
 
9
 
10
 
11
  class DDIMSampler(object):
@@ -72,6 +73,9 @@ class DDIMSampler(object):
72
  verbose=True,
73
  x_T=None,
74
  log_every_t=100,
 
 
 
75
  **kwargs
76
  ):
77
  if conditioning is not None:
@@ -100,7 +104,9 @@ class DDIMSampler(object):
100
  score_corrector=score_corrector,
101
  corrector_kwargs=corrector_kwargs,
102
  x_T=x_T,
103
- log_every_t=log_every_t
 
 
104
  )
105
  return samples, intermediates
106
 
@@ -109,7 +115,8 @@ class DDIMSampler(object):
109
  x_T=None, ddim_use_original_steps=False,
110
  callback=None, timesteps=None, quantize_denoised=False,
111
  mask=None, x0=None, img_callback=None, log_every_t=100,
112
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
 
113
  device = self.model.betas.device
114
  b = shape[0]
115
  if x_T is None:
@@ -142,7 +149,9 @@ class DDIMSampler(object):
142
  outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
143
  quantize_denoised=quantize_denoised, temperature=temperature,
144
  noise_dropout=noise_dropout, score_corrector=score_corrector,
145
- corrector_kwargs=corrector_kwargs)
 
 
146
  img, pred_x0 = outs
147
  if callback: callback(i)
148
  if img_callback: img_callback(pred_x0, i)
@@ -155,9 +164,19 @@ class DDIMSampler(object):
155
 
156
  @torch.no_grad()
157
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
158
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
 
159
  b, *_, device = *x.shape, x.device
160
- e_t = self.model.apply_model(x, t, c)
 
 
 
 
 
 
 
 
 
161
  if score_corrector is not None:
162
  assert self.model.parameterization == "eps"
163
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
@@ -183,3 +202,40 @@ class DDIMSampler(object):
183
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
184
  x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
185
  return x_prev, pred_x0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from tqdm import tqdm
6
  from functools import partial
7
 
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
+ extract_into_tensor
10
 
11
 
12
  class DDIMSampler(object):
 
73
  verbose=True,
74
  x_T=None,
75
  log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
  **kwargs
80
  ):
81
  if conditioning is not None:
 
104
  score_corrector=score_corrector,
105
  corrector_kwargs=corrector_kwargs,
106
  x_T=x_T,
107
+ log_every_t=log_every_t,
108
+ unconditional_guidance_scale=unconditional_guidance_scale,
109
+ unconditional_conditioning=unconditional_conditioning,
110
  )
111
  return samples, intermediates
112
 
 
115
  x_T=None, ddim_use_original_steps=False,
116
  callback=None, timesteps=None, quantize_denoised=False,
117
  mask=None, x0=None, img_callback=None, log_every_t=100,
118
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
120
  device = self.model.betas.device
121
  b = shape[0]
122
  if x_T is None:
 
149
  outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
150
  quantize_denoised=quantize_denoised, temperature=temperature,
151
  noise_dropout=noise_dropout, score_corrector=score_corrector,
152
+ corrector_kwargs=corrector_kwargs,
153
+ unconditional_guidance_scale=unconditional_guidance_scale,
154
+ unconditional_conditioning=unconditional_conditioning)
155
  img, pred_x0 = outs
156
  if callback: callback(i)
157
  if img_callback: img_callback(pred_x0, i)
 
164
 
165
  @torch.no_grad()
166
  def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
167
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
168
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
169
  b, *_, device = *x.shape, x.device
170
+
171
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
172
+ e_t = self.model.apply_model(x, t, c)
173
+ else:
174
+ x_in = torch.cat([x] * 2)
175
+ t_in = torch.cat([t] * 2)
176
+ c_in = torch.cat([unconditional_conditioning, c])
177
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
178
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
179
+
180
  if score_corrector is not None:
181
  assert self.model.parameterization == "eps"
182
  e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
 
202
  noise = torch.nn.functional.dropout(noise, p=noise_dropout)
203
  x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
204
  return x_prev, pred_x0
205
+
206
+ @torch.no_grad()
207
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
208
+ # fast, but does not allow for exact reconstruction
209
+ # t serves as an index to gather the correct alphas
210
+ if use_original_steps:
211
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
212
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
213
+ else:
214
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
215
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
216
+
217
+ if noise is None:
218
+ noise = torch.randn_like(x0)
219
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
220
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
221
+
222
+ @torch.no_grad()
223
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
224
+ use_original_steps=False):
225
+
226
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
227
+ timesteps = timesteps[:t_start]
228
+
229
+ time_range = np.flip(timesteps)
230
+ total_steps = timesteps.shape[0]
231
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
232
+
233
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
234
+ x_dec = x_latent
235
+ for i, step in enumerate(iterator):
236
+ index = total_steps - i - 1
237
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
238
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239
+ unconditional_guidance_scale=unconditional_guidance_scale,
240
+ unconditional_conditioning=unconditional_conditioning)
241
+ return x_dec
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class PLMSSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ if type(attr) == torch.Tensor:
20
+ if attr.device != torch.device("cuda"):
21
+ attr = attr.to(torch.device("cuda"))
22
+ setattr(self, name, attr)
23
+
24
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
+ if ddim_eta != 0:
26
+ raise ValueError('ddim_eta must be 0 for PLMS')
27
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
+ alphas_cumprod = self.model.alphas_cumprod
30
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
+
33
+ self.register_buffer('betas', to_torch(self.model.betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
+
44
+ # ddim sampling parameters
45
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
+ ddim_timesteps=self.ddim_timesteps,
47
+ eta=ddim_eta,verbose=verbose)
48
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
49
+ self.register_buffer('ddim_alphas', ddim_alphas)
50
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
+
57
+ @torch.no_grad()
58
+ def sample(self,
59
+ S,
60
+ batch_size,
61
+ shape,
62
+ conditioning=None,
63
+ callback=None,
64
+ normals_sequence=None,
65
+ img_callback=None,
66
+ quantize_x0=False,
67
+ eta=0.,
68
+ mask=None,
69
+ x0=None,
70
+ temperature=1.,
71
+ noise_dropout=0.,
72
+ score_corrector=None,
73
+ corrector_kwargs=None,
74
+ verbose=True,
75
+ x_T=None,
76
+ log_every_t=100,
77
+ unconditional_guidance_scale=1.,
78
+ unconditional_conditioning=None,
79
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
80
+ **kwargs
81
+ ):
82
+ if conditioning is not None:
83
+ if isinstance(conditioning, dict):
84
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+ else:
88
+ if conditioning.shape[0] != batch_size:
89
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
90
+
91
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
92
+ # sampling
93
+ C, H, W = shape
94
+ size = (batch_size, C, H, W)
95
+ print(f'Data shape for PLMS sampling is {size}')
96
+
97
+ samples, intermediates = self.plms_sampling(conditioning, size,
98
+ callback=callback,
99
+ img_callback=img_callback,
100
+ quantize_denoised=quantize_x0,
101
+ mask=mask, x0=x0,
102
+ ddim_use_original_steps=False,
103
+ noise_dropout=noise_dropout,
104
+ temperature=temperature,
105
+ score_corrector=score_corrector,
106
+ corrector_kwargs=corrector_kwargs,
107
+ x_T=x_T,
108
+ log_every_t=log_every_t,
109
+ unconditional_guidance_scale=unconditional_guidance_scale,
110
+ unconditional_conditioning=unconditional_conditioning,
111
+ )
112
+ return samples, intermediates
113
+
114
+ @torch.no_grad()
115
+ def plms_sampling(self, cond, shape,
116
+ x_T=None, ddim_use_original_steps=False,
117
+ callback=None, timesteps=None, quantize_denoised=False,
118
+ mask=None, x0=None, img_callback=None, log_every_t=100,
119
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
120
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
121
+ device = self.model.betas.device
122
+ b = shape[0]
123
+ if x_T is None:
124
+ img = torch.randn(shape, device=device)
125
+ else:
126
+ img = x_T
127
+
128
+ if timesteps is None:
129
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
130
+ elif timesteps is not None and not ddim_use_original_steps:
131
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
132
+ timesteps = self.ddim_timesteps[:subset_end]
133
+
134
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
135
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
136
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
137
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
138
+
139
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
140
+ old_eps = []
141
+
142
+ for i, step in enumerate(iterator):
143
+ index = total_steps - i - 1
144
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
145
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
146
+
147
+ if mask is not None:
148
+ assert x0 is not None
149
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
150
+ img = img_orig * mask + (1. - mask) * img
151
+
152
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
153
+ quantize_denoised=quantize_denoised, temperature=temperature,
154
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
155
+ corrector_kwargs=corrector_kwargs,
156
+ unconditional_guidance_scale=unconditional_guidance_scale,
157
+ unconditional_conditioning=unconditional_conditioning,
158
+ old_eps=old_eps, t_next=ts_next)
159
+ img, pred_x0, e_t = outs
160
+ old_eps.append(e_t)
161
+ if len(old_eps) >= 4:
162
+ old_eps.pop(0)
163
+ if callback: callback(i)
164
+ if img_callback: img_callback(pred_x0, i)
165
+
166
+ if index % log_every_t == 0 or index == total_steps - 1:
167
+ intermediates['x_inter'].append(img)
168
+ intermediates['pred_x0'].append(pred_x0)
169
+
170
+ return img, intermediates
171
+
172
+ @torch.no_grad()
173
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
174
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
175
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
176
+ b, *_, device = *x.shape, x.device
177
+
178
+ def get_model_output(x, t):
179
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
180
+ e_t = self.model.apply_model(x, t, c)
181
+ else:
182
+ x_in = torch.cat([x] * 2)
183
+ t_in = torch.cat([t] * 2)
184
+ c_in = torch.cat([unconditional_conditioning, c])
185
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
186
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
187
+
188
+ if score_corrector is not None:
189
+ assert self.model.parameterization == "eps"
190
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
191
+
192
+ return e_t
193
+
194
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
195
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
196
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
197
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
198
+
199
+ def get_x_prev_and_pred_x0(e_t, index):
200
+ # select parameters corresponding to the currently considered timestep
201
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
202
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
203
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
204
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
205
+
206
+ # current prediction for x_0
207
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
208
+ if quantize_denoised:
209
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
210
+ # direction pointing to x_t
211
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
212
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
213
+ if noise_dropout > 0.:
214
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
215
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
216
+ return x_prev, pred_x0
217
+
218
+ e_t = get_model_output(x, t)
219
+ if len(old_eps) == 0:
220
+ # Pseudo Improved Euler (2nd order)
221
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
222
+ e_t_next = get_model_output(x_prev, t_next)
223
+ e_t_prime = (e_t + e_t_next) / 2
224
+ elif len(old_eps) == 1:
225
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
226
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
227
+ elif len(old_eps) == 2:
228
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
229
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
230
+ elif len(old_eps) >= 3:
231
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
232
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
233
+
234
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
235
+
236
+ return x_prev, pred_x0, e_t
ldm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -455,7 +455,7 @@ class UNetModel(nn.Module):
455
  num_classes=None,
456
  use_checkpoint=False,
457
  use_fp16=False,
458
- num_heads=1,
459
  num_head_channels=-1,
460
  num_heads_upsample=-1,
461
  use_scale_shift_norm=False,
@@ -464,21 +464,28 @@ class UNetModel(nn.Module):
464
  use_spatial_transformer=False, # custom transformer support
465
  transformer_depth=1, # custom transformer support
466
  context_dim=None, # custom transformer support
467
- n_embed=None # custom support for prediction of discrete ids into codebook of first stage vq model
 
468
  ):
469
  super().__init__()
470
-
471
  if use_spatial_transformer:
472
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
473
 
474
  if context_dim is not None:
475
  assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
476
-
477
-
 
478
 
479
  if num_heads_upsample == -1:
480
  num_heads_upsample = num_heads
481
 
 
 
 
 
 
 
482
  self.image_size = image_size
483
  self.in_channels = in_channels
484
  self.model_channels = model_channels
@@ -532,13 +539,20 @@ class UNetModel(nn.Module):
532
  ]
533
  ch = mult * model_channels
534
  if ds in attention_resolutions:
535
- dim_head = ch // num_heads
 
 
 
 
 
 
 
536
  layers.append(
537
  AttentionBlock(
538
  ch,
539
  use_checkpoint=use_checkpoint,
540
  num_heads=num_heads,
541
- num_head_channels=num_head_channels,
542
  use_new_attention_order=use_new_attention_order,
543
  ) if not use_spatial_transformer else SpatialTransformer(
544
  ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
@@ -572,7 +586,14 @@ class UNetModel(nn.Module):
572
  ds *= 2
573
  self._feature_size += ch
574
 
575
- dim_head = ch // num_heads
 
 
 
 
 
 
 
576
  self.middle_block = TimestepEmbedSequential(
577
  ResBlock(
578
  ch,
@@ -586,7 +607,7 @@ class UNetModel(nn.Module):
586
  ch,
587
  use_checkpoint=use_checkpoint,
588
  num_heads=num_heads,
589
- num_head_channels=num_head_channels,
590
  use_new_attention_order=use_new_attention_order,
591
  ) if not use_spatial_transformer else SpatialTransformer(
592
  ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
@@ -619,13 +640,20 @@ class UNetModel(nn.Module):
619
  ]
620
  ch = model_channels * mult
621
  if ds in attention_resolutions:
622
- dim_head = ch // num_heads
 
 
 
 
 
 
 
623
  layers.append(
624
  AttentionBlock(
625
  ch,
626
  use_checkpoint=use_checkpoint,
627
  num_heads=num_heads_upsample,
628
- num_head_channels=num_head_channels,
629
  use_new_attention_order=use_new_attention_order,
630
  ) if not use_spatial_transformer else SpatialTransformer(
631
  ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
@@ -691,7 +719,6 @@ class UNetModel(nn.Module):
691
  assert (y is not None) == (
692
  self.num_classes is not None
693
  ), "must specify y if and only if the model is class-conditional"
694
- assert timesteps is not None, 'need to implement no-timestep usage'
695
  hs = []
696
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
697
  emb = self.time_embed(t_emb)
@@ -710,14 +737,12 @@ class UNetModel(nn.Module):
710
  h = module(h, emb, context)
711
  h = h.type(x.dtype)
712
  if self.predict_codebook_ids:
713
- #return self.out(h), self.id_predictor(h)
714
  return self.id_predictor(h)
715
  else:
716
  return self.out(h)
717
 
718
 
719
  class EncoderUNetModel(nn.Module):
720
- # TODO: do we use it ?
721
  """
722
  The half UNet model with attention and timestep embedding.
723
  For usage, see UNet.
 
455
  num_classes=None,
456
  use_checkpoint=False,
457
  use_fp16=False,
458
+ num_heads=-1,
459
  num_head_channels=-1,
460
  num_heads_upsample=-1,
461
  use_scale_shift_norm=False,
 
464
  use_spatial_transformer=False, # custom transformer support
465
  transformer_depth=1, # custom transformer support
466
  context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True,
469
  ):
470
  super().__init__()
 
471
  if use_spatial_transformer:
472
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
473
 
474
  if context_dim is not None:
475
  assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
476
+ from omegaconf.listconfig import ListConfig
477
+ if type(context_dim) == ListConfig:
478
+ context_dim = list(context_dim)
479
 
480
  if num_heads_upsample == -1:
481
  num_heads_upsample = num_heads
482
 
483
+ if num_heads == -1:
484
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
485
+
486
+ if num_head_channels == -1:
487
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
  self.image_size = image_size
490
  self.in_channels = in_channels
491
  self.model_channels = model_channels
 
539
  ]
540
  ch = mult * model_channels
541
  if ds in attention_resolutions:
542
+ if num_head_channels == -1:
543
+ dim_head = ch // num_heads
544
+ else:
545
+ num_heads = ch // num_head_channels
546
+ dim_head = num_head_channels
547
+ if legacy:
548
+ #num_heads = 1
549
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
550
  layers.append(
551
  AttentionBlock(
552
  ch,
553
  use_checkpoint=use_checkpoint,
554
  num_heads=num_heads,
555
+ num_head_channels=dim_head,
556
  use_new_attention_order=use_new_attention_order,
557
  ) if not use_spatial_transformer else SpatialTransformer(
558
  ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 
586
  ds *= 2
587
  self._feature_size += ch
588
 
589
+ if num_head_channels == -1:
590
+ dim_head = ch // num_heads
591
+ else:
592
+ num_heads = ch // num_head_channels
593
+ dim_head = num_head_channels
594
+ if legacy:
595
+ #num_heads = 1
596
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
597
  self.middle_block = TimestepEmbedSequential(
598
  ResBlock(
599
  ch,
 
607
  ch,
608
  use_checkpoint=use_checkpoint,
609
  num_heads=num_heads,
610
+ num_head_channels=dim_head,
611
  use_new_attention_order=use_new_attention_order,
612
  ) if not use_spatial_transformer else SpatialTransformer(
613
  ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 
640
  ]
641
  ch = model_channels * mult
642
  if ds in attention_resolutions:
643
+ if num_head_channels == -1:
644
+ dim_head = ch // num_heads
645
+ else:
646
+ num_heads = ch // num_head_channels
647
+ dim_head = num_head_channels
648
+ if legacy:
649
+ #num_heads = 1
650
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
651
  layers.append(
652
  AttentionBlock(
653
  ch,
654
  use_checkpoint=use_checkpoint,
655
  num_heads=num_heads_upsample,
656
+ num_head_channels=dim_head,
657
  use_new_attention_order=use_new_attention_order,
658
  ) if not use_spatial_transformer else SpatialTransformer(
659
  ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
 
719
  assert (y is not None) == (
720
  self.num_classes is not None
721
  ), "must specify y if and only if the model is class-conditional"
 
722
  hs = []
723
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
724
  emb = self.time_embed(t_emb)
 
737
  h = module(h, emb, context)
738
  h = h.type(x.dtype)
739
  if self.predict_codebook_ids:
 
740
  return self.id_predictor(h)
741
  else:
742
  return self.out(h)
743
 
744
 
745
  class EncoderUNetModel(nn.Module):
 
746
  """
747
  The half UNet model with attention and timestep embedding.
748
  For usage, see UNet.
ldm/modules/encoders/modules.py CHANGED
@@ -1,6 +1,10 @@
1
  import torch
2
  import torch.nn as nn
3
  from functools import partial
 
 
 
 
4
 
5
  from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
6
 
@@ -129,3 +133,102 @@ class SpatialRescaler(nn.Module):
129
 
130
  def encode(self, x):
131
  return self(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from functools import partial
4
+ import clip
5
+ from einops import rearrange, repeat
6
+ from transformers import CLIPTokenizer, CLIPTextModel
7
+ import kornia
8
 
9
  from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10
 
 
133
 
134
  def encode(self, x):
135
  return self(x)
136
+
137
+ class FrozenCLIPEmbedder(AbstractEncoder):
138
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
139
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
140
+ super().__init__()
141
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
142
+ self.transformer = CLIPTextModel.from_pretrained(version)
143
+ self.device = device
144
+ self.max_length = max_length
145
+ self.freeze()
146
+
147
+ def freeze(self):
148
+ self.transformer = self.transformer.eval()
149
+ for param in self.parameters():
150
+ param.requires_grad = False
151
+
152
+ def forward(self, text):
153
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
154
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
155
+ tokens = batch_encoding["input_ids"].to(self.device)
156
+ outputs = self.transformer(input_ids=tokens)
157
+
158
+ z = outputs.last_hidden_state
159
+ return z
160
+
161
+ def encode(self, text):
162
+ return self(text)
163
+
164
+
165
+ class FrozenCLIPTextEmbedder(nn.Module):
166
+ """
167
+ Uses the CLIP transformer encoder for text.
168
+ """
169
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
170
+ super().__init__()
171
+ self.model, _ = clip.load(version, jit=False, device="cpu")
172
+ self.device = device
173
+ self.max_length = max_length
174
+ self.n_repeat = n_repeat
175
+ self.normalize = normalize
176
+
177
+ def freeze(self):
178
+ self.model = self.model.eval()
179
+ for param in self.parameters():
180
+ param.requires_grad = False
181
+
182
+ def forward(self, text):
183
+ tokens = clip.tokenize(text).to(self.device)
184
+ z = self.model.encode_text(tokens)
185
+ if self.normalize:
186
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
187
+ return z
188
+
189
+ def encode(self, text):
190
+ z = self(text)
191
+ if z.ndim==2:
192
+ z = z[:, None, :]
193
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
194
+ return z
195
+
196
+
197
+ class FrozenClipImageEmbedder(nn.Module):
198
+ """
199
+ Uses the CLIP image encoder.
200
+ """
201
+ def __init__(
202
+ self,
203
+ model,
204
+ jit=False,
205
+ device='cuda' if torch.cuda.is_available() else 'cpu',
206
+ antialias=False,
207
+ ):
208
+ super().__init__()
209
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
210
+
211
+ self.antialias = antialias
212
+
213
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
214
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
215
+
216
+ def preprocess(self, x):
217
+ # normalize to [0,1]
218
+ x = kornia.geometry.resize(x, (224, 224),
219
+ interpolation='bicubic',align_corners=True,
220
+ antialias=self.antialias)
221
+ x = (x + 1.) / 2.
222
+ # renormalize according to clip
223
+ x = kornia.enhance.normalize(x, self.mean, self.std)
224
+ return x
225
+
226
+ def forward(self, x):
227
+ # x is assumed to be in range [-1,1]
228
+ return self.model.encode_image(self.preprocess(x))
229
+
230
+
231
+ if __name__ == "__main__":
232
+ from ldm.util import count_params
233
+ model = FrozenCLIPEmbedder()
234
+ count_params(model, verbose=True)
ldm/modules/x_transformer.py CHANGED
@@ -407,7 +407,7 @@ class AttentionLayers(nn.Module):
407
  self.rotary_pos_emb = always(None)
408
 
409
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410
- self.rel_pos = always(None)
411
 
412
  self.pre_norm = pre_norm
413
 
 
407
  self.rotary_pos_emb = always(None)
408
 
409
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
410
+ self.rel_pos = None
411
 
412
  self.pre_norm = pre_norm
413
 
ldm/util.py CHANGED
@@ -2,6 +2,13 @@ import importlib
2
 
3
  import torch
4
  import numpy as np
 
 
 
 
 
 
 
5
 
6
  from inspect import isfunction
7
  from PIL import Image, ImageDraw, ImageFont
@@ -38,7 +45,7 @@ def ismap(x):
38
 
39
 
40
  def isimage(x):
41
- if not isinstance(x,torch.Tensor):
42
  return False
43
  return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
44
 
@@ -64,7 +71,7 @@ def mean_flat(tensor):
64
  def count_params(model, verbose=False):
65
  total_params = sum(p.numel() for p in model.parameters())
66
  if verbose:
67
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
68
  return total_params
69
 
70
 
@@ -83,4 +90,114 @@ def get_obj_from_str(string, reload=False):
83
  if reload:
84
  module_imp = importlib.import_module(module)
85
  importlib.reload(module_imp)
86
- return getattr(importlib.import_module(module, package=None), cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  import numpy as np
5
+ from collections import abc
6
+ from einops import rearrange
7
+ from functools import partial
8
+
9
+ import multiprocessing as mp
10
+ from threading import Thread
11
+ from queue import Queue
12
 
13
  from inspect import isfunction
14
  from PIL import Image, ImageDraw, ImageFont
 
45
 
46
 
47
  def isimage(x):
48
+ if not isinstance(x, torch.Tensor):
49
  return False
50
  return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51
 
 
71
  def count_params(model, verbose=False):
72
  total_params = sum(p.numel() for p in model.parameters())
73
  if verbose:
74
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75
  return total_params
76
 
77
 
 
90
  if reload:
91
  module_imp = importlib.import_module(module)
92
  importlib.reload(module_imp)
93
+ return getattr(importlib.import_module(module, package=None), cls)
94
+
95
+
96
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97
+ # create dummy dataset instance
98
+
99
+ # run prefetching
100
+ if idx_to_fn:
101
+ res = func(data, worker_id=idx)
102
+ else:
103
+ res = func(data)
104
+ Q.put([idx, res])
105
+ Q.put("Done")
106
+
107
+
108
+ def parallel_data_prefetch(
109
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110
+ ):
111
+ # if target_data_type not in ["ndarray", "list"]:
112
+ # raise ValueError(
113
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114
+ # )
115
+ if isinstance(data, np.ndarray) and target_data_type == "list":
116
+ raise ValueError("list expected but function got ndarray.")
117
+ elif isinstance(data, abc.Iterable):
118
+ if isinstance(data, dict):
119
+ print(
120
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121
+ )
122
+ data = list(data.values())
123
+ if target_data_type == "ndarray":
124
+ data = np.asarray(data)
125
+ else:
126
+ data = list(data)
127
+ else:
128
+ raise TypeError(
129
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130
+ )
131
+
132
+ if cpu_intensive:
133
+ Q = mp.Queue(1000)
134
+ proc = mp.Process
135
+ else:
136
+ Q = Queue(1000)
137
+ proc = Thread
138
+ # spawn processes
139
+ if target_data_type == "ndarray":
140
+ arguments = [
141
+ [func, Q, part, i, use_worker_id]
142
+ for i, part in enumerate(np.array_split(data, n_proc))
143
+ ]
144
+ else:
145
+ step = (
146
+ int(len(data) / n_proc + 1)
147
+ if len(data) % n_proc != 0
148
+ else int(len(data) / n_proc)
149
+ )
150
+ arguments = [
151
+ [func, Q, part, i, use_worker_id]
152
+ for i, part in enumerate(
153
+ [data[i: i + step] for i in range(0, len(data), step)]
154
+ )
155
+ ]
156
+ processes = []
157
+ for i in range(n_proc):
158
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159
+ processes += [p]
160
+
161
+ # start processes
162
+ print(f"Start prefetching...")
163
+ import time
164
+
165
+ start = time.time()
166
+ gather_res = [[] for _ in range(n_proc)]
167
+ try:
168
+ for p in processes:
169
+ p.start()
170
+
171
+ k = 0
172
+ while k < n_proc:
173
+ # get result
174
+ res = Q.get()
175
+ if res == "Done":
176
+ k += 1
177
+ else:
178
+ gather_res[res[0]] = res[1]
179
+
180
+ except Exception as e:
181
+ print("Exception: ", e)
182
+ for p in processes:
183
+ p.terminate()
184
+
185
+ raise e
186
+ finally:
187
+ for p in processes:
188
+ p.join()
189
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
190
+
191
+ if target_data_type == 'ndarray':
192
+ if not isinstance(gather_res[0], np.ndarray):
193
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194
+
195
+ # order outputs
196
+ return np.concatenate(gather_res, axis=0)
197
+ elif target_data_type == 'list':
198
+ out = []
199
+ for r in gather_res:
200
+ out.extend(r)
201
+ return out
202
+ else:
203
+ return gather_res
scripts/img2img.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """make variations of input image"""
2
+
3
+ import argparse, os, sys, glob
4
+ import PIL
5
+ import torch
6
+ import numpy as np
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from tqdm import tqdm, trange
10
+ from itertools import islice
11
+ from einops import rearrange, repeat
12
+ from torchvision.utils import make_grid
13
+ from torch import autocast
14
+ from contextlib import nullcontext
15
+ import time
16
+ from pytorch_lightning import seed_everything
17
+
18
+ from ldm.util import instantiate_from_config
19
+ from ldm.models.diffusion.ddim import DDIMSampler
20
+ from ldm.models.diffusion.plms import PLMSSampler
21
+
22
+
23
+ def chunk(it, size):
24
+ it = iter(it)
25
+ return iter(lambda: tuple(islice(it, size)), ())
26
+
27
+
28
+ def load_model_from_config(config, ckpt, verbose=False):
29
+ print(f"Loading model from {ckpt}")
30
+ pl_sd = torch.load(ckpt, map_location="cpu")
31
+ if "global_step" in pl_sd:
32
+ print(f"Global Step: {pl_sd['global_step']}")
33
+ sd = pl_sd["state_dict"]
34
+ model = instantiate_from_config(config.model)
35
+ m, u = model.load_state_dict(sd, strict=False)
36
+ if len(m) > 0 and verbose:
37
+ print("missing keys:")
38
+ print(m)
39
+ if len(u) > 0 and verbose:
40
+ print("unexpected keys:")
41
+ print(u)
42
+
43
+ model.cuda()
44
+ model.eval()
45
+ return model
46
+
47
+
48
+ def load_img(path):
49
+ image = Image.open(path).convert("RGB")
50
+ w, h = image.size
51
+ print(f"loaded input image of size ({w}, {h}) from {path}")
52
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
53
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
54
+ image = np.array(image).astype(np.float32) / 255.0
55
+ image = image[None].transpose(0, 3, 1, 2)
56
+ image = torch.from_numpy(image)
57
+ return 2.*image - 1.
58
+
59
+
60
+ def main():
61
+ parser = argparse.ArgumentParser()
62
+
63
+ parser.add_argument(
64
+ "--prompt",
65
+ type=str,
66
+ nargs="?",
67
+ default="a painting of a virus monster playing guitar",
68
+ help="the prompt to render"
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--init-img",
73
+ type=str,
74
+ nargs="?",
75
+ help="path to the input image"
76
+ )
77
+
78
+ parser.add_argument(
79
+ "--outdir",
80
+ type=str,
81
+ nargs="?",
82
+ help="dir to write results to",
83
+ default="outputs/img2img-samples"
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--skip_grid",
88
+ action='store_true',
89
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--skip_save",
94
+ action='store_true',
95
+ help="do not save indiviual samples. For speed measurements.",
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--ddim_steps",
100
+ type=int,
101
+ default=50,
102
+ help="number of ddim sampling steps",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--plms",
107
+ action='store_true',
108
+ help="use plms sampling",
109
+ )
110
+ parser.add_argument(
111
+ "--fixed_code",
112
+ action='store_true',
113
+ help="if enabled, uses the same starting code across all samples ",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--ddim_eta",
118
+ type=float,
119
+ default=0.0,
120
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
121
+ )
122
+ parser.add_argument(
123
+ "--n_iter",
124
+ type=int,
125
+ default=1,
126
+ help="sample this often",
127
+ )
128
+ parser.add_argument(
129
+ "--C",
130
+ type=int,
131
+ default=4,
132
+ help="latent channels",
133
+ )
134
+ parser.add_argument(
135
+ "--f",
136
+ type=int,
137
+ default=8,
138
+ help="downsampling factor, most often 8 or 16",
139
+ )
140
+ parser.add_argument(
141
+ "--n_samples",
142
+ type=int,
143
+ default=2,
144
+ help="how many samples to produce for each given prompt. A.k.a batch size",
145
+ )
146
+ parser.add_argument(
147
+ "--n_rows",
148
+ type=int,
149
+ default=0,
150
+ help="rows in the grid (default: n_samples)",
151
+ )
152
+ parser.add_argument(
153
+ "--scale",
154
+ type=float,
155
+ default=5.0,
156
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
157
+ )
158
+
159
+ parser.add_argument(
160
+ "--strength",
161
+ type=float,
162
+ default=0.75,
163
+ help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
164
+ )
165
+ parser.add_argument(
166
+ "--from-file",
167
+ type=str,
168
+ help="if specified, load prompts from this file",
169
+ )
170
+ parser.add_argument(
171
+ "--config",
172
+ type=str,
173
+ default="configs/stable-diffusion/v1-inference.yaml",
174
+ help="path to config which constructs model",
175
+ )
176
+ parser.add_argument(
177
+ "--ckpt",
178
+ type=str,
179
+ default="models/ldm/stable-diffusion-v1/model.ckpt",
180
+ help="path to checkpoint of model",
181
+ )
182
+ parser.add_argument(
183
+ "--seed",
184
+ type=int,
185
+ default=42,
186
+ help="the seed (for reproducible sampling)",
187
+ )
188
+ parser.add_argument(
189
+ "--precision",
190
+ type=str,
191
+ help="evaluate at this precision",
192
+ choices=["full", "autocast"],
193
+ default="autocast"
194
+ )
195
+
196
+ opt = parser.parse_args()
197
+ seed_everything(opt.seed)
198
+
199
+ config = OmegaConf.load(f"{opt.config}")
200
+ model = load_model_from_config(config, f"{opt.ckpt}")
201
+
202
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
203
+ model = model.to(device)
204
+
205
+ if opt.plms:
206
+ raise NotImplementedError("PLMS sampler not (yet) supported")
207
+ sampler = PLMSSampler(model)
208
+ else:
209
+ sampler = DDIMSampler(model)
210
+
211
+ os.makedirs(opt.outdir, exist_ok=True)
212
+ outpath = opt.outdir
213
+
214
+ batch_size = opt.n_samples
215
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
216
+ if not opt.from_file:
217
+ prompt = opt.prompt
218
+ assert prompt is not None
219
+ data = [batch_size * [prompt]]
220
+
221
+ else:
222
+ print(f"reading prompts from {opt.from_file}")
223
+ with open(opt.from_file, "r") as f:
224
+ data = f.read().splitlines()
225
+ data = list(chunk(data, batch_size))
226
+
227
+ sample_path = os.path.join(outpath, "samples")
228
+ os.makedirs(sample_path, exist_ok=True)
229
+ base_count = len(os.listdir(sample_path))
230
+ grid_count = len(os.listdir(outpath)) - 1
231
+
232
+ assert os.path.isfile(opt.init_img)
233
+ init_image = load_img(opt.init_img).to(device)
234
+ init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
235
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
236
+
237
+ sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
238
+
239
+ assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
240
+ t_enc = int(opt.strength * opt.ddim_steps)
241
+ print(f"target t_enc is {t_enc} steps")
242
+
243
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
244
+ with torch.no_grad():
245
+ with precision_scope("cuda"):
246
+ with model.ema_scope():
247
+ tic = time.time()
248
+ all_samples = list()
249
+ for n in trange(opt.n_iter, desc="Sampling"):
250
+ for prompts in tqdm(data, desc="data"):
251
+ uc = None
252
+ if opt.scale != 1.0:
253
+ uc = model.get_learned_conditioning(batch_size * [""])
254
+ if isinstance(prompts, tuple):
255
+ prompts = list(prompts)
256
+ c = model.get_learned_conditioning(prompts)
257
+
258
+ # encode (scaled latent)
259
+ z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
260
+ # decode it
261
+ samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
262
+ unconditional_conditioning=uc,)
263
+
264
+ x_samples = model.decode_first_stage(samples)
265
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
266
+
267
+ if not opt.skip_save:
268
+ for x_sample in x_samples:
269
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
270
+ Image.fromarray(x_sample.astype(np.uint8)).save(
271
+ os.path.join(sample_path, f"{base_count:05}.png"))
272
+ base_count += 1
273
+ all_samples.append(x_samples)
274
+
275
+ if not opt.skip_grid:
276
+ # additionally, save as grid
277
+ grid = torch.stack(all_samples, 0)
278
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
279
+ grid = make_grid(grid, nrow=n_rows)
280
+
281
+ # to image
282
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
283
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
284
+ grid_count += 1
285
+
286
+ toc = time.time()
287
+
288
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
289
+ f" \nEnjoy.")
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()
scripts/knn2img.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import clip
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from omegaconf import OmegaConf
7
+ from PIL import Image
8
+ from tqdm import tqdm, trange
9
+ from itertools import islice
10
+ from einops import rearrange, repeat
11
+ from torchvision.utils import make_grid
12
+ import scann
13
+ import time
14
+ from multiprocessing import cpu_count
15
+
16
+ from ldm.util import instantiate_from_config, parallel_data_prefetch
17
+ from ldm.models.diffusion.ddim import DDIMSampler
18
+ from ldm.models.diffusion.plms import PLMSSampler
19
+ from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
20
+
21
+ DATABASES = [
22
+ "openimages",
23
+ "artbench-art_nouveau",
24
+ "artbench-baroque",
25
+ "artbench-expressionism",
26
+ "artbench-impressionism",
27
+ "artbench-post_impressionism",
28
+ "artbench-realism",
29
+ "artbench-romanticism",
30
+ "artbench-renaissance",
31
+ "artbench-surrealism",
32
+ "artbench-ukiyo_e",
33
+ ]
34
+
35
+
36
+ def chunk(it, size):
37
+ it = iter(it)
38
+ return iter(lambda: tuple(islice(it, size)), ())
39
+
40
+
41
+ def load_model_from_config(config, ckpt, verbose=False):
42
+ print(f"Loading model from {ckpt}")
43
+ pl_sd = torch.load(ckpt, map_location="cpu")
44
+ if "global_step" in pl_sd:
45
+ print(f"Global Step: {pl_sd['global_step']}")
46
+ sd = pl_sd["state_dict"]
47
+ model = instantiate_from_config(config.model)
48
+ m, u = model.load_state_dict(sd, strict=False)
49
+ if len(m) > 0 and verbose:
50
+ print("missing keys:")
51
+ print(m)
52
+ if len(u) > 0 and verbose:
53
+ print("unexpected keys:")
54
+ print(u)
55
+
56
+ model.cuda()
57
+ model.eval()
58
+ return model
59
+
60
+
61
+ class Searcher(object):
62
+ def __init__(self, database, retriever_version='ViT-L/14'):
63
+ assert database in DATABASES
64
+ # self.database = self.load_database(database)
65
+ self.database_name = database
66
+ self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
67
+ self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
68
+ self.retriever = self.load_retriever(version=retriever_version)
69
+ self.database = {'embedding': [],
70
+ 'img_id': [],
71
+ 'patch_coords': []}
72
+ self.load_database()
73
+ self.load_searcher()
74
+
75
+ def train_searcher(self, k,
76
+ metric='dot_product',
77
+ searcher_savedir=None):
78
+
79
+ print('Start training searcher')
80
+ searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
81
+ np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
82
+ k, metric)
83
+ self.searcher = searcher.score_brute_force().build()
84
+ print('Finish training searcher')
85
+
86
+ if searcher_savedir is not None:
87
+ print(f'Save trained searcher under "{searcher_savedir}"')
88
+ os.makedirs(searcher_savedir, exist_ok=True)
89
+ self.searcher.serialize(searcher_savedir)
90
+
91
+ def load_single_file(self, saved_embeddings):
92
+ compressed = np.load(saved_embeddings)
93
+ self.database = {key: compressed[key] for key in compressed.files}
94
+ print('Finished loading of clip embeddings.')
95
+
96
+ def load_multi_files(self, data_archive):
97
+ out_data = {key: [] for key in self.database}
98
+ for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
99
+ for key in d.files:
100
+ out_data[key].append(d[key])
101
+
102
+ return out_data
103
+
104
+ def load_database(self):
105
+
106
+ print(f'Load saved patch embedding from "{self.database_path}"')
107
+ file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
108
+
109
+ if len(file_content) == 1:
110
+ self.load_single_file(file_content[0])
111
+ elif len(file_content) > 1:
112
+ data = [np.load(f) for f in file_content]
113
+ prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
114
+ n_proc=min(len(data), cpu_count()), target_data_type='dict')
115
+
116
+ self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
117
+ self.database}
118
+ else:
119
+ raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
120
+
121
+ print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
122
+
123
+ def load_retriever(self, version='ViT-L/14', ):
124
+ model = FrozenClipImageEmbedder(model=version)
125
+ if torch.cuda.is_available():
126
+ model.cuda()
127
+ model.eval()
128
+ return model
129
+
130
+ def load_searcher(self):
131
+ print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
132
+ self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
133
+ print('Finished loading searcher.')
134
+
135
+ def search(self, x, k):
136
+ if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
137
+ self.train_searcher(k) # quickly fit searcher on the fly for small databases
138
+ assert self.searcher is not None, 'Cannot search with uninitialized searcher'
139
+ if isinstance(x, torch.Tensor):
140
+ x = x.detach().cpu().numpy()
141
+ if len(x.shape) == 3:
142
+ x = x[:, 0]
143
+ query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
144
+
145
+ start = time.time()
146
+ nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
147
+ end = time.time()
148
+
149
+ out_embeddings = self.database['embedding'][nns]
150
+ out_img_ids = self.database['img_id'][nns]
151
+ out_pc = self.database['patch_coords'][nns]
152
+
153
+ out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
154
+ 'img_ids': out_img_ids,
155
+ 'patch_coords': out_pc,
156
+ 'queries': x,
157
+ 'exec_time': end - start,
158
+ 'nns': nns,
159
+ 'q_embeddings': query_embeddings}
160
+
161
+ return out
162
+
163
+ def __call__(self, x, n):
164
+ return self.search(x, n)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser()
169
+ # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
170
+ # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
171
+ parser.add_argument(
172
+ "--prompt",
173
+ type=str,
174
+ nargs="?",
175
+ default="a painting of a virus monster playing guitar",
176
+ help="the prompt to render"
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--outdir",
181
+ type=str,
182
+ nargs="?",
183
+ help="dir to write results to",
184
+ default="outputs/txt2img-samples"
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--skip_grid",
189
+ action='store_true',
190
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--ddim_steps",
195
+ type=int,
196
+ default=50,
197
+ help="number of ddim sampling steps",
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--n_repeat",
202
+ type=int,
203
+ default=1,
204
+ help="number of repeats in CLIP latent space",
205
+ )
206
+
207
+ parser.add_argument(
208
+ "--plms",
209
+ action='store_true',
210
+ help="use plms sampling",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--ddim_eta",
215
+ type=float,
216
+ default=0.0,
217
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
218
+ )
219
+ parser.add_argument(
220
+ "--n_iter",
221
+ type=int,
222
+ default=1,
223
+ help="sample this often",
224
+ )
225
+
226
+ parser.add_argument(
227
+ "--H",
228
+ type=int,
229
+ default=768,
230
+ help="image height, in pixel space",
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--W",
235
+ type=int,
236
+ default=768,
237
+ help="image width, in pixel space",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--n_samples",
242
+ type=int,
243
+ default=3,
244
+ help="how many samples to produce for each given prompt. A.k.a batch size",
245
+ )
246
+
247
+ parser.add_argument(
248
+ "--n_rows",
249
+ type=int,
250
+ default=0,
251
+ help="rows in the grid (default: n_samples)",
252
+ )
253
+
254
+ parser.add_argument(
255
+ "--scale",
256
+ type=float,
257
+ default=5.0,
258
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
259
+ )
260
+
261
+ parser.add_argument(
262
+ "--from-file",
263
+ type=str,
264
+ help="if specified, load prompts from this file",
265
+ )
266
+
267
+ parser.add_argument(
268
+ "--config",
269
+ type=str,
270
+ default="configs/retrieval-augmented-diffusion/768x768.yaml",
271
+ help="path to config which constructs model",
272
+ )
273
+
274
+ parser.add_argument(
275
+ "--ckpt",
276
+ type=str,
277
+ default="models/rdm/rdm768x768/model.ckpt",
278
+ help="path to checkpoint of model",
279
+ )
280
+
281
+ parser.add_argument(
282
+ "--clip_type",
283
+ type=str,
284
+ default="ViT-L/14",
285
+ help="which CLIP model to use for retrieval and NN encoding",
286
+ )
287
+ parser.add_argument(
288
+ "--database",
289
+ type=str,
290
+ default='artbench-surrealism',
291
+ choices=DATABASES,
292
+ help="The database used for the search, only applied when --use_neighbors=True",
293
+ )
294
+ parser.add_argument(
295
+ "--use_neighbors",
296
+ default=False,
297
+ action='store_true',
298
+ help="Include neighbors in addition to text prompt for conditioning",
299
+ )
300
+ parser.add_argument(
301
+ "--knn",
302
+ default=10,
303
+ type=int,
304
+ help="The number of included neighbors, only applied when --use_neighbors=True",
305
+ )
306
+
307
+ opt = parser.parse_args()
308
+
309
+ config = OmegaConf.load(f"{opt.config}")
310
+ model = load_model_from_config(config, f"{opt.ckpt}")
311
+
312
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
313
+ model = model.to(device)
314
+
315
+ clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
316
+
317
+ if opt.plms:
318
+ sampler = PLMSSampler(model)
319
+ else:
320
+ sampler = DDIMSampler(model)
321
+
322
+ os.makedirs(opt.outdir, exist_ok=True)
323
+ outpath = opt.outdir
324
+
325
+ batch_size = opt.n_samples
326
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
327
+ if not opt.from_file:
328
+ prompt = opt.prompt
329
+ assert prompt is not None
330
+ data = [batch_size * [prompt]]
331
+
332
+ else:
333
+ print(f"reading prompts from {opt.from_file}")
334
+ with open(opt.from_file, "r") as f:
335
+ data = f.read().splitlines()
336
+ data = list(chunk(data, batch_size))
337
+
338
+ sample_path = os.path.join(outpath, "samples")
339
+ os.makedirs(sample_path, exist_ok=True)
340
+ base_count = len(os.listdir(sample_path))
341
+ grid_count = len(os.listdir(outpath)) - 1
342
+
343
+ print(f"sampling scale for cfg is {opt.scale:.2f}")
344
+
345
+ searcher = None
346
+ if opt.use_neighbors:
347
+ searcher = Searcher(opt.database)
348
+
349
+ with torch.no_grad():
350
+ with model.ema_scope():
351
+ for n in trange(opt.n_iter, desc="Sampling"):
352
+ all_samples = list()
353
+ for prompts in tqdm(data, desc="data"):
354
+ print("sampling prompts:", prompts)
355
+ if isinstance(prompts, tuple):
356
+ prompts = list(prompts)
357
+ c = clip_text_encoder.encode(prompts)
358
+ uc = None
359
+ if searcher is not None:
360
+ nn_dict = searcher(c, opt.knn)
361
+ c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
362
+ if opt.scale != 1.0:
363
+ uc = torch.zeros_like(c)
364
+ if isinstance(prompts, tuple):
365
+ prompts = list(prompts)
366
+ shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
367
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
368
+ conditioning=c,
369
+ batch_size=c.shape[0],
370
+ shape=shape,
371
+ verbose=False,
372
+ unconditional_guidance_scale=opt.scale,
373
+ unconditional_conditioning=uc,
374
+ eta=opt.ddim_eta,
375
+ )
376
+
377
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
378
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
379
+
380
+ for x_sample in x_samples_ddim:
381
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
382
+ Image.fromarray(x_sample.astype(np.uint8)).save(
383
+ os.path.join(sample_path, f"{base_count:05}.png"))
384
+ base_count += 1
385
+ all_samples.append(x_samples_ddim)
386
+
387
+ if not opt.skip_grid:
388
+ # additionally, save as grid
389
+ grid = torch.stack(all_samples, 0)
390
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
391
+ grid = make_grid(grid, nrow=n_rows)
392
+
393
+ # to image
394
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
395
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
396
+ grid_count += 1
397
+
398
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
scripts/latent_imagenet_diffusion.ipynb.REMOVED.git-id ADDED
@@ -0,0 +1 @@
 
 
1
+ 607f94fc7d3ef6d8d1627017215476d9dfc7ddc4
scripts/train_searcher.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import numpy as np
3
+ import scann
4
+ import argparse
5
+ import glob
6
+ from multiprocessing import cpu_count
7
+ from tqdm import tqdm
8
+
9
+ from ldm.util import parallel_data_prefetch
10
+
11
+
12
+ def search_bruteforce(searcher):
13
+ return searcher.score_brute_force().build()
14
+
15
+
16
+ def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17
+ partioning_trainsize, num_leaves, num_leaves_to_search):
18
+ return searcher.tree(num_leaves=num_leaves,
19
+ num_leaves_to_search=num_leaves_to_search,
20
+ training_sample_size=partioning_trainsize). \
21
+ score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22
+
23
+
24
+ def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25
+ return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26
+ reorder_k).build()
27
+
28
+ def load_datapool(dpath):
29
+
30
+
31
+ def load_single_file(saved_embeddings):
32
+ compressed = np.load(saved_embeddings)
33
+ database = {key: compressed[key] for key in compressed.files}
34
+ return database
35
+
36
+ def load_multi_files(data_archive):
37
+ database = {key: [] for key in data_archive[0].files}
38
+ for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39
+ for key in d.files:
40
+ database[key].append(d[key])
41
+
42
+ return database
43
+
44
+ print(f'Load saved patch embedding from "{dpath}"')
45
+ file_content = glob.glob(os.path.join(dpath, '*.npz'))
46
+
47
+ if len(file_content) == 1:
48
+ data_pool = load_single_file(file_content[0])
49
+ elif len(file_content) > 1:
50
+ data = [np.load(f) for f in file_content]
51
+ prefetched_data = parallel_data_prefetch(load_multi_files, data,
52
+ n_proc=min(len(data), cpu_count()), target_data_type='dict')
53
+
54
+ data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55
+ else:
56
+ raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57
+
58
+ print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59
+ return data_pool
60
+
61
+
62
+ def train_searcher(opt,
63
+ metric='dot_product',
64
+ partioning_trainsize=None,
65
+ reorder_k=None,
66
+ # todo tune
67
+ aiq_thld=0.2,
68
+ dims_per_block=2,
69
+ num_leaves=None,
70
+ num_leaves_to_search=None,):
71
+
72
+ data_pool = load_datapool(opt.database)
73
+ k = opt.knn
74
+
75
+ if not reorder_k:
76
+ reorder_k = 2 * k
77
+
78
+ # normalize
79
+ # embeddings =
80
+ searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81
+ pool_size = data_pool['embedding'].shape[0]
82
+
83
+ print(*(['#'] * 100))
84
+ print('Initializing scaNN searcher with the following values:')
85
+ print(f'k: {k}')
86
+ print(f'metric: {metric}')
87
+ print(f'reorder_k: {reorder_k}')
88
+ print(f'anisotropic_quantization_threshold: {aiq_thld}')
89
+ print(f'dims_per_block: {dims_per_block}')
90
+ print(*(['#'] * 100))
91
+ print('Start training searcher....')
92
+ print(f'N samples in pool is {pool_size}')
93
+
94
+ # this reflects the recommended design choices proposed at
95
+ # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96
+ if pool_size < 2e4:
97
+ print('Using brute force search.')
98
+ searcher = search_bruteforce(searcher)
99
+ elif 2e4 <= pool_size and pool_size < 1e5:
100
+ print('Using asymmetric hashing search and reordering.')
101
+ searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102
+ else:
103
+ print('Using using partioning, asymmetric hashing search and reordering.')
104
+
105
+ if not partioning_trainsize:
106
+ partioning_trainsize = data_pool['embedding'].shape[0] // 10
107
+ if not num_leaves:
108
+ num_leaves = int(np.sqrt(pool_size))
109
+
110
+ if not num_leaves_to_search:
111
+ num_leaves_to_search = max(num_leaves // 20, 1)
112
+
113
+ print('Partitioning params:')
114
+ print(f'num_leaves: {num_leaves}')
115
+ print(f'num_leaves_to_search: {num_leaves_to_search}')
116
+ # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117
+ searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118
+ partioning_trainsize, num_leaves, num_leaves_to_search)
119
+
120
+ print('Finish training searcher')
121
+ searcher_savedir = opt.target_path
122
+ os.makedirs(searcher_savedir, exist_ok=True)
123
+ searcher.serialize(searcher_savedir)
124
+ print(f'Saved trained searcher under "{searcher_savedir}"')
125
+
126
+ if __name__ == '__main__':
127
+ sys.path.append(os.getcwd())
128
+ parser = argparse.ArgumentParser()
129
+ parser.add_argument('--database',
130
+ '-d',
131
+ default='data/rdm/retrieval_databases/openimages',
132
+ type=str,
133
+ help='path to folder containing the clip feature of the database')
134
+ parser.add_argument('--target_path',
135
+ '-t',
136
+ default='data/rdm/searchers/openimages',
137
+ type=str,
138
+ help='path to the target folder where the searcher shall be stored.')
139
+ parser.add_argument('--knn',
140
+ '-k',
141
+ default=20,
142
+ type=int,
143
+ help='number of nearest neighbors, for which the searcher shall be optimized')
144
+
145
+ opt, _ = parser.parse_known_args()
146
+
147
+ train_searcher(opt,)
scripts/txt2img.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import torch
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ from PIL import Image
6
+ from tqdm import tqdm, trange
7
+ from itertools import islice
8
+ from einops import rearrange
9
+ from torchvision.utils import make_grid
10
+ import time
11
+ from pytorch_lightning import seed_everything
12
+ from torch import autocast
13
+ from contextlib import contextmanager, nullcontext
14
+
15
+ from ldm.util import instantiate_from_config
16
+ from ldm.models.diffusion.ddim import DDIMSampler
17
+ from ldm.models.diffusion.plms import PLMSSampler
18
+
19
+
20
+ def chunk(it, size):
21
+ it = iter(it)
22
+ return iter(lambda: tuple(islice(it, size)), ())
23
+
24
+
25
+ def load_model_from_config(config, ckpt, verbose=False):
26
+ print(f"Loading model from {ckpt}")
27
+ pl_sd = torch.load(ckpt, map_location="cpu")
28
+ if "global_step" in pl_sd:
29
+ print(f"Global Step: {pl_sd['global_step']}")
30
+ sd = pl_sd["state_dict"]
31
+ model = instantiate_from_config(config.model)
32
+ m, u = model.load_state_dict(sd, strict=False)
33
+ if len(m) > 0 and verbose:
34
+ print("missing keys:")
35
+ print(m)
36
+ if len(u) > 0 and verbose:
37
+ print("unexpected keys:")
38
+ print(u)
39
+
40
+ model.cuda()
41
+ model.eval()
42
+ return model
43
+
44
+
45
+ def main():
46
+ parser = argparse.ArgumentParser()
47
+
48
+ parser.add_argument(
49
+ "--prompt",
50
+ type=str,
51
+ nargs="?",
52
+ default="a painting of a virus monster playing guitar",
53
+ help="the prompt to render"
54
+ )
55
+ parser.add_argument(
56
+ "--outdir",
57
+ type=str,
58
+ nargs="?",
59
+ help="dir to write results to",
60
+ default="outputs/txt2img-samples"
61
+ )
62
+ parser.add_argument(
63
+ "--skip_grid",
64
+ action='store_true',
65
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
66
+ )
67
+ parser.add_argument(
68
+ "--skip_save",
69
+ action='store_true',
70
+ help="do not save individual samples. For speed measurements.",
71
+ )
72
+ parser.add_argument(
73
+ "--ddim_steps",
74
+ type=int,
75
+ default=50,
76
+ help="number of ddim sampling steps",
77
+ )
78
+ parser.add_argument(
79
+ "--plms",
80
+ action='store_true',
81
+ help="use plms sampling",
82
+ )
83
+ parser.add_argument(
84
+ "--laion400m",
85
+ action='store_true',
86
+ help="uses the LAION400M model",
87
+ )
88
+ parser.add_argument(
89
+ "--fixed_code",
90
+ action='store_true',
91
+ help="if enabled, uses the same starting code across samples ",
92
+ )
93
+ parser.add_argument(
94
+ "--ddim_eta",
95
+ type=float,
96
+ default=0.0,
97
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
98
+ )
99
+ parser.add_argument(
100
+ "--n_iter",
101
+ type=int,
102
+ default=2,
103
+ help="sample this often",
104
+ )
105
+ parser.add_argument(
106
+ "--H",
107
+ type=int,
108
+ default=512,
109
+ help="image height, in pixel space",
110
+ )
111
+ parser.add_argument(
112
+ "--W",
113
+ type=int,
114
+ default=512,
115
+ help="image width, in pixel space",
116
+ )
117
+ parser.add_argument(
118
+ "--C",
119
+ type=int,
120
+ default=4,
121
+ help="latent channels",
122
+ )
123
+ parser.add_argument(
124
+ "--f",
125
+ type=int,
126
+ default=8,
127
+ help="downsampling factor",
128
+ )
129
+ parser.add_argument(
130
+ "--n_samples",
131
+ type=int,
132
+ default=3,
133
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
134
+ )
135
+ parser.add_argument(
136
+ "--n_rows",
137
+ type=int,
138
+ default=0,
139
+ help="rows in the grid (default: n_samples)",
140
+ )
141
+ parser.add_argument(
142
+ "--scale",
143
+ type=float,
144
+ default=7.5,
145
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
146
+ )
147
+ parser.add_argument(
148
+ "--from-file",
149
+ type=str,
150
+ help="if specified, load prompts from this file",
151
+ )
152
+ parser.add_argument(
153
+ "--config",
154
+ type=str,
155
+ default="configs/stable-diffusion/v1-inference.yaml",
156
+ help="path to config which constructs model",
157
+ )
158
+ parser.add_argument(
159
+ "--ckpt",
160
+ type=str,
161
+ default="models/ldm/stable-diffusion-v1/model.ckpt",
162
+ help="path to checkpoint of model",
163
+ )
164
+ parser.add_argument(
165
+ "--seed",
166
+ type=int,
167
+ default=42,
168
+ help="the seed (for reproducible sampling)",
169
+ )
170
+ parser.add_argument(
171
+ "--precision",
172
+ type=str,
173
+ help="evaluate at this precision",
174
+ choices=["full", "autocast"],
175
+ default="autocast"
176
+ )
177
+ opt = parser.parse_args()
178
+
179
+ if opt.laion400m:
180
+ print("Falling back to LAION 400M model...")
181
+ opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
182
+ opt.ckpt = "models/ldm/text2img-large/model.ckpt"
183
+ opt.outdir = "outputs/txt2img-samples-laion400m"
184
+
185
+ seed_everything(opt.seed)
186
+
187
+ config = OmegaConf.load(f"{opt.config}")
188
+ model = load_model_from_config(config, f"{opt.ckpt}")
189
+
190
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
191
+ model = model.to(device)
192
+
193
+ if opt.plms:
194
+ sampler = PLMSSampler(model)
195
+ else:
196
+ sampler = DDIMSampler(model)
197
+
198
+ os.makedirs(opt.outdir, exist_ok=True)
199
+ outpath = opt.outdir
200
+
201
+ batch_size = opt.n_samples
202
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
203
+ if not opt.from_file:
204
+ prompt = opt.prompt
205
+ assert prompt is not None
206
+ data = [batch_size * [prompt]]
207
+
208
+ else:
209
+ print(f"reading prompts from {opt.from_file}")
210
+ with open(opt.from_file, "r") as f:
211
+ data = f.read().splitlines()
212
+ data = list(chunk(data, batch_size))
213
+
214
+ sample_path = os.path.join(outpath, "samples")
215
+ os.makedirs(sample_path, exist_ok=True)
216
+ base_count = len(os.listdir(sample_path))
217
+ grid_count = len(os.listdir(outpath)) - 1
218
+
219
+ start_code = None
220
+ if opt.fixed_code:
221
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
222
+
223
+ precision_scope = autocast if opt.precision=="autocast" else nullcontext
224
+ with torch.no_grad():
225
+ with precision_scope("cuda"):
226
+ with model.ema_scope():
227
+ tic = time.time()
228
+ all_samples = list()
229
+ for n in trange(opt.n_iter, desc="Sampling"):
230
+ for prompts in tqdm(data, desc="data"):
231
+ uc = None
232
+ if opt.scale != 1.0:
233
+ uc = model.get_learned_conditioning(batch_size * [""])
234
+ if isinstance(prompts, tuple):
235
+ prompts = list(prompts)
236
+ c = model.get_learned_conditioning(prompts)
237
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
238
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
239
+ conditioning=c,
240
+ batch_size=opt.n_samples,
241
+ shape=shape,
242
+ verbose=False,
243
+ unconditional_guidance_scale=opt.scale,
244
+ unconditional_conditioning=uc,
245
+ eta=opt.ddim_eta,
246
+ x_T=start_code)
247
+
248
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
249
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
250
+
251
+ if not opt.skip_save:
252
+ for x_sample in x_samples_ddim:
253
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
254
+ Image.fromarray(x_sample.astype(np.uint8)).save(
255
+ os.path.join(sample_path, f"{base_count:05}.png"))
256
+ base_count += 1
257
+
258
+ if not opt.skip_grid:
259
+ all_samples.append(x_samples_ddim)
260
+
261
+ if not opt.skip_grid:
262
+ # additionally, save as grid
263
+ grid = torch.stack(all_samples, 0)
264
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
265
+ grid = make_grid(grid, nrow=n_rows)
266
+
267
+ # to image
268
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
269
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
270
+ grid_count += 1
271
+
272
+ toc = time.time()
273
+
274
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
275
+ f" \nEnjoy.")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()