+
+# TME 对diffusers的适配
+- rebase 20231124 的diffusers, commit id `3003ff4947ea43fb56aa0df3da61c85652f24c69`
+
+# 原版
+🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
+
+🤗 Diffusers offers three core components:
+
+- State-of-the-art [diffusion pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) that can be run in inference with just a few lines of code.
+- Interchangeable noise [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview) for different diffusion speeds and output quality.
+- Pretrained [models](https://huggingface.co/docs/diffusers/api/models/overview) that can be used as building blocks, and combined with schedulers, for creating your own end-to-end diffusion systems.
+
+## Installation
+
+We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation.
+
+### PyTorch
+
+With `pip` (official package):
+
+```bash
+pip install --upgrade diffusers[torch]
+```
+
+With `conda` (maintained by the community):
+
+```sh
+conda install -c conda-forge diffusers
+```
+
+### Flax
+
+With `pip` (official package):
+
+```bash
+pip install --upgrade diffusers[flax]
+```
+
+### Apple Silicon (M1/M2) support
+
+Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggingface.co/docs/diffusers/optimization/mps) guide.
+
+## Quickstart
+
+Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 15000+ checkpoints):
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipeline.to("cuda")
+pipeline("An image of a squirrel in Picasso style").images[0]
+```
+
+You can also dig into the models and schedulers toolbox to build your own diffusion system:
+
+```python
+from diffusers import DDPMScheduler, UNet2DModel
+from PIL import Image
+import torch
+
+scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
+model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
+scheduler.set_timesteps(50)
+
+sample_size = model.config.sample_size
+noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
+input = noise
+
+for t in scheduler.timesteps:
+ with torch.no_grad():
+ noisy_residual = model(input, t).sample
+ prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
+ input = prev_noisy_sample
+
+image = (input / 2 + 0.5).clamp(0, 1)
+image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
+image = Image.fromarray((image * 255).round().astype("uint8"))
+image
+```
+
+Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to launch your diffusion journey today!
+
+## How to navigate the documentation
+
+| **Documentation** | **What can I learn?** |
+|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. |
+| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
+| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
+| [Optimization](https://huggingface.co/docs/diffusers/optimization/opt_overview) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
+| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. |
+## Contribution
+
+We ❤️ contributions from the open-source community!
+If you want to contribute to this library, please check out our [Contribution guide](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).
+You can look out for [issues](https://github.com/huggingface/diffusers/issues) you'd like to tackle to contribute to the library.
+- See [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) for general opportunities to contribute
+- See [New model/pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) to contribute exciting new diffusion models / diffusion pipelines
+- See [New scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
+
+Also, say 👋 in our public Discord channel . We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or just hang out ☕.
+
+
+## Popular Tasks & Pipelines
+
+
+
+## Popular libraries using 🧨 Diffusers
+
+- https://github.com/microsoft/TaskMatrix
+- https://github.com/invoke-ai/InvokeAI
+- https://github.com/apple/ml-stable-diffusion
+- https://github.com/Sanster/lama-cleaner
+- https://github.com/IDEA-Research/Grounded-Segment-Anything
+- https://github.com/ashawkey/stable-dreamfusion
+- https://github.com/deep-floyd/IF
+- https://github.com/bentoml/BentoML
+- https://github.com/bmaltais/kohya_ss
+- +6000 other amazing GitHub repositories 💪
+
+Thank you for using us ❤️.
+
+## Credits
+
+This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:
+
+- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion)
+- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion)
+- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim)
+- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
+
+We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.
+
+## Citation
+
+```bibtex
+@misc{von-platen-etal-2022-diffusers,
+ author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf},
+ title = {Diffusers: State-of-the-art diffusion models},
+ year = {2022},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/huggingface/diffusers}}
+}
+```
diff --git a/diffusers/_typos.toml b/diffusers/_typos.toml
new file mode 100644
index 0000000000000000000000000000000000000000..551099f981e7885fbda9ed28e297bace0e13407b
--- /dev/null
+++ b/diffusers/_typos.toml
@@ -0,0 +1,13 @@
+# Files for typos
+# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
+
+[default.extend-identifiers]
+
+[default.extend-words]
+NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+nd="np" # nd may be np (numpy)
+parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py
+
+
+[files]
+extend-exclude = ["_typos.toml"]
diff --git a/diffusers/docker/diffusers-flax-cpu/Dockerfile b/diffusers/docker/diffusers-flax-cpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..57a9c1ec742200b48f8c2f906d1152e85e60584a
--- /dev/null
+++ b/diffusers/docker/diffusers-flax-cpu/Dockerfile
@@ -0,0 +1,44 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.8 \
+ python3-pip \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --upgrade --no-cache-dir \
+ clu \
+ "jax[cpu]>=0.2.16,!=0.3.2" \
+ "flax>=0.4.1" \
+ "jaxlib>=0.1.65" && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-flax-tpu/Dockerfile b/diffusers/docker/diffusers-flax-tpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..2517da586d74b43c4c94a0eca4651f047345ec4d
--- /dev/null
+++ b/diffusers/docker/diffusers-flax-tpu/Dockerfile
@@ -0,0 +1,46 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.8 \
+ python3-pip \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --no-cache-dir \
+ "jax[tpu]>=0.2.16,!=0.3.2" \
+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
+ python3 -m pip install --upgrade --no-cache-dir \
+ clu \
+ "flax>=0.4.1" \
+ "jaxlib>=0.1.65" && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-onnxruntime-cpu/Dockerfile b/diffusers/docker/diffusers-onnxruntime-cpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..75f45be87a033e9476c4038218c9c2fd2f1255a5
--- /dev/null
+++ b/diffusers/docker/diffusers-onnxruntime-cpu/Dockerfile
@@ -0,0 +1,44 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.8 \
+ python3-pip \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ onnxruntime \
+ --extra-index-url https://download.pytorch.org/whl/cpu && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-onnxruntime-cuda/Dockerfile b/diffusers/docker/diffusers-onnxruntime-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..2129dbcaf68c57755485e1e54e867af05b937336
--- /dev/null
+++ b/diffusers/docker/diffusers-onnxruntime-cuda/Dockerfile
@@ -0,0 +1,44 @@
+FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.8 \
+ python3-pip \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ "onnxruntime-gpu>=1.13.1" \
+ --extra-index-url https://download.pytorch.org/whl/cu117 && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers
+
+CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/diffusers/docker/diffusers-pytorch-compile-cuda/Dockerfile b/diffusers/docker/diffusers-pytorch-compile-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..da9f372bd664ef2e2d7f2db0c1e529f82a050513
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-compile-cuda/Dockerfile
@@ -0,0 +1,46 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.9 \
+ python3.9-dev \
+ python3-pip \
+ python3.9-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3.9 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
+ python3.9 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark && \
+ python3.9 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers \
+ omegaconf
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-pytorch-cpu/Dockerfile b/diffusers/docker/diffusers-pytorch-cpu/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..127c61a719c5f43cf10561e1e64123799ce62402
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-cpu/Dockerfile
@@ -0,0 +1,45 @@
+FROM ubuntu:20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ python3.8 \
+ python3-pip \
+ libgl1 \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark \
+ --extra-index-url https://download.pytorch.org/whl/cpu && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-pytorch-cuda/Dockerfile b/diffusers/docker/diffusers-pytorch-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..877bc6840e6b90be8d62f7819a5ccf2f48e4741f
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-cuda/Dockerfile
@@ -0,0 +1,46 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.8 \
+ python3-pip \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers \
+ omegaconf \
+ pytorch-lightning
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/diffusers/docker/diffusers-pytorch-xformers-cuda/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..003f8e1165a1ab3f54a8c763db7af3ab475867ac
--- /dev/null
+++ b/diffusers/docker/diffusers-pytorch-xformers-cuda/Dockerfile
@@ -0,0 +1,46 @@
+FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
+LABEL maintainer="Hugging Face"
+LABEL repository="diffusers"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+RUN apt update && \
+ apt install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libsndfile1-dev \
+ libgl1 \
+ python3.8 \
+ python3-pip \
+ python3.8-venv && \
+ rm -rf /var/lib/apt/lists
+
+# make sure to use venv
+RUN python3 -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
+RUN python3 -m pip install --no-cache-dir --upgrade pip && \
+ python3 -m pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ invisible_watermark && \
+ python3 -m pip install --no-cache-dir \
+ accelerate \
+ datasets \
+ hf-doc-builder \
+ huggingface-hub \
+ Jinja2 \
+ librosa \
+ numpy \
+ scipy \
+ tensorboard \
+ transformers \
+ omegaconf \
+ xformers
+
+CMD ["/bin/bash"]
diff --git a/diffusers/docs/README.md b/diffusers/docs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f85032c68931ec0faeda4b81c2638a29ad7964ea
--- /dev/null
+++ b/diffusers/docs/README.md
@@ -0,0 +1,268 @@
+
+
+# Generating the documentation
+
+To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
+you can install them with the following command, at the root of the code repository:
+
+```bash
+pip install -e ".[docs]"
+```
+
+Then you need to install our open source documentation builder tool:
+
+```bash
+pip install git+https://github.com/huggingface/doc-builder
+```
+
+---
+**NOTE**
+
+You only need to generate the documentation to inspect it locally (if you're planning changes and want to
+check how they look before committing for instance). You don't have to commit the built documentation.
+
+---
+
+## Previewing the documentation
+
+To preview the docs, first install the `watchdog` module with:
+
+```bash
+pip install watchdog
+```
+
+Then run the following command:
+
+```bash
+doc-builder preview {package_name} {path_to_docs}
+```
+
+For example:
+
+```bash
+doc-builder preview diffusers docs/source/en
+```
+
+The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
+
+---
+**NOTE**
+
+The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
+
+---
+
+## Adding a new element to the navigation bar
+
+Accepted files are Markdown (.md).
+
+Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
+the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml) file.
+
+## Renaming section headers and moving sections
+
+It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
+
+Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
+
+So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
+
+```md
+Sections that were moved:
+
+[ Section A ]
+```
+and of course, if you moved it to another file, then:
+
+```md
+Sections that were moved:
+
+[ Section A ]
+```
+
+Use the relative style to link to the new file so that the versioned docs continue to work.
+
+For an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md).
+
+
+## Writing Documentation - Specification
+
+The `huggingface/diffusers` documentation follows the
+[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
+although we can write them directly in Markdown.
+
+### Adding a new tutorial
+
+Adding a new tutorial or section is done in two steps:
+
+- Add a new Markdown (.md) file under `docs/source/`.
+- Link that file in `docs/source//_toctree.yml` on the correct toc-tree.
+
+Make sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so
+depending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or four.
+
+### Adding a new pipeline/scheduler
+
+When adding a new pipeline:
+
+- Create a file `xxx.md` under `docs/source//api/pipelines` (don't hesitate to copy an existing file as template).
+- Link that file in (*Diffusers Summary*) section in `docs/source/api/pipelines/overview.md`, along with the link to the paper, and a colab notebook (if available).
+- Write a short overview of the diffusion model:
+ - Overview with paper & authors
+ - Paper abstract
+ - Tips and tricks and how to use it best
+ - Possible an end-to-end example of how to use it
+- Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. By default as follows:
+
+```
+[[autodoc]] XXXPipeline
+ - all
+ - __call__
+```
+
+This will include every public method of the pipeline that is documented, as well as the `__call__` method that is not documented by default. If you just want to add additional methods that are not documented, you can put the list of all methods to add in a list that contains `all`.
+
+```
+[[autodoc]] XXXPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+```
+
+You can follow the same process to create a new scheduler under the `docs/source//api/schedulers` folder.
+
+### Writing source documentation
+
+Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
+and objects like True, None, or any strings should usually be put in `code`.
+
+When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
+adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
+function to be in the main package.
+
+If you want to create a link to some internal class or function, you need to
+provide its path. For instance: \[\`pipelines.ImagePipelineOutput\`\]. This will be converted into a link with
+`pipelines.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are
+linking to in the description, add a ~: \[\`~pipelines.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description.
+
+The same works for methods so you can either use \[\`XXXClass.method\`\] or \[\`~XXXClass.method\`\].
+
+#### Defining arguments in a method
+
+Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and
+an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
+description:
+
+```
+ Args:
+ n_layers (`int`): The number of layers of the model.
+```
+
+If the description is too long to fit in one line, another indentation is necessary before writing the description
+after the argument.
+
+Here's an example showcasing everything so far:
+
+```
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and
+ [`~PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+```
+
+For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
+following signature:
+
+```py
+def my_function(x: str=None, a: float=3.14):
+```
+
+then its documentation should look like this:
+
+```
+ Args:
+ x (`str`, *optional*):
+ This argument controls ...
+ a (`float`, *optional*, defaults to `3.14`):
+ This argument is used to ...
+```
+
+Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even
+if the first line describing your argument type and its default gets long, you can't break it on several lines. You can
+however write as many lines as you want in the indented description (see the example above with `input_ids`).
+
+#### Writing a multi-line code block
+
+Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
+
+
+````
+```
+# first line of code
+# second line
+# etc
+```
+````
+
+#### Writing a return block
+
+The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.
+The first line should be the type of the return, followed by a line return. No need to indent further for the elements
+building the return.
+
+Here's an example of a single value return:
+
+```
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
+```
+
+Here's an example of a tuple return, comprising several objects:
+
+```
+ Returns:
+ `tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
+ - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --
+ Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
+ - **prediction_scores** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+```
+
+#### Adding an image
+
+Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
+the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference
+them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).
+If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
+to this dataset.
+
+## Styling the docstring
+
+We have an automatic script running with the `make style` command that will make sure that:
+- the docstrings fully take advantage of the line width
+- all code examples are formatted using black, like the code of the Transformers library
+
+This script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's
+recommended to commit your changes before running `make style`, so you can revert the changes done by that script
+easily.
diff --git a/diffusers/docs/TRANSLATING.md b/diffusers/docs/TRANSLATING.md
new file mode 100644
index 0000000000000000000000000000000000000000..b64ac9fd8d688aa47bd83d3d08b9d70e27a2871f
--- /dev/null
+++ b/diffusers/docs/TRANSLATING.md
@@ -0,0 +1,69 @@
+
+
+### Translating the Diffusers documentation into your language
+
+As part of our mission to democratize machine learning, we'd love to make the Diffusers library available in many more languages! Follow the steps below if you want to help translate the documentation into your language 🙏.
+
+**🗞️ Open an issue**
+
+To get started, navigate to the [Issues](https://github.com/huggingface/diffusers/issues) page of this repo and check if anyone else has opened an issue for your language. If not, open a new issue by selecting the "🌐 Translating a New Language?" from the "New issue" button.
+
+Once an issue exists, post a comment to indicate which chapters you'd like to work on, and we'll add your name to the list.
+
+
+**🍴 Fork the repository**
+
+First, you'll need to [fork the Diffusers repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo). You can do this by clicking on the **Fork** button on the top-right corner of this repo's page.
+
+Once you've forked the repo, you'll want to get the files on your local machine for editing. You can do that by cloning the fork with Git as follows:
+
+```bash
+git clone https://github.com//diffusers.git
+```
+
+**📋 Copy-paste the English version with a new language code**
+
+The documentation files are in one leading directory:
+
+- [`docs/source`](https://github.com/huggingface/diffusers/tree/main/docs/source): All the documentation materials are organized here by language.
+
+You'll only need to copy the files in the [`docs/source/en`](https://github.com/huggingface/diffusers/tree/main/docs/source/en) directory, so first navigate to your fork of the repo and run the following:
+
+```bash
+cd ~/path/to/diffusers/docs
+cp -r source/en source/
+```
+
+Here, `` should be one of the ISO 639-1 or ISO 639-2 language codes -- see [here](https://www.loc.gov/standards/iso639-2/php/code_list.php) for a handy table.
+
+**✍️ Start translating**
+
+The fun part comes - translating the text!
+
+The first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website.
+
+> 🙋 If the `_toctree.yml` file doesn't yet exist for your language, you can create one by copy-pasting from the English version and deleting the sections unrelated to your chapter. Just make sure it exists in the `docs/source//` directory!
+
+The fields you should add are `local` (with the name of the file containing the translation; e.g. `autoclass_tutorial`), and `title` (with the title of the doc in your language; e.g. `Load pretrained instances with an AutoClass`) -- as a reference, here is the `_toctree.yml` for [English](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml):
+
+```yaml
+- sections:
+ - local: pipeline_tutorial # Do not change this! Use the same name for your .md file
+ title: Pipelines for inference # Translate this!
+ ...
+ title: Tutorials # Translate this!
+```
+
+Once you have translated the `_toctree.yml` file, you can start translating the [MDX](https://mdxjs.com/) files associated with your docs chapter.
+
+> 🙋 If you'd like others to help you with the translation, you should [open an issue](https://github.com/huggingface/diffusers/issues) and tag @patrickvonplaten.
diff --git a/diffusers/docs/source/_config.py b/diffusers/docs/source/_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0d73dcb951ea5b8b91e255d79b893a2a103ed3
--- /dev/null
+++ b/diffusers/docs/source/_config.py
@@ -0,0 +1,9 @@
+# docstyle-ignore
+INSTALL_CONTENT = """
+# Diffusers installation
+! pip install diffusers transformers datasets accelerate
+# To install from source instead of the last release, comment the command above and uncomment the following one.
+# ! pip install git+https://github.com/huggingface/diffusers.git
+"""
+
+notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
diff --git a/diffusers/docs/source/en/_toctree.yml b/diffusers/docs/source/en/_toctree.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d2583121418ea054eaa61b76cf529ea6b1408d12
--- /dev/null
+++ b/diffusers/docs/source/en/_toctree.yml
@@ -0,0 +1,434 @@
+- sections:
+ - local: index
+ title: 🧨 Diffusers
+ - local: quicktour
+ title: Quicktour
+ - local: stable_diffusion
+ title: Effective and efficient diffusion
+ - local: installation
+ title: Installation
+ title: Get started
+- sections:
+ - local: tutorials/tutorial_overview
+ title: Overview
+ - local: using-diffusers/write_own_pipeline
+ title: Understanding pipelines, models and schedulers
+ - local: tutorials/autopipeline
+ title: AutoPipeline
+ - local: tutorials/basic_training
+ title: Train a diffusion model
+ - local: tutorials/using_peft_for_inference
+ title: Inference with PEFT
+ title: Tutorials
+- sections:
+ - sections:
+ - local: using-diffusers/loading_overview
+ title: Overview
+ - local: using-diffusers/loading
+ title: Load pipelines, models, and schedulers
+ - local: using-diffusers/schedulers
+ title: Load and compare different schedulers
+ - local: using-diffusers/custom_pipeline_overview
+ title: Load community pipelines and components
+ - local: using-diffusers/using_safetensors
+ title: Load safetensors
+ - local: using-diffusers/other-formats
+ title: Load different Stable Diffusion formats
+ - local: using-diffusers/loading_adapters
+ title: Load adapters
+ - local: using-diffusers/push_to_hub
+ title: Push files to the Hub
+ title: Loading & Hub
+ - sections:
+ - local: using-diffusers/pipeline_overview
+ title: Overview
+ - local: using-diffusers/unconditional_image_generation
+ title: Unconditional image generation
+ - local: using-diffusers/conditional_image_generation
+ title: Text-to-image
+ - local: using-diffusers/img2img
+ title: Image-to-image
+ - local: using-diffusers/inpaint
+ title: Inpainting
+ - local: using-diffusers/depth2img
+ title: Depth-to-image
+ title: Tasks
+ - sections:
+ - local: using-diffusers/textual_inversion_inference
+ title: Textual inversion
+ - local: training/distributed_inference
+ title: Distributed inference with multiple GPUs
+ - local: using-diffusers/reusing_seeds
+ title: Improve image quality with deterministic generation
+ - local: using-diffusers/control_brightness
+ title: Control image brightness
+ - local: using-diffusers/weighted_prompts
+ title: Prompt weighting
+ - local: using-diffusers/freeu
+ title: Improve generation quality with FreeU
+ title: Techniques
+ - sections:
+ - local: using-diffusers/pipeline_overview
+ title: Overview
+ - local: using-diffusers/sdxl
+ title: Stable Diffusion XL
+ - local: using-diffusers/kandinsky
+ title: Kandinsky
+ - local: using-diffusers/controlnet
+ title: ControlNet
+ - local: using-diffusers/shap-e
+ title: Shap-E
+ - local: using-diffusers/diffedit
+ title: DiffEdit
+ - local: using-diffusers/distilled_sd
+ title: Distilled Stable Diffusion inference
+ - local: using-diffusers/callback
+ title: Pipeline callbacks
+ - local: using-diffusers/reproducibility
+ title: Create reproducible pipelines
+ - local: using-diffusers/custom_pipeline_examples
+ title: Community pipelines
+ - local: using-diffusers/contribute_pipeline
+ title: Contribute a community pipeline
+ - local: using-diffusers/inference_with_lcm_lora
+ title: Latent Consistency Model-LoRA
+ - local: using-diffusers/inference_with_lcm
+ title: Latent Consistency Model
+ title: Specific pipeline examples
+ - sections:
+ - local: training/overview
+ title: Overview
+ - local: training/create_dataset
+ title: Create a dataset for training
+ - local: training/adapt_a_model
+ title: Adapt a model to a new task
+ - sections:
+ - local: training/unconditional_training
+ title: Unconditional image generation
+ - local: training/text2image
+ title: Text-to-image
+ - local: training/sdxl
+ title: Stable Diffusion XL
+ - local: training/kandinsky
+ title: Kandinsky 2.2
+ - local: training/wuerstchen
+ title: Wuerstchen
+ - local: training/controlnet
+ title: ControlNet
+ - local: training/t2i_adapters
+ title: T2I-Adapters
+ - local: training/instructpix2pix
+ title: InstructPix2Pix
+ title: Models
+ - sections:
+ - local: training/text_inversion
+ title: Textual Inversion
+ - local: training/dreambooth
+ title: DreamBooth
+ - local: training/lora
+ title: LoRA
+ - local: training/custom_diffusion
+ title: Custom Diffusion
+ - local: training/ddpo
+ title: Reinforcement learning training with DDPO
+ title: Methods
+ title: Training
+ - sections:
+ - local: using-diffusers/other-modalities
+ title: Other Modalities
+ title: Taking Diffusers Beyond Images
+ title: Using Diffusers
+- sections:
+ - local: optimization/opt_overview
+ title: Overview
+ - sections:
+ - local: optimization/fp16
+ title: Speed up inference
+ - local: optimization/memory
+ title: Reduce memory usage
+ - local: optimization/torch2.0
+ title: PyTorch 2.0
+ - local: optimization/xformers
+ title: xFormers
+ - local: optimization/tome
+ title: Token merging
+ title: General optimizations
+ - sections:
+ - local: using-diffusers/stable_diffusion_jax_how_to
+ title: JAX/Flax
+ - local: optimization/onnx
+ title: ONNX
+ - local: optimization/open_vino
+ title: OpenVINO
+ - local: optimization/coreml
+ title: Core ML
+ title: Optimized model types
+ - sections:
+ - local: optimization/mps
+ title: Metal Performance Shaders (MPS)
+ - local: optimization/habana
+ title: Habana Gaudi
+ title: Optimized hardware
+ title: Optimization
+- sections:
+ - local: conceptual/philosophy
+ title: Philosophy
+ - local: using-diffusers/controlling_generation
+ title: Controlled generation
+ - local: conceptual/contribution
+ title: How to contribute?
+ - local: conceptual/ethical_guidelines
+ title: Diffusers' Ethical Guidelines
+ - local: conceptual/evaluation
+ title: Evaluating Diffusion Models
+ title: Conceptual Guides
+- sections:
+ - sections:
+ - local: api/configuration
+ title: Configuration
+ - local: api/logging
+ title: Logging
+ - local: api/outputs
+ title: Outputs
+ title: Main Classes
+ - sections:
+ - local: api/loaders/lora
+ title: LoRA
+ - local: api/loaders/single_file
+ title: Single files
+ - local: api/loaders/textual_inversion
+ title: Textual Inversion
+ - local: api/loaders/unet
+ title: UNet
+ title: Loaders
+ - sections:
+ - local: api/models/overview
+ title: Overview
+ - local: api/models/unet
+ title: UNet1DModel
+ - local: api/models/unet2d
+ title: UNet2DModel
+ - local: api/models/unet2d-cond
+ title: UNet2DConditionModel
+ - local: api/models/unet3d-cond
+ title: UNet3DConditionModel
+ - local: api/models/unet-motion
+ title: UNetMotionModel
+ - local: api/models/vq
+ title: VQModel
+ - local: api/models/autoencoderkl
+ title: AutoencoderKL
+ - local: api/models/asymmetricautoencoderkl
+ title: AsymmetricAutoencoderKL
+ - local: api/models/autoencoder_tiny
+ title: Tiny AutoEncoder
+ - local: api/models/consistency_decoder_vae
+ title: ConsistencyDecoderVAE
+ - local: api/models/transformer2d
+ title: Transformer2D
+ - local: api/models/transformer_temporal
+ title: Transformer Temporal
+ - local: api/models/prior_transformer
+ title: Prior Transformer
+ - local: api/models/controlnet
+ title: ControlNet
+ title: Models
+ - sections:
+ - local: api/pipelines/overview
+ title: Overview
+ - local: api/pipelines/alt_diffusion
+ title: AltDiffusion
+ - local: api/pipelines/animatediff
+ title: AnimateDiff
+ - local: api/pipelines/attend_and_excite
+ title: Attend-and-Excite
+ - local: api/pipelines/audio_diffusion
+ title: Audio Diffusion
+ - local: api/pipelines/audioldm
+ title: AudioLDM
+ - local: api/pipelines/audioldm2
+ title: AudioLDM 2
+ - local: api/pipelines/auto_pipeline
+ title: AutoPipeline
+ - local: api/pipelines/blip_diffusion
+ title: BLIP-Diffusion
+ - local: api/pipelines/consistency_models
+ title: Consistency Models
+ - local: api/pipelines/controlnet
+ title: ControlNet
+ - local: api/pipelines/controlnet_sdxl
+ title: ControlNet with Stable Diffusion XL
+ - local: api/pipelines/cycle_diffusion
+ title: Cycle Diffusion
+ - local: api/pipelines/dance_diffusion
+ title: Dance Diffusion
+ - local: api/pipelines/ddim
+ title: DDIM
+ - local: api/pipelines/ddpm
+ title: DDPM
+ - local: api/pipelines/deepfloyd_if
+ title: DeepFloyd IF
+ - local: api/pipelines/diffedit
+ title: DiffEdit
+ - local: api/pipelines/dit
+ title: DiT
+ - local: api/pipelines/pix2pix
+ title: InstructPix2Pix
+ - local: api/pipelines/kandinsky
+ title: Kandinsky 2.1
+ - local: api/pipelines/kandinsky_v22
+ title: Kandinsky 2.2
+ - local: api/pipelines/latent_consistency_models
+ title: Latent Consistency Models
+ - local: api/pipelines/latent_diffusion
+ title: Latent Diffusion
+ - local: api/pipelines/panorama
+ title: MultiDiffusion
+ - local: api/pipelines/musicldm
+ title: MusicLDM
+ - local: api/pipelines/paint_by_example
+ title: Paint by Example
+ - local: api/pipelines/paradigms
+ title: Parallel Sampling of Diffusion Models
+ - local: api/pipelines/pix2pix_zero
+ title: Pix2Pix Zero
+ - local: api/pipelines/pixart
+ title: PixArt-α
+ - local: api/pipelines/pndm
+ title: PNDM
+ - local: api/pipelines/repaint
+ title: RePaint
+ - local: api/pipelines/score_sde_ve
+ title: Score SDE VE
+ - local: api/pipelines/self_attention_guidance
+ title: Self-Attention Guidance
+ - local: api/pipelines/semantic_stable_diffusion
+ title: Semantic Guidance
+ - local: api/pipelines/shap_e
+ title: Shap-E
+ - local: api/pipelines/spectrogram_diffusion
+ title: Spectrogram Diffusion
+ - sections:
+ - local: api/pipelines/stable_diffusion/overview
+ title: Overview
+ - local: api/pipelines/stable_diffusion/text2img
+ title: Text-to-image
+ - local: api/pipelines/stable_diffusion/img2img
+ title: Image-to-image
+ - local: api/pipelines/stable_diffusion/inpaint
+ title: Inpainting
+ - local: api/pipelines/stable_diffusion/depth2img
+ title: Depth-to-image
+ - local: api/pipelines/stable_diffusion/image_variation
+ title: Image variation
+ - local: api/pipelines/stable_diffusion/stable_diffusion_safe
+ title: Safe Stable Diffusion
+ - local: api/pipelines/stable_diffusion/stable_diffusion_2
+ title: Stable Diffusion 2
+ - local: api/pipelines/stable_diffusion/stable_diffusion_xl
+ title: Stable Diffusion XL
+ - local: api/pipelines/stable_diffusion/latent_upscale
+ title: Latent upscaler
+ - local: api/pipelines/stable_diffusion/upscale
+ title: Super-resolution
+ - local: api/pipelines/stable_diffusion/ldm3d_diffusion
+ title: LDM3D Text-to-(RGB, Depth)
+ - local: api/pipelines/stable_diffusion/adapter
+ title: Stable Diffusion T2I-Adapter
+ - local: api/pipelines/stable_diffusion/gligen
+ title: GLIGEN (Grounded Language-to-Image Generation)
+ title: Stable Diffusion
+ - local: api/pipelines/stable_unclip
+ title: Stable unCLIP
+ - local: api/pipelines/stochastic_karras_ve
+ title: Stochastic Karras VE
+ - local: api/pipelines/model_editing
+ title: Text-to-image model editing
+ - local: api/pipelines/text_to_video
+ title: Text-to-video
+ - local: api/pipelines/text_to_video_zero
+ title: Text2Video-Zero
+ - local: api/pipelines/unclip
+ title: unCLIP
+ - local: api/pipelines/latent_diffusion_uncond
+ title: Unconditional Latent Diffusion
+ - local: api/pipelines/unidiffuser
+ title: UniDiffuser
+ - local: api/pipelines/value_guided_sampling
+ title: Value-guided sampling
+ - local: api/pipelines/versatile_diffusion
+ title: Versatile Diffusion
+ - local: api/pipelines/vq_diffusion
+ title: VQ Diffusion
+ - local: api/pipelines/wuerstchen
+ title: Wuerstchen
+ title: Pipelines
+ - sections:
+ - local: api/schedulers/overview
+ title: Overview
+ - local: api/schedulers/cm_stochastic_iterative
+ title: CMStochasticIterativeScheduler
+ - local: api/schedulers/consistency_decoder
+ title: ConsistencyDecoderScheduler
+ - local: api/schedulers/ddim_inverse
+ title: DDIMInverseScheduler
+ - local: api/schedulers/ddim
+ title: DDIMScheduler
+ - local: api/schedulers/ddpm
+ title: DDPMScheduler
+ - local: api/schedulers/deis
+ title: DEISMultistepScheduler
+ - local: api/schedulers/multistep_dpm_solver_inverse
+ title: DPMSolverMultistepInverse
+ - local: api/schedulers/multistep_dpm_solver
+ title: DPMSolverMultistepScheduler
+ - local: api/schedulers/dpm_sde
+ title: DPMSolverSDEScheduler
+ - local: api/schedulers/singlestep_dpm_solver
+ title: DPMSolverSinglestepScheduler
+ - local: api/schedulers/euler_ancestral
+ title: EulerAncestralDiscreteScheduler
+ - local: api/schedulers/euler
+ title: EulerDiscreteScheduler
+ - local: api/schedulers/heun
+ title: HeunDiscreteScheduler
+ - local: api/schedulers/ipndm
+ title: IPNDMScheduler
+ - local: api/schedulers/stochastic_karras_ve
+ title: KarrasVeScheduler
+ - local: api/schedulers/dpm_discrete_ancestral
+ title: KDPM2AncestralDiscreteScheduler
+ - local: api/schedulers/dpm_discrete
+ title: KDPM2DiscreteScheduler
+ - local: api/schedulers/lcm
+ title: LCMScheduler
+ - local: api/schedulers/lms_discrete
+ title: LMSDiscreteScheduler
+ - local: api/schedulers/pndm
+ title: PNDMScheduler
+ - local: api/schedulers/repaint
+ title: RePaintScheduler
+ - local: api/schedulers/score_sde_ve
+ title: ScoreSdeVeScheduler
+ - local: api/schedulers/score_sde_vp
+ title: ScoreSdeVpScheduler
+ - local: api/schedulers/unipc
+ title: UniPCMultistepScheduler
+ - local: api/schedulers/vq_diffusion
+ title: VQDiffusionScheduler
+ title: Schedulers
+ - sections:
+ - local: api/internal_classes_overview
+ title: Overview
+ - local: api/attnprocessor
+ title: Attention Processor
+ - local: api/activations
+ title: Custom activation functions
+ - local: api/normalization
+ title: Custom normalization layers
+ - local: api/utilities
+ title: Utilities
+ - local: api/image_processor
+ title: VAE Image Processor
+ title: Internal classes
+ title: API
diff --git a/diffusers/docs/source/en/api/activations.md b/diffusers/docs/source/en/api/activations.md
new file mode 100644
index 0000000000000000000000000000000000000000..e4f4567caca0ad90e1de9cf8ccff60b9879b26f6
--- /dev/null
+++ b/diffusers/docs/source/en/api/activations.md
@@ -0,0 +1,27 @@
+
+
+# Activation functions
+
+Customized activation functions for supporting various models in 🤗 Diffusers.
+
+## GELU
+
+[[autodoc]] models.activations.GELU
+
+## GEGLU
+
+[[autodoc]] models.activations.GEGLU
+
+## ApproximateGELU
+
+[[autodoc]] models.activations.ApproximateGELU
diff --git a/diffusers/docs/source/en/api/attnprocessor.md b/diffusers/docs/source/en/api/attnprocessor.md
new file mode 100644
index 0000000000000000000000000000000000000000..f6ee09f124be92a0e6e66c1e50cef374aa22690a
--- /dev/null
+++ b/diffusers/docs/source/en/api/attnprocessor.md
@@ -0,0 +1,57 @@
+
+
+# Attention Processor
+
+An attention processor is a class for applying different types of attention mechanisms.
+
+## AttnProcessor
+[[autodoc]] models.attention_processor.AttnProcessor
+
+## AttnProcessor2_0
+[[autodoc]] models.attention_processor.AttnProcessor2_0
+
+## LoRAAttnProcessor
+[[autodoc]] models.attention_processor.LoRAAttnProcessor
+
+## LoRAAttnProcessor2_0
+[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
+
+## CustomDiffusionAttnProcessor
+[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
+
+## CustomDiffusionAttnProcessor2_0
+[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
+
+## AttnAddedKVProcessor
+[[autodoc]] models.attention_processor.AttnAddedKVProcessor
+
+## AttnAddedKVProcessor2_0
+[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
+
+## LoRAAttnAddedKVProcessor
+[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
+
+## XFormersAttnProcessor
+[[autodoc]] models.attention_processor.XFormersAttnProcessor
+
+## LoRAXFormersAttnProcessor
+[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
+
+## CustomDiffusionXFormersAttnProcessor
+[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
+
+## SlicedAttnProcessor
+[[autodoc]] models.attention_processor.SlicedAttnProcessor
+
+## SlicedAttnAddedKVProcessor
+[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
diff --git a/diffusers/docs/source/en/api/configuration.md b/diffusers/docs/source/en/api/configuration.md
new file mode 100644
index 0000000000000000000000000000000000000000..a10e348acdefedafd67e670f05413fe845b78c20
--- /dev/null
+++ b/diffusers/docs/source/en/api/configuration.md
@@ -0,0 +1,30 @@
+
+
+# Configuration
+
+Schedulers from [`~schedulers.scheduling_utils.SchedulerMixin`] and models from [`ModelMixin`] inherit from [`ConfigMixin`] which stores all the parameters that are passed to their respective `__init__` methods in a JSON-configuration file.
+
+
+
+To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`.
+
+
+
+## ConfigMixin
+
+[[autodoc]] ConfigMixin
+ - load_config
+ - from_config
+ - save_config
+ - to_json_file
+ - to_json_string
diff --git a/diffusers/docs/source/en/api/image_processor.md b/diffusers/docs/source/en/api/image_processor.md
new file mode 100644
index 0000000000000000000000000000000000000000..fb446c944c3a062a342af5f65471e29160ccd27d
--- /dev/null
+++ b/diffusers/docs/source/en/api/image_processor.md
@@ -0,0 +1,27 @@
+
+
+# VAE Image Processor
+
+The [`VaeImageProcessor`] provides a unified API for [`StableDiffusionPipeline`]s to prepare image inputs for VAE encoding and post-processing outputs once they're decoded. This includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays.
+
+All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or NumPy arrays as image inputs and return outputs based on the `output_type` argument by the user. You can pass encoded image latents directly to the pipeline and return latents from the pipeline as a specific output with the `output_type` argument (for example `output_type="latent"`). This allows you to take the generated latents from one pipeline and pass it to another pipeline as input without leaving the latent space. It also makes it much easier to use multiple pipelines together by passing PyTorch tensors directly between different pipelines.
+
+## VaeImageProcessor
+
+[[autodoc]] image_processor.VaeImageProcessor
+
+## VaeImageProcessorLDM3D
+
+The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.
+
+[[autodoc]] image_processor.VaeImageProcessorLDM3D
diff --git a/diffusers/docs/source/en/api/internal_classes_overview.md b/diffusers/docs/source/en/api/internal_classes_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..5c8d2cc0e38741ad11f44248aa5165898b92032a
--- /dev/null
+++ b/diffusers/docs/source/en/api/internal_classes_overview.md
@@ -0,0 +1,15 @@
+
+
+# Overview
+
+The APIs in this section are more experimental and prone to breaking changes. Most of them are used internally for development, but they may also be useful to you if you're interested in building a diffusion model with some custom parts or if you're interested in some of our helper utilities for working with 🤗 Diffusers.
diff --git a/diffusers/docs/source/en/api/loaders/lora.md b/diffusers/docs/source/en/api/loaders/lora.md
new file mode 100644
index 0000000000000000000000000000000000000000..05ff11afc5d41b7719cca292e873a112686d0e9c
--- /dev/null
+++ b/diffusers/docs/source/en/api/loaders/lora.md
@@ -0,0 +1,32 @@
+
+
+# LoRA
+
+LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights:
+
+- [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
+- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
+
+
+
+To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
+
+
+
+## LoraLoaderMixin
+
+[[autodoc]] loaders.lora.LoraLoaderMixin
+
+## StableDiffusionXLLoraLoaderMixin
+
+[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
\ No newline at end of file
diff --git a/diffusers/docs/source/en/api/loaders/single_file.md b/diffusers/docs/source/en/api/loaders/single_file.md
new file mode 100644
index 0000000000000000000000000000000000000000..52e44606455bbdc99d31c851d55495cc62ea1246
--- /dev/null
+++ b/diffusers/docs/source/en/api/loaders/single_file.md
@@ -0,0 +1,37 @@
+
+
+# Single files
+
+Diffusers supports loading pretrained pipeline (or model) weights stored in a single file, such as a `ckpt` or `safetensors` file. These single file types are typically produced from community trained models. There are three classes for loading single file weights:
+
+- [`FromSingleFileMixin`] supports loading pretrained pipeline weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
+- [`FromOriginalVAEMixin`] supports loading a pretrained [`AutoencoderKL`] from pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
+- [`FromOriginalControlnetMixin`] supports loading pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
+
+
+
+To learn more about how to load single file weights, see the [Load different Stable Diffusion formats](../../using-diffusers/other-formats) loading guide.
+
+
+
+## FromSingleFileMixin
+
+[[autodoc]] loaders.single_file.FromSingleFileMixin
+
+## FromOriginalVAEMixin
+
+[[autodoc]] loaders.single_file.FromOriginalVAEMixin
+
+## FromOriginalControlnetMixin
+
+[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
\ No newline at end of file
diff --git a/diffusers/docs/source/en/api/loaders/textual_inversion.md b/diffusers/docs/source/en/api/loaders/textual_inversion.md
new file mode 100644
index 0000000000000000000000000000000000000000..28d38ddb5bf2896a9e6b50f82a0956bb04f5c89c
--- /dev/null
+++ b/diffusers/docs/source/en/api/loaders/textual_inversion.md
@@ -0,0 +1,27 @@
+
+
+# Textual Inversion
+
+Textual Inversion is a training method for personalizing models by learning new text embeddings from a few example images. The file produced from training is extremely small (a few KBs) and the new embeddings can be loaded into the text encoder.
+
+[`TextualInversionLoaderMixin`] provides a function for loading Textual Inversion embeddings from Diffusers and Automatic1111 into the text encoder and loading a special token to activate the embeddings.
+
+
+
+To learn more about how to load Textual Inversion embeddings, see the [Textual Inversion](../../using-diffusers/loading_adapters#textual-inversion) loading guide.
+
+
+
+## TextualInversionLoaderMixin
+
+[[autodoc]] loaders.textual_inversion.TextualInversionLoaderMixin
\ No newline at end of file
diff --git a/diffusers/docs/source/en/api/loaders/unet.md b/diffusers/docs/source/en/api/loaders/unet.md
new file mode 100644
index 0000000000000000000000000000000000000000..df896a065eb3221d79aea9da06d6a72e26e92c6e
--- /dev/null
+++ b/diffusers/docs/source/en/api/loaders/unet.md
@@ -0,0 +1,27 @@
+
+
+# UNet
+
+Some training methods - like LoRA and Custom Diffusion - typically target the UNet's attention layers, but these training methods can also target other non-attention layers. Instead of training all of a model's parameters, only a subset of the parameters are trained, which is faster and more efficient. This class is useful if you're *only* loading weights into a UNet. If you need to load weights into the text encoder or a text encoder and UNet, try using the [`~loaders.LoraLoaderMixin.load_lora_weights`] function instead.
+
+The [`UNet2DConditionLoadersMixin`] class provides functions for loading and saving weights, fusing and unfusing LoRAs, disabling and enabling LoRAs, and setting and deleting adapters.
+
+
+
+To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
+
+
+
+## UNet2DConditionLoadersMixin
+
+[[autodoc]] loaders.unet.UNet2DConditionLoadersMixin
\ No newline at end of file
diff --git a/diffusers/docs/source/en/api/logging.md b/diffusers/docs/source/en/api/logging.md
new file mode 100644
index 0000000000000000000000000000000000000000..b31b7c11755ee8942f26e4916c33e6eb6cc794d6
--- /dev/null
+++ b/diffusers/docs/source/en/api/logging.md
@@ -0,0 +1,96 @@
+
+
+# Logging
+
+🤗 Diffusers has a centralized logging system to easily manage the verbosity of the library. The default verbosity is set to `WARNING`.
+
+To change the verbosity level, use one of the direct setters. For instance, to change the verbosity to the `INFO` level.
+
+```python
+import diffusers
+
+diffusers.logging.set_verbosity_info()
+```
+
+You can also use the environment variable `DIFFUSERS_VERBOSITY` to override the default verbosity. You can set it
+to one of the following: `debug`, `info`, `warning`, `error`, `critical`. For example:
+
+```bash
+DIFFUSERS_VERBOSITY=error ./myprogram.py
+```
+
+Additionally, some `warnings` can be disabled by setting the environment variable
+`DIFFUSERS_NO_ADVISORY_WARNINGS` to a true value, like `1`. This disables any warning logged by
+[`logger.warning_advice`]. For example:
+
+```bash
+DIFFUSERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py
+```
+
+Here is an example of how to use the same logger as the library in your own module or script:
+
+```python
+from diffusers.utils import logging
+
+logging.set_verbosity_info()
+logger = logging.get_logger("diffusers")
+logger.info("INFO")
+logger.warning("WARN")
+```
+
+
+All methods of the logging module are documented below. The main methods are
+[`logging.get_verbosity`] to get the current level of verbosity in the logger and
+[`logging.set_verbosity`] to set the verbosity to the level of your choice.
+
+In order from the least verbose to the most verbose:
+
+| Method | Integer value | Description |
+|----------------------------------------------------------:|--------------:|----------------------------------------------------:|
+| `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` | 50 | only report the most critical errors |
+| `diffusers.logging.ERROR` | 40 | only report errors |
+| `diffusers.logging.WARNING` or `diffusers.logging.WARN` | 30 | only report errors and warnings (default) |
+| `diffusers.logging.INFO` | 20 | only report errors, warnings, and basic information |
+| `diffusers.logging.DEBUG` | 10 | report all information |
+
+By default, `tqdm` progress bars are displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] are used to enable or disable this behavior.
+
+## Base setters
+
+[[autodoc]] utils.logging.set_verbosity_error
+
+[[autodoc]] utils.logging.set_verbosity_warning
+
+[[autodoc]] utils.logging.set_verbosity_info
+
+[[autodoc]] utils.logging.set_verbosity_debug
+
+## Other functions
+
+[[autodoc]] utils.logging.get_verbosity
+
+[[autodoc]] utils.logging.set_verbosity
+
+[[autodoc]] utils.logging.get_logger
+
+[[autodoc]] utils.logging.enable_default_handler
+
+[[autodoc]] utils.logging.disable_default_handler
+
+[[autodoc]] utils.logging.enable_explicit_format
+
+[[autodoc]] utils.logging.reset_format
+
+[[autodoc]] utils.logging.enable_progress_bar
+
+[[autodoc]] utils.logging.disable_progress_bar
diff --git a/diffusers/docs/source/en/api/models/asymmetricautoencoderkl.md b/diffusers/docs/source/en/api/models/asymmetricautoencoderkl.md
new file mode 100644
index 0000000000000000000000000000000000000000..1e102943c5e4ce71876c16c7958a649f0186a55c
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/asymmetricautoencoderkl.md
@@ -0,0 +1,60 @@
+
+
+# AsymmetricAutoencoderKL
+
+Improved larger variational autoencoder (VAE) model with KL loss for inpainting task: [Designing a Better Asymmetric VQGAN for StableDiffusion](https://arxiv.org/abs/2306.04632) by Zixin Zhu, Xuelu Feng, Dongdong Chen, Jianmin Bao, Le Wang, Yinpeng Chen, Lu Yuan, Gang Hua.
+
+The abstract from the paper is:
+
+*StableDiffusion is a revolutionary text-to-image generator that is causing a stir in the world of image generation and editing. Unlike traditional methods that learn a diffusion model in pixel space, StableDiffusion learns a diffusion model in the latent space via a VQGAN, ensuring both efficiency and quality. It not only supports image generation tasks, but also enables image editing for real images, such as image inpainting and local editing. However, we have observed that the vanilla VQGAN used in StableDiffusion leads to significant information loss, causing distortion artifacts even in non-edited image regions. To this end, we propose a new asymmetric VQGAN with two simple designs. Firstly, in addition to the input from the encoder, the decoder contains a conditional branch that incorporates information from task-specific priors, such as the unmasked image region in inpainting. Secondly, the decoder is much heavier than the encoder, allowing for more detailed recovery while only slightly increasing the total inference cost. The training cost of our asymmetric VQGAN is cheap, and we only need to retrain a new asymmetric decoder while keeping the vanilla VQGAN encoder and StableDiffusion unchanged. Our asymmetric VQGAN can be widely used in StableDiffusion-based inpainting and local editing methods. Extensive experiments demonstrate that it can significantly improve the inpainting and editing performance, while maintaining the original text-to-image capability. The code is available at https://github.com/buxiangzhiren/Asymmetric_VQGAN*
+
+Evaluation results can be found in section 4.1 of the original paper.
+
+## Available checkpoints
+
+* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5)
+* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2)
+
+## Example Usage
+
+```python
+from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
+from diffusers.utils import load_image, make_image_grid
+
+
+prompt = "a photo of a person with beard"
+img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
+mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
+
+original_image = load_image(img_url).resize((512, 512))
+mask_image = load_image(mask_url).resize((512, 512))
+
+pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
+pipe.vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
+pipe.to("cuda")
+
+image = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
+make_image_grid([original_image, mask_image, image], rows=1, cols=3)
+```
+
+## AsymmetricAutoencoderKL
+
+[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.vae.DecoderOutput
diff --git a/diffusers/docs/source/en/api/models/autoencoder_tiny.md b/diffusers/docs/source/en/api/models/autoencoder_tiny.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d19539bffe88a7111d2a0db30181644b65b7c2c
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/autoencoder_tiny.md
@@ -0,0 +1,57 @@
+
+
+# Tiny AutoEncoder
+
+Tiny AutoEncoder for Stable Diffusion (TAESD) was introduced in [madebyollin/taesd](https://github.com/madebyollin/taesd) by Ollin Boer Bohan. It is a tiny distilled version of Stable Diffusion's VAE that can quickly decode the latents in a [`StableDiffusionPipeline`] or [`StableDiffusionXLPipeline`] almost instantly.
+
+To use with Stable Diffusion v-2.1:
+
+```python
+import torch
+from diffusers import DiffusionPipeline, AutoencoderTiny
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
+)
+pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+
+prompt = "slice of delicious New York-style berry cheesecake"
+image = pipe(prompt, num_inference_steps=25).images[0]
+image
+```
+
+To use with Stable Diffusion XL 1.0
+
+```python
+import torch
+from diffusers import DiffusionPipeline, AutoencoderTiny
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+)
+pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+
+prompt = "slice of delicious New York-style berry cheesecake"
+image = pipe(prompt, num_inference_steps=25).images[0]
+image
+```
+
+## AutoencoderTiny
+
+[[autodoc]] AutoencoderTiny
+
+## AutoencoderTinyOutput
+
+[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
diff --git a/diffusers/docs/source/en/api/models/autoencoderkl.md b/diffusers/docs/source/en/api/models/autoencoderkl.md
new file mode 100644
index 0000000000000000000000000000000000000000..f42a4d2941ddeaf04dc5975fd73104132877ca1f
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/autoencoderkl.md
@@ -0,0 +1,55 @@
+
+
+# AutoencoderKL
+
+The variational autoencoder (VAE) model with KL loss was introduced in [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114v11) by Diederik P. Kingma and Max Welling. The model is used in 🤗 Diffusers to encode images into latents and to decode latent representations into images.
+
+The abstract from the paper is:
+
+*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
+
+## Loading from the original format
+
+By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
+from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
+
+```py
+from diffusers import AutoencoderKL
+
+url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be a local file
+model = AutoencoderKL.from_single_file(url)
+```
+
+## AutoencoderKL
+
+[[autodoc]] AutoencoderKL
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.vae.DecoderOutput
+
+## FlaxAutoencoderKL
+
+[[autodoc]] FlaxAutoencoderKL
+
+## FlaxAutoencoderKLOutput
+
+[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
+
+## FlaxDecoderOutput
+
+[[autodoc]] models.vae_flax.FlaxDecoderOutput
diff --git a/diffusers/docs/source/en/api/models/consistency_decoder_vae.md b/diffusers/docs/source/en/api/models/consistency_decoder_vae.md
new file mode 100644
index 0000000000000000000000000000000000000000..b45f7fa059dc9339bc815e99411ab5d01579c954
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/consistency_decoder_vae.md
@@ -0,0 +1,18 @@
+# Consistency Decoder
+
+Consistency decoder can be used to decode the latents from the denoising UNet in the [`StableDiffusionPipeline`]. This decoder was introduced in the [DALL-E 3 technical report](https://openai.com/dall-e-3).
+
+The original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder).
+
+
+
+Inference is only supported for 2 iterations as of now.
+
+
+
+The pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1).
+
+## ConsistencyDecoderVAE
+[[autodoc]] ConsistencyDecoderVAE
+ - all
+ - decode
diff --git a/diffusers/docs/source/en/api/models/controlnet.md b/diffusers/docs/source/en/api/models/controlnet.md
new file mode 100644
index 0000000000000000000000000000000000000000..12bc0110f208d60320451351fd1042b8f2a2d515
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/controlnet.md
@@ -0,0 +1,50 @@
+
+
+# ControlNet
+
+The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+## Loading from the original format
+
+By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
+from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+
+url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
+controlnet = ControlNetModel.from_single_file(url)
+
+url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
+pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
+```
+
+## ControlNetModel
+
+[[autodoc]] ControlNetModel
+
+## ControlNetOutput
+
+[[autodoc]] models.controlnet.ControlNetOutput
+
+## FlaxControlNetModel
+
+[[autodoc]] FlaxControlNetModel
+
+## FlaxControlNetOutput
+
+[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
diff --git a/diffusers/docs/source/en/api/models/overview.md b/diffusers/docs/source/en/api/models/overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..ab8d9d4e78397be9a93bfd3376d21cd141b4e47e
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/overview.md
@@ -0,0 +1,28 @@
+
+
+# Models
+
+🤗 Diffusers provides pretrained models for popular algorithms and modules to create custom diffusion systems. The primary function of models is to denoise an input sample as modeled by the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\).
+
+All models are built from the base [`ModelMixin`] class which is a [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) providing basic functionality for saving and loading models, locally and from the Hugging Face Hub.
+
+## ModelMixin
+[[autodoc]] ModelMixin
+
+## FlaxModelMixin
+
+[[autodoc]] FlaxModelMixin
+
+## PushToHubMixin
+
+[[autodoc]] utils.PushToHubMixin
diff --git a/diffusers/docs/source/en/api/models/prior_transformer.md b/diffusers/docs/source/en/api/models/prior_transformer.md
new file mode 100644
index 0000000000000000000000000000000000000000..0b849c300662253ef759f3e7fd99bea5d8b41d98
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/prior_transformer.md
@@ -0,0 +1,27 @@
+
+
+# Prior Transformer
+
+The Prior Transformer was originally introduced in [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) by Ramesh et al. It is used to predict CLIP image embeddings from CLIP text embeddings; image embeddings are predicted through a denoising diffusion process.
+
+The abstract from the paper is:
+
+*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*
+
+## PriorTransformer
+
+[[autodoc]] PriorTransformer
+
+## PriorTransformerOutput
+
+[[autodoc]] models.prior_transformer.PriorTransformerOutput
diff --git a/diffusers/docs/source/en/api/models/transformer2d.md b/diffusers/docs/source/en/api/models/transformer2d.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f891edd754a551c2fde843c936d4e40486d9e52
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/transformer2d.md
@@ -0,0 +1,41 @@
+
+
+# Transformer2D
+
+A Transformer model for image-like data from [CompVis](https://huggingface.co/CompVis) that is based on the [Vision Transformer](https://huggingface.co/papers/2010.11929) introduced by Dosovitskiy et al. The [`Transformer2DModel`] accepts discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
+
+When the input is **continuous**:
+
+1. Project the input and reshape it to `(batch_size, sequence_length, feature_dimension)`.
+2. Apply the Transformer blocks in the standard way.
+3. Reshape to image.
+
+When the input is **discrete**:
+
+
+
+It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked.
+
+
+
+1. Convert input (classes of latent pixels) to embeddings and apply positional embeddings.
+2. Apply the Transformer blocks in the standard way.
+3. Predict classes of unnoised image.
+
+## Transformer2DModel
+
+[[autodoc]] Transformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.transformer_2d.Transformer2DModelOutput
diff --git a/diffusers/docs/source/en/api/models/transformer_temporal.md b/diffusers/docs/source/en/api/models/transformer_temporal.md
new file mode 100644
index 0000000000000000000000000000000000000000..c936270b79274f40012bb550243227531c5f4be3
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/transformer_temporal.md
@@ -0,0 +1,23 @@
+
+
+# Transformer Temporal
+
+A Transformer model for video-like data.
+
+## TransformerTemporalModel
+
+[[autodoc]] models.transformer_temporal.TransformerTemporalModel
+
+## TransformerTemporalModelOutput
+
+[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
diff --git a/diffusers/docs/source/en/api/models/unet-motion.md b/diffusers/docs/source/en/api/models/unet-motion.md
new file mode 100644
index 0000000000000000000000000000000000000000..cbc8c30ff64f66f0ed0b79ee9de03bf0809302a1
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/unet-motion.md
@@ -0,0 +1,25 @@
+
+
+# UNetMotionModel
+
+The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
+
+The abstract from the paper is:
+
+*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
+
+## UNetMotionModel
+[[autodoc]] UNetMotionModel
+
+## UNet3DConditionOutput
+[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
diff --git a/diffusers/docs/source/en/api/models/unet.md b/diffusers/docs/source/en/api/models/unet.md
new file mode 100644
index 0000000000000000000000000000000000000000..66508b469a60d635ae96675c5efdbd7ddb596f6f
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/unet.md
@@ -0,0 +1,25 @@
+
+
+# UNet1DModel
+
+The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 1D UNet model.
+
+The abstract from the paper is:
+
+*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
+
+## UNet1DModel
+[[autodoc]] UNet1DModel
+
+## UNet1DOutput
+[[autodoc]] models.unet_1d.UNet1DOutput
diff --git a/diffusers/docs/source/en/api/models/unet2d-cond.md b/diffusers/docs/source/en/api/models/unet2d-cond.md
new file mode 100644
index 0000000000000000000000000000000000000000..ea385ff92426a0e588a08613c730235075357dcf
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/unet2d-cond.md
@@ -0,0 +1,31 @@
+
+
+# UNet2DConditionModel
+
+The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet conditional model.
+
+The abstract from the paper is:
+
+*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
+
+## UNet2DConditionModel
+[[autodoc]] UNet2DConditionModel
+
+## UNet2DConditionOutput
+[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
+
+## FlaxUNet2DConditionModel
+[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
+
+## FlaxUNet2DConditionOutput
+[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
diff --git a/diffusers/docs/source/en/api/models/unet2d.md b/diffusers/docs/source/en/api/models/unet2d.md
new file mode 100644
index 0000000000000000000000000000000000000000..7669d4a5d75ae781dcec0f8f4ca52d73839699d9
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/unet2d.md
@@ -0,0 +1,25 @@
+
+
+# UNet2DModel
+
+The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
+
+The abstract from the paper is:
+
+*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
+
+## UNet2DModel
+[[autodoc]] UNet2DModel
+
+## UNet2DOutput
+[[autodoc]] models.unet_2d.UNet2DOutput
diff --git a/diffusers/docs/source/en/api/models/unet3d-cond.md b/diffusers/docs/source/en/api/models/unet3d-cond.md
new file mode 100644
index 0000000000000000000000000000000000000000..4eea0a6d1cd2248ba8940484e54c19838c61f2be
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/unet3d-cond.md
@@ -0,0 +1,25 @@
+
+
+# UNet3DConditionModel
+
+The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 3D UNet conditional model.
+
+The abstract from the paper is:
+
+*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
+
+## UNet3DConditionModel
+[[autodoc]] UNet3DConditionModel
+
+## UNet3DConditionOutput
+[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
diff --git a/diffusers/docs/source/en/api/models/vq.md b/diffusers/docs/source/en/api/models/vq.md
new file mode 100644
index 0000000000000000000000000000000000000000..c288b163b28f68ed4cddf4c580b9300297c8c8c9
--- /dev/null
+++ b/diffusers/docs/source/en/api/models/vq.md
@@ -0,0 +1,27 @@
+
+
+# VQModel
+
+The VQ-VAE model was introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu. The model is used in 🤗 Diffusers to decode latent representations into images. Unlike [`AutoencoderKL`], the [`VQModel`] works in a quantized latent space.
+
+The abstract from the paper is:
+
+*Learning useful representations without supervision remains a key challenge in machine learning. In this paper, we propose a simple yet powerful generative model that learns such discrete representations. Our model, the Vector Quantised-Variational AutoEncoder (VQ-VAE), differs from VAEs in two key ways: the encoder network outputs discrete, rather than continuous, codes; and the prior is learnt rather than static. In order to learn a discrete latent representation, we incorporate ideas from vector quantisation (VQ). Using the VQ method allows the model to circumvent issues of "posterior collapse" -- where the latents are ignored when they are paired with a powerful autoregressive decoder -- typically observed in the VAE framework. Pairing these representations with an autoregressive prior, the model can generate high quality images, videos, and speech as well as doing high quality speaker conversion and unsupervised learning of phonemes, providing further evidence of the utility of the learnt representations.*
+
+## VQModel
+
+[[autodoc]] VQModel
+
+## VQEncoderOutput
+
+[[autodoc]] models.vq_model.VQEncoderOutput
diff --git a/diffusers/docs/source/en/api/normalization.md b/diffusers/docs/source/en/api/normalization.md
new file mode 100644
index 0000000000000000000000000000000000000000..ccc643ac5e31b122890e6c0ec5377391f28240b8
--- /dev/null
+++ b/diffusers/docs/source/en/api/normalization.md
@@ -0,0 +1,31 @@
+
+
+# Normalization layers
+
+Customized normalization layers for supporting various models in 🤗 Diffusers.
+
+## AdaLayerNorm
+
+[[autodoc]] models.normalization.AdaLayerNorm
+
+## AdaLayerNormZero
+
+[[autodoc]] models.normalization.AdaLayerNormZero
+
+## AdaLayerNormSingle
+
+[[autodoc]] models.normalization.AdaLayerNormSingle
+
+## AdaGroupNorm
+
+[[autodoc]] models.normalization.AdaGroupNorm
diff --git a/diffusers/docs/source/en/api/outputs.md b/diffusers/docs/source/en/api/outputs.md
new file mode 100644
index 0000000000000000000000000000000000000000..30bad5646e919d4b17a30f49a64d1c806605f378
--- /dev/null
+++ b/diffusers/docs/source/en/api/outputs.md
@@ -0,0 +1,67 @@
+
+
+# Outputs
+
+All model outputs are subclasses of [`~utils.BaseOutput`], data structures containing all the information returned by the model. The outputs can also be used as tuples or dictionaries.
+
+For example:
+
+```python
+from diffusers import DDIMPipeline
+
+pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32")
+outputs = pipeline()
+```
+
+The `outputs` object is a [`~pipelines.ImagePipelineOutput`] which means it has an image attribute.
+
+You can access each attribute as you normally would or with a keyword lookup, and if that attribute is not returned by the model, you will get `None`:
+
+```python
+outputs.images
+outputs["images"]
+```
+
+When considering the `outputs` object as a tuple, it only considers the attributes that don't have `None` values.
+For instance, retrieving an image by indexing into it returns the tuple `(outputs.images)`:
+
+```python
+outputs[:1]
+```
+
+
+
+To check a specific pipeline or model output, refer to its corresponding API documentation.
+
+
+
+## BaseOutput
+
+[[autodoc]] utils.BaseOutput
+ - to_tuple
+
+## ImagePipelineOutput
+
+[[autodoc]] pipelines.ImagePipelineOutput
+
+## FlaxImagePipelineOutput
+
+[[autodoc]] pipelines.pipeline_flax_utils.FlaxImagePipelineOutput
+
+## AudioPipelineOutput
+
+[[autodoc]] pipelines.AudioPipelineOutput
+
+## ImageTextPipelineOutput
+
+[[autodoc]] ImageTextPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/alt_diffusion.md b/diffusers/docs/source/en/api/pipelines/alt_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..d0326affbb638e4d8fa6a6b17abfe68f06f0c2ab
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/alt_diffusion.md
@@ -0,0 +1,47 @@
+
+
+# AltDiffusion
+
+AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://huggingface.co/papers/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu.
+
+The abstract from the paper is:
+
+*In this work, we present a conceptually simple and effective method to train a strong bilingual/multilingual multimodal representation model. Starting from the pre-trained multimodal representation model CLIP released by OpenAI, we altered its text encoder with a pre-trained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k-CN, COCO-CN and XTD. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding. Our models and code are available at [this https URL](https://github.com/FlagAI-Open/FlagAI).*
+
+## Tips
+
+`AltDiffusion` is conceptually the same as [Stable Diffusion](./stable_diffusion/overview).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## AltDiffusionPipeline
+
+[[autodoc]] AltDiffusionPipeline
+ - all
+ - __call__
+
+## AltDiffusionImg2ImgPipeline
+
+[[autodoc]] AltDiffusionImg2ImgPipeline
+ - all
+ - __call__
+
+## AltDiffusionPipelineOutput
+
+[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/animatediff.md b/diffusers/docs/source/en/api/pipelines/animatediff.md
new file mode 100644
index 0000000000000000000000000000000000000000..422d345b90578f22a1583eda9d5907b495ad15a6
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/animatediff.md
@@ -0,0 +1,234 @@
+
+
+# Text-to-Video Generation with AnimateDiff
+
+## Overview
+
+[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai.
+
+The abstract of the paper is the following:
+
+*With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at [this https URL](https://animatediff.github.io/).*
+
+## Available Pipelines
+
+| Pipeline | Tasks | Demo
+|---|---|:---:|
+| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
+
+## Available checkpoints
+
+Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5.
+
+## Usage example
+
+AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
+
+The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif
+
+# Load the motion adapter
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+# load SD 1.5 based finetuned model
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
+scheduler = DDIMScheduler.from_pretrained(
+ model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
+)
+pipe.scheduler = scheduler
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_model_cpu_offload()
+
+output = pipe(
+ prompt=(
+ "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
+ "orange sky, warm lighting, fishing boats, ocean waves seagulls, "
+ "rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
+ "golden hour, coastal landscape, seaside scenery"
+ ),
+ negative_prompt="bad quality, worse quality",
+ num_frames=16,
+ guidance_scale=7.5,
+ num_inference_steps=25,
+ generator=torch.Generator("cpu").manual_seed(42),
+)
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+Here are some sample outputs:
+
+
+
+
+ masterpiece, bestquality, sunset.
+
+
+
+
+
+
+
+
+AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples.
+
+
+
+## Using Motion LoRAs
+
+Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif
+
+# Load the motion adapter
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+# load SD 1.5 based finetuned model
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
+pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
+
+scheduler = DDIMScheduler.from_pretrained(
+ model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
+)
+pipe.scheduler = scheduler
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_model_cpu_offload()
+
+output = pipe(
+ prompt=(
+ "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
+ "orange sky, warm lighting, fishing boats, ocean waves seagulls, "
+ "rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
+ "golden hour, coastal landscape, seaside scenery"
+ ),
+ negative_prompt="bad quality, worse quality",
+ num_frames=16,
+ guidance_scale=7.5,
+ num_inference_steps=25,
+ generator=torch.Generator("cpu").manual_seed(42),
+)
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+
+
+
+ masterpiece, bestquality, sunset.
+
+
+
+
+
+
+## Using Motion LoRAs with PEFT
+
+You can also leverage the [PEFT](https://github.com/huggingface/peft) backend to combine Motion LoRA's and create more complex animations.
+
+First install PEFT with
+
+```shell
+pip install peft
+```
+
+Then you can use the following code to combine Motion LoRAs.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif
+
+# Load the motion adapter
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+# load SD 1.5 based finetuned model
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
+
+pipe.load_lora_weights("diffusers/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
+pipe.load_lora_weights("diffusers/animatediff-motion-lora-pan-left", adapter_name="pan-left")
+pipe.set_adapters(["zoom-out", "pan-left"], adapter_weights=[1.0, 1.0])
+
+scheduler = DDIMScheduler.from_pretrained(
+ model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
+)
+pipe.scheduler = scheduler
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_model_cpu_offload()
+
+output = pipe(
+ prompt=(
+ "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
+ "orange sky, warm lighting, fishing boats, ocean waves seagulls, "
+ "rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
+ "golden hour, coastal landscape, seaside scenery"
+ ),
+ negative_prompt="bad quality, worse quality",
+ num_frames=16,
+ guidance_scale=7.5,
+ num_inference_steps=25,
+ generator=torch.Generator("cpu").manual_seed(42),
+)
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+
+
+
+ masterpiece, bestquality, sunset.
+
+
+
+
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## AnimateDiffPipeline
+
+[[autodoc]] AnimateDiffPipeline
+ - all
+ - __call__
+ - enable_freeu
+ - disable_freeu
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_vae_tiling
+ - disable_vae_tiling
+
+## AnimateDiffPipelineOutput
+
+[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/attend_and_excite.md b/diffusers/docs/source/en/api/pipelines/attend_and_excite.md
new file mode 100644
index 0000000000000000000000000000000000000000..94f33cf1d0b634795464fea3d3d66631cf838b56
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/attend_and_excite.md
@@ -0,0 +1,37 @@
+
+
+# Attend-and-Excite
+
+Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.
+
+The abstract from the paper is:
+
+*Recent text-to-image generative models have demonstrated an unparalleled ability to generate diverse and creative imagery guided by a target text prompt. While revolutionary, current state-of-the-art diffusion models may still fail in generating images that fully convey the semantics in the given text prompt. We analyze the publicly available Stable Diffusion model and assess the existence of catastrophic neglect, where the model fails to generate one or more of the subjects from the input prompt. Moreover, we find that in some cases the model also fails to correctly bind attributes (e.g., colors) to their corresponding subjects. To help mitigate these failure cases, we introduce the concept of Generative Semantic Nursing (GSN), where we seek to intervene in the generative process on the fly during inference time to improve the faithfulness of the generated images. Using an attention-based formulation of GSN, dubbed Attend-and-Excite, we guide the model to refine the cross-attention units to attend to all subject tokens in the text prompt and strengthen - or excite - their activations, encouraging the model to generate all subjects described in the text prompt. We compare our approach to alternative approaches and demonstrate that it conveys the desired concepts more faithfully across a range of text prompts.*
+
+You can find additional information about Attend-and-Excite on the [project page](https://attendandexcite.github.io/Attend-and-Excite/), the [original codebase](https://github.com/AttendAndExcite/Attend-and-Excite), or try it out in a [demo](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionAttendAndExcitePipeline
+
+[[autodoc]] StableDiffusionAttendAndExcitePipeline
+ - all
+ - __call__
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/audio_diffusion.md b/diffusers/docs/source/en/api/pipelines/audio_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..3d140fe202a6bf6ec0bd487b84c767d18ab58e8e
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/audio_diffusion.md
@@ -0,0 +1,35 @@
+
+
+# Audio Diffusion
+
+[Audio Diffusion](https://github.com/teticio/audio-diffusion) is by Robert Dargavel Smith, and it leverages the recent advances in image generation from diffusion models by converting audio samples to and from Mel spectrogram images.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## AudioDiffusionPipeline
+[[autodoc]] AudioDiffusionPipeline
+ - all
+ - __call__
+
+## AudioPipelineOutput
+[[autodoc]] pipelines.AudioPipelineOutput
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
+
+## Mel
+[[autodoc]] Mel
diff --git a/diffusers/docs/source/en/api/pipelines/audioldm.md b/diffusers/docs/source/en/api/pipelines/audioldm.md
new file mode 100644
index 0000000000000000000000000000000000000000..43fb0f1a3bf4cfe606e443daec73c6cc52c59ca8
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/audioldm.md
@@ -0,0 +1,50 @@
+
+
+# AudioLDM
+
+AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
+is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)
+latents. AudioLDM takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional
+sound effects, human speech and music.
+
+The abstract from the paper is:
+
+*Text-to-audio (TTA) system has recently gained attention for its ability to synthesize general audio based on text descriptions. However, previous studies in TTA have limited generation quality with high computational costs. In this study, we propose AudioLDM, a TTA system that is built on a latent space to learn the continuous audio representations from contrastive language-audio pretraining (CLAP) latents. The pretrained CLAP models enable us to train LDMs with audio embedding while providing text embedding as a condition during sampling. By learning the latent representations of audio signals and their compositions without modeling the cross-modal relationship, AudioLDM is advantageous in both generation quality and computational efficiency. Trained on AudioCaps with a single GPU, AudioLDM achieves state-of-the-art TTA performance measured by both objective and subjective metrics (e.g., frechet distance). Moreover, AudioLDM is the first TTA system that enables various text-guided audio manipulations (e.g., style transfer) in a zero-shot fashion. Our implementation and demos are available at [this https URL](https://audioldm.github.io/).*
+
+The original codebase can be found at [haoheliu/AudioLDM](https://github.com/haoheliu/AudioLDM).
+
+## Tips
+
+When constructing a prompt, keep in mind:
+
+* Descriptive prompt inputs work best; you can use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific (for example, "water stream in a forest" instead of "stream").
+* It's best to use general terms like "cat" or "dog" instead of specific names or abstract objects the model may not be familiar with.
+
+During inference:
+
+* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
+* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## AudioLDMPipeline
+[[autodoc]] AudioLDMPipeline
+ - all
+ - __call__
+
+## AudioPipelineOutput
+[[autodoc]] pipelines.AudioPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/audioldm2.md b/diffusers/docs/source/en/api/pipelines/audioldm2.md
new file mode 100644
index 0000000000000000000000000000000000000000..89bb6b8cc922855b2df2817c2093e86ee958e564
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/audioldm2.md
@@ -0,0 +1,78 @@
+
+
+# AudioLDM 2
+
+AudioLDM 2 was proposed in [AudioLDM 2: Learning Holistic Audio Generation with Self-supervised Pretraining](https://arxiv.org/abs/2308.05734) by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional sound effects, human speech and music.
+
+Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM 2 is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from text embeddings. Two text encoder models are used to compute the text embeddings from a prompt input: the text-branch of [CLAP](https://huggingface.co/docs/transformers/main/en/model_doc/clap) and the encoder of [Flan-T5](https://huggingface.co/docs/transformers/main/en/model_doc/flan-t5). These text embeddings are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/main/api/pipelines/audioldm2#diffusers.AudioLDM2ProjectionModel). A [GPT2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) _language model (LM)_ is used to auto-regressively predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2UNet2DConditionModel) of AudioLDM 2 is unique in the sense that it takes **two** cross-attention embeddings, as opposed to one cross-attention conditioning, as in most other LDMs.
+
+The abstract of the paper is the following:
+
+*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called "language of audio" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).*
+
+This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
+
+## Tips
+
+### Choosing a checkpoint
+
+AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio generation. The third checkpoint is trained exclusively on text-to-music generation.
+
+All checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet.
+See table below for details on the three checkpoints:
+
+| Checkpoint | Task | UNet Model Size | Total Model Size | Training Data / h |
+|-----------------------------------------------------------------|---------------|-----------------|------------------|-------------------|
+| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
+| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
+| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
+
+### Constructing a prompt
+
+* Descriptive prompt inputs work best: use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g. "water stream in a forest" instead of "stream").
+* It's best to use general terms like "cat" or "dog" instead of specific names or abstract objects the model may not be familiar with.
+* Using a **negative prompt** can significantly improve the quality of the generated waveform, by guiding the generation away from terms that correspond to poor quality audio. Try using a negative prompt of "Low quality."
+
+### Controlling inference
+
+* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
+* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
+
+### Evaluating generated waveforms:
+
+* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation.
+* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
+
+The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## AudioLDM2Pipeline
+[[autodoc]] AudioLDM2Pipeline
+ - all
+ - __call__
+
+## AudioLDM2ProjectionModel
+[[autodoc]] AudioLDM2ProjectionModel
+ - forward
+
+## AudioLDM2UNet2DConditionModel
+[[autodoc]] AudioLDM2UNet2DConditionModel
+ - forward
+
+## AudioPipelineOutput
+[[autodoc]] pipelines.AudioPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/auto_pipeline.md b/diffusers/docs/source/en/api/pipelines/auto_pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..e9b932f33dd299c451f92b17a7c339d9ceae52bc
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/auto_pipeline.md
@@ -0,0 +1,71 @@
+
+
+# AutoPipeline
+
+`AutoPipeline` is designed to:
+
+1. make it easy for you to load a checkpoint for a task without knowing the specific pipeline class to use
+2. use multiple pipelines in your workflow
+
+Based on the task, the `AutoPipeline` class automatically retrieves the relevant pipeline given the name or path to the pretrained weights with the `from_pretrained()` method.
+
+To seamlessly switch between tasks with the same checkpoint without reallocating additional memory, use the `from_pipe()` method to transfer the components from the original pipeline to the new one.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+image = pipeline(prompt, num_inference_steps=25).images[0]
+```
+
+
+
+Check out the [AutoPipeline](../../tutorials/autopipeline) tutorial to learn how to use this API!
+
+
+
+`AutoPipeline` supports text-to-image, image-to-image, and inpainting for the following diffusion models:
+
+- [Stable Diffusion](./stable_diffusion/overview)
+- [ControlNet](./controlnet)
+- [Stable Diffusion XL (SDXL)](./stable_diffusion/stable_diffusion_xl)
+- [DeepFloyd IF](./deepfloyd_if)
+- [Kandinsky 2.1](./kandinsky)
+- [Kandinsky 2.2](./kandinsky_v22)
+
+
+## AutoPipelineForText2Image
+
+[[autodoc]] AutoPipelineForText2Image
+ - all
+ - from_pretrained
+ - from_pipe
+
+## AutoPipelineForImage2Image
+
+[[autodoc]] AutoPipelineForImage2Image
+ - all
+ - from_pretrained
+ - from_pipe
+
+## AutoPipelineForInpainting
+
+[[autodoc]] AutoPipelineForInpainting
+ - all
+ - from_pretrained
+ - from_pipe
diff --git a/diffusers/docs/source/en/api/pipelines/blip_diffusion.md b/diffusers/docs/source/en/api/pipelines/blip_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..b2fa5de2508c500b248fa4e9c765974f4b2ffb64
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/blip_diffusion.md
@@ -0,0 +1,41 @@
+
+
+# BLIP-Diffusion
+
+BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
+
+
+The abstract from the paper is:
+
+*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications. Project page at [this https URL](https://dxli94.github.io/BLIP-Diffusion-website/).*
+
+The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP-Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
+
+`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+
+## BlipDiffusionPipeline
+[[autodoc]] BlipDiffusionPipeline
+ - all
+ - __call__
+
+## BlipDiffusionControlNetPipeline
+[[autodoc]] BlipDiffusionControlNetPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/consistency_models.md b/diffusers/docs/source/en/api/pipelines/consistency_models.md
new file mode 100644
index 0000000000000000000000000000000000000000..afdee2c0c8e9127b8f6fa10db563cd595bbee1a8
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/consistency_models.md
@@ -0,0 +1,56 @@
+
+
+# Consistency Models
+
+Consistency Models were proposed in [Consistency Models](https://huggingface.co/papers/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
+
+The abstract from the paper is:
+
+*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.*
+
+The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models), and additional checkpoints are available at [openai](https://huggingface.co/openai).
+
+The pipeline was contributed by [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues). ❤️
+
+## Tips
+
+For an additional speed-up, use `torch.compile` to generate multiple images in <1 second:
+
+```diff
+ import torch
+ from diffusers import ConsistencyModelPipeline
+
+ device = "cuda"
+ # Load the cd_bedroom256_lpips checkpoint.
+ model_id_or_path = "openai/diffusers-cd_bedroom256_lpips"
+ pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+ pipe.to(device)
+
++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ # Multistep sampling
+ # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo:
+ # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83
+ for _ in range(10):
+ image = pipe(timesteps=[17, 0]).images[0]
+ image.show()
+```
+
+
+## ConsistencyModelPipeline
+[[autodoc]] ConsistencyModelPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/controlnet.md b/diffusers/docs/source/en/api/pipelines/controlnet.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f636af79b773f2ad4de241e9d4aa55b6efafcca
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/controlnet.md
@@ -0,0 +1,78 @@
+
+
+# ControlNet
+
+ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
+
+With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+This model was contributed by [takuma104](https://huggingface.co/takuma104). ❤️
+
+The original codebase can be found at [lllyasviel/ControlNet](https://github.com/lllyasviel/ControlNet), and you can find official ControlNet checkpoints on [lllyasviel's](https://huggingface.co/lllyasviel) Hub profile.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionControlNetPipeline
+[[autodoc]] StableDiffusionControlNetPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - load_textual_inversion
+
+## StableDiffusionControlNetImg2ImgPipeline
+[[autodoc]] StableDiffusionControlNetImg2ImgPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - load_textual_inversion
+
+## StableDiffusionControlNetInpaintPipeline
+[[autodoc]] StableDiffusionControlNetInpaintPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - load_textual_inversion
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
+
+## FlaxStableDiffusionControlNetPipeline
+[[autodoc]] FlaxStableDiffusionControlNetPipeline
+ - all
+ - __call__
+
+## FlaxStableDiffusionControlNetPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/controlnet_sdxl.md b/diffusers/docs/source/en/api/pipelines/controlnet_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..755f18341d2082b9652cbece9febe656f745afbc
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/controlnet_sdxl.md
@@ -0,0 +1,55 @@
+
+
+# ControlNet with Stable Diffusion XL
+
+ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
+
+With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+You can find additional smaller Stable Diffusion XL (SDXL) ControlNet checkpoints from the 🤗 [Diffusers](https://huggingface.co/diffusers) Hub organization, and browse [community-trained](https://huggingface.co/models?other=stable-diffusion-xl&other=controlnet) checkpoints on the Hub.
+
+
+
+🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
+
+
+
+If you don't see a checkpoint you're interested in, you can train your own SDXL ControlNet with our [training script](../../../../../examples/controlnet/README_sdxl).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionXLControlNetPipeline
+[[autodoc]] StableDiffusionXLControlNetPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLControlNetImg2ImgPipeline
+[[autodoc]] StableDiffusionXLControlNetImg2ImgPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLControlNetInpaintPipeline
+[[autodoc]] StableDiffusionXLControlNetInpaintPipeline
+ - all
+ - __call__
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/cycle_diffusion.md b/diffusers/docs/source/en/api/pipelines/cycle_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..13ada0594a6a603e77a94d1ac119cebefe26aeb9
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/cycle_diffusion.md
@@ -0,0 +1,33 @@
+
+
+# Cycle Diffusion
+
+Cycle Diffusion is a text guided image-to-image generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://huggingface.co/papers/2210.05559) by Chen Henry Wu, Fernando De la Torre.
+
+The abstract from the paper is:
+
+*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs. The code is publicly available at [this https URL](https://github.com/ChenWu98/cycle-diffusion).*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## CycleDiffusionPipeline
+[[autodoc]] CycleDiffusionPipeline
+ - all
+ - __call__
+
+## StableDiffusionPiplineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/dance_diffusion.md b/diffusers/docs/source/en/api/pipelines/dance_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..fcf52a1ec0811e2b7ceed8fdcb110d4565d7203f
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/dance_diffusion.md
@@ -0,0 +1,32 @@
+
+
+# Dance Diffusion
+
+[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.
+
+Dance Diffusion is the first in a suite of generative audio tools for producers and musicians released by [Harmonai](https://github.com/Harmonai-org).
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## DanceDiffusionPipeline
+[[autodoc]] DanceDiffusionPipeline
+ - all
+ - __call__
+
+## AudioPipelineOutput
+[[autodoc]] pipelines.AudioPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/ddim.md b/diffusers/docs/source/en/api/pipelines/ddim.md
new file mode 100644
index 0000000000000000000000000000000000000000..5c876806f6001907fcd50a97407b06dded5fc9f2
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/ddim.md
@@ -0,0 +1,29 @@
+
+
+# DDIM
+
+[Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.
+
+The abstract from the paper is:
+
+*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
+
+The original codebase can be found at [ermongroup/ddim](https://github.com/ermongroup/ddim).
+
+## DDIMPipeline
+[[autodoc]] DDIMPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/ddpm.md b/diffusers/docs/source/en/api/pipelines/ddpm.md
new file mode 100644
index 0000000000000000000000000000000000000000..c12fbcf088dfc4dde906fd030a0165ba38d8eea9
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/ddpm.md
@@ -0,0 +1,35 @@
+
+
+# DDPM
+
+[Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2006.11239) (DDPM) by Jonathan Ho, Ajay Jain and Pieter Abbeel proposes a diffusion based model of the same name. In the 🤗 Diffusers library, DDPM refers to the *discrete denoising scheduler* from the paper as well as the pipeline.
+
+The abstract from the paper is:
+
+*We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN.*
+
+The original codebase can be found at [hohonathanho/diffusion](https://github.com/hojonathanho/diffusion).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+# DDPMPipeline
+[[autodoc]] DDPMPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/deepfloyd_if.md b/diffusers/docs/source/en/api/pipelines/deepfloyd_if.md
new file mode 100644
index 0000000000000000000000000000000000000000..8168c65779792935a3efe529d67c3faa96888ee9
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/deepfloyd_if.md
@@ -0,0 +1,506 @@
+
+
+# DeepFloyd IF
+
+## Overview
+
+DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding.
+The model is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules:
+- Stage 1: a base model that generates 64x64 px image based on text prompt,
+- Stage 2: a 64x64 px => 256x256 px super-resolution model, and
+- Stage 3: a 256x256 px => 1024x1024 px super-resolution model
+Stage 1 and Stage 2 utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling.
+Stage 3 is [Stability AI's x4 Upscaling model](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler).
+The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset.
+Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.
+
+## Usage
+
+Before you can use IF, you need to accept its usage conditions. To do so:
+1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be logged in.
+2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). Accepting the license on the stage I model card will auto accept for the other IF models.
+3. Make sure to login locally. Install `huggingface_hub`:
+```sh
+pip install huggingface_hub --upgrade
+```
+
+run the login function in a Python shell:
+
+```py
+from huggingface_hub import login
+
+login()
+```
+
+and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
+
+Next we install `diffusers` and dependencies:
+
+```sh
+pip install -q diffusers accelerate transformers
+```
+
+The following sections give more in-detail examples of how to use IF. Specifically:
+
+- [Text-to-Image Generation](#text-to-image-generation)
+- [Image-to-Image Generation](#text-guided-image-to-image-generation)
+- [Inpainting](#text-guided-inpainting-generation)
+- [Reusing model weights](#converting-between-different-pipelines)
+- [Speed optimization](#optimizing-for-speed)
+- [Memory optimization](#optimizing-for-memory)
+
+**Available checkpoints**
+- *Stage-1*
+ - [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
+ - [DeepFloyd/IF-I-L-v1.0](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)
+ - [DeepFloyd/IF-I-M-v1.0](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)
+
+- *Stage-2*
+ - [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
+ - [DeepFloyd/IF-II-M-v1.0](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)
+
+- *Stage-3*
+ - [stabilityai/stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)
+
+
+**Google Colab**
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
+
+### Text-to-Image Generation
+
+By default diffusers makes use of [model cpu offloading](../../optimization/memory#model-offloading) to run the whole IF pipeline with as little as 14 GB of VRAM.
+
+```python
+from diffusers import DiffusionPipeline
+from diffusers.utils import pt_to_pil, make_image_grid
+import torch
+
+# stage 1
+stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+stage_1.enable_model_cpu_offload()
+
+# stage 2
+stage_2 = DiffusionPipeline.from_pretrained(
+ "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+)
+stage_2.enable_model_cpu_offload()
+
+# stage 3
+safety_modules = {
+ "feature_extractor": stage_1.feature_extractor,
+ "safety_checker": stage_1.safety_checker,
+ "watermarker": stage_1.watermarker,
+}
+stage_3 = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
+)
+stage_3.enable_model_cpu_offload()
+
+prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
+generator = torch.manual_seed(1)
+
+# text embeds
+prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
+
+# stage 1
+stage_1_output = stage_1(
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
+).images
+#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
+
+# stage 2
+stage_2_output = stage_2(
+ image=stage_1_output,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ generator=generator,
+ output_type="pt",
+).images
+#pt_to_pil(stage_2_output)[0].save("./if_stage_II.png")
+
+# stage 3
+stage_3_output = stage_3(prompt=prompt, image=stage_2_output, noise_level=100, generator=generator).images
+#stage_3_output[0].save("./if_stage_III.png")
+make_image_grid([pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=3)
+```
+
+### Text Guided Image-to-Image Generation
+
+The same IF model weights can be used for text-guided image-to-image translation or image variation.
+In this case just make sure to load the weights using the [`IFImg2ImgPipeline`] and [`IFImg2ImgSuperResolutionPipeline`] pipelines.
+
+**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
+without loading them twice by making use of the [`~DiffusionPipeline.components`] argument as explained [here](#converting-between-different-pipelines).
+
+```python
+from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
+from diffusers.utils import pt_to_pil, load_image, make_image_grid
+import torch
+
+# download image
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+original_image = load_image(url)
+original_image = original_image.resize((768, 512))
+
+# stage 1
+stage_1 = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+stage_1.enable_model_cpu_offload()
+
+# stage 2
+stage_2 = IFImg2ImgSuperResolutionPipeline.from_pretrained(
+ "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+)
+stage_2.enable_model_cpu_offload()
+
+# stage 3
+safety_modules = {
+ "feature_extractor": stage_1.feature_extractor,
+ "safety_checker": stage_1.safety_checker,
+ "watermarker": stage_1.watermarker,
+}
+stage_3 = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
+)
+stage_3.enable_model_cpu_offload()
+
+prompt = "A fantasy landscape in style minecraft"
+generator = torch.manual_seed(1)
+
+# text embeds
+prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
+
+# stage 1
+stage_1_output = stage_1(
+ image=original_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ generator=generator,
+ output_type="pt",
+).images
+#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
+
+# stage 2
+stage_2_output = stage_2(
+ image=stage_1_output,
+ original_image=original_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ generator=generator,
+ output_type="pt",
+).images
+#pt_to_pil(stage_2_output)[0].save("./if_stage_II.png")
+
+# stage 3
+stage_3_output = stage_3(prompt=prompt, image=stage_2_output, generator=generator, noise_level=100).images
+#stage_3_output[0].save("./if_stage_III.png")
+make_image_grid([original_image, pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=4)
+```
+
+### Text Guided Inpainting Generation
+
+The same IF model weights can be used for text-guided image-to-image translation or image variation.
+In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.
+
+**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
+without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).
+
+```python
+from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
+from diffusers.utils import pt_to_pil, load_image, make_image_grid
+import torch
+
+# download image
+url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
+original_image = load_image(url)
+
+# download mask
+url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
+mask_image = load_image(url)
+
+# stage 1
+stage_1 = IFInpaintingPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+stage_1.enable_model_cpu_offload()
+
+# stage 2
+stage_2 = IFInpaintingSuperResolutionPipeline.from_pretrained(
+ "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
+)
+stage_2.enable_model_cpu_offload()
+
+# stage 3
+safety_modules = {
+ "feature_extractor": stage_1.feature_extractor,
+ "safety_checker": stage_1.safety_checker,
+ "watermarker": stage_1.watermarker,
+}
+stage_3 = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
+)
+stage_3.enable_model_cpu_offload()
+
+prompt = "blue sunglasses"
+generator = torch.manual_seed(1)
+
+# text embeds
+prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
+
+# stage 1
+stage_1_output = stage_1(
+ image=original_image,
+ mask_image=mask_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ generator=generator,
+ output_type="pt",
+).images
+#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
+
+# stage 2
+stage_2_output = stage_2(
+ image=stage_1_output,
+ original_image=original_image,
+ mask_image=mask_image,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ generator=generator,
+ output_type="pt",
+).images
+#pt_to_pil(stage_1_output)[0].save("./if_stage_II.png")
+
+# stage 3
+stage_3_output = stage_3(prompt=prompt, image=stage_2_output, generator=generator, noise_level=100).images
+#stage_3_output[0].save("./if_stage_III.png")
+make_image_grid([original_image, mask_image, pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=5)
+```
+
+### Converting between different pipelines
+
+In addition to being loaded with `from_pretrained`, Pipelines can also be loaded directly from each other.
+
+```python
+from diffusers import IFPipeline, IFSuperResolutionPipeline
+
+pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")
+pipe_2 = IFSuperResolutionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0")
+
+
+from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline
+
+pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
+pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
+
+
+from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline
+
+pipe_1 = IFInpaintingPipeline(**pipe_1.components)
+pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
+```
+
+### Optimizing for speed
+
+The simplest optimization to run IF faster is to move all model components to the GPU.
+
+```py
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+pipe.to("cuda")
+```
+
+You can also run the diffusion process for a shorter number of timesteps.
+
+This can either be done with the `num_inference_steps` argument:
+
+```py
+pipe("", num_inference_steps=30)
+```
+
+Or with the `timesteps` argument:
+
+```py
+from diffusers.pipelines.deepfloyd_if import fast27_timesteps
+
+pipe("", timesteps=fast27_timesteps)
+```
+
+When doing image variation or inpainting, you can also decrease the number of timesteps
+with the strength argument. The strength argument is the amount of noise to add to the input image which also determines how many steps to run in the denoising process.
+A smaller number will vary the image less but run faster.
+
+```py
+pipe = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(image=image, prompt="", strength=0.3).images
+```
+
+You can also use [`torch.compile`](../../optimization/torch2.0). Note that we have not exhaustively tested `torch.compile`
+with IF and it might not give expected results.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+### Optimizing for memory
+
+When optimizing for GPU memory, we can use the standard diffusers CPU offloading APIs.
+
+Either the model based CPU offloading,
+
+```py
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+```
+
+or the more aggressive layer based CPU offloading.
+
+```py
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
+pipe.enable_sequential_cpu_offload()
+```
+
+Additionally, T5 can be loaded in 8bit precision
+
+```py
+from transformers import T5EncoderModel
+
+text_encoder = T5EncoderModel.from_pretrained(
+ "DeepFloyd/IF-I-XL-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
+)
+
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "DeepFloyd/IF-I-XL-v1.0",
+ text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
+ unet=None,
+ device_map="auto",
+)
+
+prompt_embeds, negative_embeds = pipe.encode_prompt("")
+```
+
+For CPU RAM constrained machines like Google Colab free tier where we can't load all model components to the CPU at once, we can manually only load the pipeline with
+the text encoder or UNet when the respective model components are needed.
+
+```py
+from diffusers import IFPipeline, IFSuperResolutionPipeline
+import torch
+import gc
+from transformers import T5EncoderModel
+from diffusers.utils import pt_to_pil, make_image_grid
+
+text_encoder = T5EncoderModel.from_pretrained(
+ "DeepFloyd/IF-I-XL-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
+)
+
+# text to image
+pipe = DiffusionPipeline.from_pretrained(
+ "DeepFloyd/IF-I-XL-v1.0",
+ text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
+ unet=None,
+ device_map="auto",
+)
+
+prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
+prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
+
+# Remove the pipeline so we can re-load the pipeline with the unet
+del text_encoder
+del pipe
+gc.collect()
+torch.cuda.empty_cache()
+
+pipe = IFPipeline.from_pretrained(
+ "DeepFloyd/IF-I-XL-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
+)
+
+generator = torch.Generator().manual_seed(0)
+stage_1_output = pipe(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ output_type="pt",
+ generator=generator,
+).images
+
+#pt_to_pil(stage_1_output)[0].save("./if_stage_I.png")
+
+# Remove the pipeline so we can load the super-resolution pipeline
+del pipe
+gc.collect()
+torch.cuda.empty_cache()
+
+# First super resolution
+
+pipe = IFSuperResolutionPipeline.from_pretrained(
+ "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
+)
+
+generator = torch.Generator().manual_seed(0)
+stage_2_output = pipe(
+ image=stage_1_output,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ output_type="pt",
+ generator=generator,
+).images
+
+#pt_to_pil(stage_2_output)[0].save("./if_stage_II.png")
+make_image_grid([pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0]], rows=1, rows=2)
+```
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Colab
+|---|---|:---:|
+| [pipeline_if.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |
+| [pipeline_if_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py) | *Text-to-Image Generation* | - |
+| [pipeline_if_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py) | *Image-to-Image Generation* | - |
+| [pipeline_if_img2img_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py) | *Image-to-Image Generation* | - |
+| [pipeline_if_inpainting.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py) | *Image-to-Image Generation* | - |
+| [pipeline_if_inpainting_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py) | *Image-to-Image Generation* | - |
+
+## IFPipeline
+[[autodoc]] IFPipeline
+ - all
+ - __call__
+
+## IFSuperResolutionPipeline
+[[autodoc]] IFSuperResolutionPipeline
+ - all
+ - __call__
+
+## IFImg2ImgPipeline
+[[autodoc]] IFImg2ImgPipeline
+ - all
+ - __call__
+
+## IFImg2ImgSuperResolutionPipeline
+[[autodoc]] IFImg2ImgSuperResolutionPipeline
+ - all
+ - __call__
+
+## IFInpaintingPipeline
+[[autodoc]] IFInpaintingPipeline
+ - all
+ - __call__
+
+## IFInpaintingSuperResolutionPipeline
+[[autodoc]] IFInpaintingSuperResolutionPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/diffedit.md b/diffusers/docs/source/en/api/pipelines/diffedit.md
new file mode 100644
index 0000000000000000000000000000000000000000..7ab6ab2391e9437f2112cc9a8452abd8ae1734de
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/diffedit.md
@@ -0,0 +1,55 @@
+
+
+# DiffEdit
+
+[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.
+
+The abstract from the paper is:
+
+*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
+
+The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
+
+This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️
+
+## Tips
+
+* The pipeline can generate masks that can be fed into other inpainting pipelines.
+* In order to generate an image using this pipeline, both an image mask (source and target prompts can be manually specified or generated, and passed to [`~StableDiffusionDiffEditPipeline.generate_mask`])
+and a set of partially inverted latents (generated using [`~StableDiffusionDiffEditPipeline.invert`]) _must_ be provided as arguments when calling the pipeline to generate the final edited image.
+* The function [`~StableDiffusionDiffEditPipeline.generate_mask`] exposes two prompt arguments, `source_prompt` and `target_prompt`
+that let you control the locations of the semantic edits in the final image to be generated. Let's say,
+you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect
+this in the generated mask, you simply have to set the embeddings related to the phrases including "cat" to
+`source_prompt` and "dog" to `target_prompt`.
+* When generating partially inverted latents using `invert`, assign a caption or text embedding describing the
+overall image to the `prompt` argument to help guide the inverse latent sampling process. In most cases, the
+source concept is sufficiently descriptive to yield good results, but feel free to explore alternatives.
+* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt`
+and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to
+the phrases including "cat" to `negative_prompt` and "dog" to `prompt`.
+* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to:
+ * Swap the `source_prompt` and `target_prompt` in the arguments to `generate_mask`.
+ * Change the input prompt in [`~StableDiffusionDiffEditPipeline.invert`] to include "dog".
+ * Swap the `prompt` and `negative_prompt` in the arguments to call the pipeline to generate the final edited image.
+* The source and target prompts, or their corresponding embeddings, can also be automatically generated. Please refer to the [DiffEdit](../../using-diffusers/diffedit) guide for more details.
+
+## StableDiffusionDiffEditPipeline
+[[autodoc]] StableDiffusionDiffEditPipeline
+ - all
+ - generate_mask
+ - invert
+ - __call__
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/dit.md b/diffusers/docs/source/en/api/pipelines/dit.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e49a3bd68e734046aa334cd3a0e6ec7065e39cb
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/dit.md
@@ -0,0 +1,35 @@
+
+
+# DiT
+
+[Scalable Diffusion Models with Transformers](https://huggingface.co/papers/2212.09748) (DiT) is by William Peebles and Saining Xie.
+
+The abstract from the paper is:
+
+*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.*
+
+The original codebase can be found at [facebookresearch/dit](https://github.com/facebookresearch/dit).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## DiTPipeline
+[[autodoc]] DiTPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/kandinsky.md b/diffusers/docs/source/en/api/pipelines/kandinsky.md
new file mode 100644
index 0000000000000000000000000000000000000000..12073d4a14e79f74fa87982a05ebe53cb0e9d135
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/kandinsky.md
@@ -0,0 +1,73 @@
+
+
+# Kandinsky 2.1
+
+Kandinsky 2.1 is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Vladimir Arkhipkin](https://github.com/oriBetelgeuse), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey), and [Denis Dimitrov](https://github.com/denndimitrov).
+
+The description from it's GitHub page is:
+
+*Kandinsky 2.1 inherits best practicies from Dall-E 2 and Latent diffusion, while introducing some new ideas. As text and image encoder it uses CLIP model and diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach increases the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.*
+
+The original codebase can be found at [ai-forever/Kandinsky-2](https://github.com/ai-forever/Kandinsky-2).
+
+
+
+Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
+
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## KandinskyPriorPipeline
+
+[[autodoc]] KandinskyPriorPipeline
+ - all
+ - __call__
+ - interpolate
+
+## KandinskyPipeline
+
+[[autodoc]] KandinskyPipeline
+ - all
+ - __call__
+
+## KandinskyCombinedPipeline
+
+[[autodoc]] KandinskyCombinedPipeline
+ - all
+ - __call__
+
+## KandinskyImg2ImgPipeline
+
+[[autodoc]] KandinskyImg2ImgPipeline
+ - all
+ - __call__
+
+## KandinskyImg2ImgCombinedPipeline
+
+[[autodoc]] KandinskyImg2ImgCombinedPipeline
+ - all
+ - __call__
+
+## KandinskyInpaintPipeline
+
+[[autodoc]] KandinskyInpaintPipeline
+ - all
+ - __call__
+
+## KandinskyInpaintCombinedPipeline
+
+[[autodoc]] KandinskyInpaintCombinedPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/kandinsky_v22.md b/diffusers/docs/source/en/api/pipelines/kandinsky_v22.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a32eb42412a30c0a4c0108c7a579762abb3ccb1
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/kandinsky_v22.md
@@ -0,0 +1,92 @@
+
+
+# Kandinsky 2.2
+
+Kandinsky 2.2 is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Vladimir Arkhipkin](https://github.com/oriBetelgeuse), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey), and [Denis Dimitrov](https://github.com/denndimitrov).
+
+The description from it's GitHub page is:
+
+*Kandinsky 2.2 brings substantial improvements upon its predecessor, Kandinsky 2.1, by introducing a new, more powerful image encoder - CLIP-ViT-G and the ControlNet support. The switch to CLIP-ViT-G as the image encoder significantly increases the model's capability to generate more aesthetic pictures and better understand text, thus enhancing the model's overall performance. The addition of the ControlNet mechanism allows the model to effectively control the process of generating images. This leads to more accurate and visually appealing outputs and opens new possibilities for text-guided image manipulation.*
+
+The original codebase can be found at [ai-forever/Kandinsky-2](https://github.com/ai-forever/Kandinsky-2).
+
+
+
+Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
+
+
+
+
+
+Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## KandinskyV22PriorPipeline
+
+[[autodoc]] KandinskyV22PriorPipeline
+ - all
+ - __call__
+ - interpolate
+
+## KandinskyV22Pipeline
+
+[[autodoc]] KandinskyV22Pipeline
+ - all
+ - __call__
+
+## KandinskyV22CombinedPipeline
+
+[[autodoc]] KandinskyV22CombinedPipeline
+ - all
+ - __call__
+
+## KandinskyV22ControlnetPipeline
+
+[[autodoc]] KandinskyV22ControlnetPipeline
+ - all
+ - __call__
+
+## KandinskyV22PriorEmb2EmbPipeline
+
+[[autodoc]] KandinskyV22PriorEmb2EmbPipeline
+ - all
+ - __call__
+ - interpolate
+
+## KandinskyV22Img2ImgPipeline
+
+[[autodoc]] KandinskyV22Img2ImgPipeline
+ - all
+ - __call__
+
+## KandinskyV22Img2ImgCombinedPipeline
+
+[[autodoc]] KandinskyV22Img2ImgCombinedPipeline
+ - all
+ - __call__
+
+## KandinskyV22ControlnetImg2ImgPipeline
+
+[[autodoc]] KandinskyV22ControlnetImg2ImgPipeline
+ - all
+ - __call__
+
+## KandinskyV22InpaintPipeline
+
+[[autodoc]] KandinskyV22InpaintPipeline
+ - all
+ - __call__
+
+## KandinskyV22InpaintCombinedPipeline
+
+[[autodoc]] KandinskyV22InpaintCombinedPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/latent_consistency_models.md b/diffusers/docs/source/en/api/pipelines/latent_consistency_models.md
new file mode 100644
index 0000000000000000000000000000000000000000..e5d4beba2bed19eeb7667fc2f310de3ffa2736af
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/latent_consistency_models.md
@@ -0,0 +1,52 @@
+
+
+# Latent Consistency Models
+
+Latent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://huggingface.co/papers/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.
+
+The abstract of the paper is as follows:
+
+*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/).*
+
+A demo for the [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) checkpoint can be found [here](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).
+
+The pipelines were contributed by [luosiallen](https://luosiallen.github.io/), [nagolinc](https://github.com/nagolinc), and [dg845](https://github.com/dg845).
+
+
+## LatentConsistencyModelPipeline
+
+[[autodoc]] LatentConsistencyModelPipeline
+ - all
+ - __call__
+ - enable_freeu
+ - disable_freeu
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_vae_tiling
+ - disable_vae_tiling
+
+## LatentConsistencyModelImg2ImgPipeline
+
+[[autodoc]] LatentConsistencyModelImg2ImgPipeline
+ - all
+ - __call__
+ - enable_freeu
+ - disable_freeu
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_vae_tiling
+ - disable_vae_tiling
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/latent_diffusion.md b/diffusers/docs/source/en/api/pipelines/latent_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..de6f96bea19a8d666b5355c6137edc2233f1a13c
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/latent_diffusion.md
@@ -0,0 +1,40 @@
+
+
+# Latent Diffusion
+
+Latent Diffusion was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
+
+The abstract from the paper is:
+
+*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs.*
+
+The original codebase can be found at [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## LDMTextToImagePipeline
+[[autodoc]] LDMTextToImagePipeline
+ - all
+ - __call__
+
+## LDMSuperResolutionPipeline
+[[autodoc]] LDMSuperResolutionPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/latent_diffusion_uncond.md b/diffusers/docs/source/en/api/pipelines/latent_diffusion_uncond.md
new file mode 100644
index 0000000000000000000000000000000000000000..54835c2115b9567cc841f4b9f3f0a59c261a8b1c
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/latent_diffusion_uncond.md
@@ -0,0 +1,35 @@
+
+
+# Unconditional Latent Diffusion
+
+Unconditional Latent Diffusion was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
+
+The abstract from the paper is:
+
+*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs.*
+
+The original codebase can be found at [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## LDMPipeline
+[[autodoc]] LDMPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/model_editing.md b/diffusers/docs/source/en/api/pipelines/model_editing.md
new file mode 100644
index 0000000000000000000000000000000000000000..2d94a50e4355992946ce2c7ff452015b7e046dea
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/model_editing.md
@@ -0,0 +1,35 @@
+
+
+# Text-to-image model editing
+
+[Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://huggingface.co/papers/2303.08084) is by Hadas Orgad, Bahjat Kawar, and Yonatan Belinkov. This pipeline enables editing diffusion model weights, such that its assumptions of a given concept are changed. The resulting change is expected to take effect in all prompt generations related to the edited concept.
+
+The abstract from the paper is:
+
+*Text-to-image diffusion models often make implicit assumptions about the world when generating images. While some assumptions are useful (e.g., the sky is blue), they can also be outdated, incorrect, or reflective of social biases present in the training data. Thus, there is a need to control these assumptions without requiring explicit user input or costly re-training. In this work, we aim to edit a given implicit assumption in a pre-trained diffusion model. Our Text-to-Image Model Editing method, TIME for short, receives a pair of inputs: a "source" under-specified prompt for which the model makes an implicit assumption (e.g., "a pack of roses"), and a "destination" prompt that describes the same setting, but with a specified desired attribute (e.g., "a pack of blue roses"). TIME then updates the model's cross-attention layers, as these layers assign visual meaning to textual tokens. We edit the projection matrices in these layers such that the source prompt is projected close to the destination prompt. Our method is highly efficient, as it modifies a mere 2.2% of the model's parameters in under one second. To evaluate model editing approaches, we introduce TIMED (TIME Dataset), containing 147 source and destination prompt pairs from various domains. Our experiments (using Stable Diffusion) show that TIME is successful in model editing, generalizes well for related prompts unseen during editing, and imposes minimal effect on unrelated generations.*
+
+You can find additional information about model editing on the [project page](https://time-diffusion.github.io/), [original codebase](https://github.com/bahjat-kawar/time-diffusion), and try it out in a [demo](https://huggingface.co/spaces/bahjat-kawar/time-diffusion).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionModelEditingPipeline
+[[autodoc]] StableDiffusionModelEditingPipeline
+ - __call__
+ - all
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/musicldm.md b/diffusers/docs/source/en/api/pipelines/musicldm.md
new file mode 100644
index 0000000000000000000000000000000000000000..896f707c76d7599694706c7bc740d278fdd4768b
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/musicldm.md
@@ -0,0 +1,52 @@
+
+
+# MusicLDM
+
+MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
+MusicLDM takes a text prompt as input and predicts the corresponding music sample.
+
+Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) and [AudioLDM](https://huggingface.co/docs/diffusers/api/pipelines/audioldm),
+MusicLDM is a text-to-music _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)
+latents.
+
+MusicLDM is trained on a corpus of 466 hours of music data. Beat-synchronous data augmentation strategies are applied to the music samples, both in the time domain and in the latent space. Using beat-synchronous data augmentation strategies encourages the model to interpolate between the training samples, but stay within the domain of the training data. The result is generated music that is more diverse while staying faithful to the corresponding style.
+
+The abstract of the paper is the following:
+
+*Diffusion models have shown promising results in cross-modal generation tasks, including text-to-image and text-to-audio generation. However, generating music, as a special type of audio, presents unique challenges due to limited availability of music data and sensitive issues related to copyright and plagiarism. In this paper, to tackle these challenges, we first construct a state-of-the-art text-to-music model, MusicLDM, that adapts Stable Diffusion and AudioLDM architectures to the music domain. We achieve this by retraining the contrastive language-audio pretraining model (CLAP) and the Hifi-GAN vocoder, as components of MusicLDM, on a collection of music data samples. Then, to address the limitations of training data and to avoid plagiarism, we leverage a beat tracking model and propose two different mixup strategies for data augmentation: beat-synchronous audio mixup and beat-synchronous latent mixup, which recombine training audio directly or via a latent embeddings space, respectively. Such mixup strategies encourage the model to interpolate between musical training samples and generate new music within the convex hull of the training data, making the generated music more diverse while still staying faithful to the corresponding style. In addition to popular evaluation metrics, we design several new evaluation metrics based on CLAP score to demonstrate that our proposed MusicLDM and beat-synchronous mixup strategies improve both the quality and novelty of generated music, as well as the correspondence between input text and generated music.*
+
+This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi).
+
+## Tips
+
+When constructing a prompt, keep in mind:
+
+* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno").
+* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality".
+
+During inference:
+
+* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
+* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
+* The _length_ of the generated audio sample can be controlled by varying the `audio_length_in_s` argument.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## MusicLDMPipeline
+[[autodoc]] MusicLDMPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/overview.md b/diffusers/docs/source/en/api/pipelines/overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..a7f4f477ef82a8a96e68fdc69611dc678f218fbf
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/overview.md
@@ -0,0 +1,101 @@
+
+
+# Pipelines
+
+Pipelines provide a simple way to run state-of-the-art diffusion models in inference by bundling all of the necessary components (multiple independently-trained models, schedulers, and processors) into a single end-to-end class. Pipelines are flexible and they can be adapted to use different schedulers or even model components.
+
+All pipelines are built from the base [`DiffusionPipeline`] class which provides basic functionality for loading, downloading, and saving all the components. Specific pipeline types (for example [`StableDiffusionPipeline`]) loaded with [`~DiffusionPipeline.from_pretrained`] are automatically detected and the pipeline components are loaded and passed to the `__init__` function of the pipeline.
+
+
+
+You shouldn't use the [`DiffusionPipeline`] class for training. Individual components (for example, [`UNet2DModel`] and [`UNet2DConditionModel`]) of diffusion pipelines are usually trained individually, so we suggest directly working with them instead.
+
+
+
+Pipelines do not offer any training functionality. You'll notice PyTorch's autograd is disabled by decorating the [`~DiffusionPipeline.__call__`] method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should not be used for training. If you're interested in training, please take a look at the [Training](../../training/overview) guides instead!
+
+
+
+The table below lists all the pipelines currently available in 🤗 Diffusers and the tasks they support. Click on a pipeline to view its abstract and published paper.
+
+| Pipeline | Tasks |
+|---|---|
+| [AltDiffusion](alt_diffusion) | image2image |
+| [AnimateDiff](animatediff) | text2video |
+| [Attend-and-Excite](attend_and_excite) | text2image |
+| [Audio Diffusion](audio_diffusion) | image2audio |
+| [AudioLDM](audioldm) | text2audio |
+| [AudioLDM2](audioldm2) | text2audio |
+| [BLIP Diffusion](blip_diffusion) | text2image |
+| [Consistency Models](consistency_models) | unconditional image generation |
+| [ControlNet](controlnet) | text2image, image2image, inpainting |
+| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
+| [Cycle Diffusion](cycle_diffusion) | image2image |
+| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
+| [DDIM](ddim) | unconditional image generation |
+| [DDPM](ddpm) | unconditional image generation |
+| [DeepFloyd IF](deepfloyd_if) | text2image, image2image, inpainting, super-resolution |
+| [DiffEdit](diffedit) | inpainting |
+| [DiT](dit) | text2image |
+| [GLIGEN](stable_diffusion/gligen) | text2image |
+| [InstructPix2Pix](pix2pix) | image editing |
+| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
+| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
+| [Latent Consistency Models](latent_consistency_models) | text2image |
+| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
+| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D |
+| [MultiDiffusion](panorama) | text2image |
+| [MusicLDM](musicldm) | text2audio |
+| [Paint by Example](paint_by_example) | inpainting |
+| [ParaDiGMS](paradigms) | text2image |
+| [Pix2Pix Zero](pix2pix_zero) | image editing |
+| [PixArt-α](pixart) | text2image |
+| [PNDM](pndm) | unconditional image generation |
+| [RePaint](repaint) | inpainting |
+| [Score SDE VE](score_sde_ve) | unconditional image generation |
+| [Self-Attention Guidance](self_attention_guidance) | text2image |
+| [Semantic Guidance](semantic_stable_diffusion) | text2image |
+| [Shap-E](shap_e) | text-to-3D, image-to-3D |
+| [Spectrogram Diffusion](spectrogram_diffusion) | |
+| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
+| [Stable Diffusion Model Editing](model_editing) | model editing |
+| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
+| [Stable unCLIP](stable_unclip) | text2image, image variation |
+| [Stochastic Karras VE](stochastic_karras_ve) | unconditional image generation |
+| [T2I-Adapter](stable_diffusion/adapter) | text2image |
+| [Text2Video](text_to_video) | text2video, video2video |
+| [Text2Video-Zero](text_to_video_zero) | text2video |
+| [unCLIP](unclip) | text2image, image variation |
+| [Unconditional Latent Diffusion](latent_diffusion_uncond) | unconditional image generation |
+| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
+| [Value-guided planning](value_guided_sampling) | value guided sampling |
+| [Versatile Diffusion](versatile_diffusion) | text2image, image variation |
+| [VQ Diffusion](vq_diffusion) | text2image |
+| [Wuerstchen](wuerstchen) | text2image |
+
+## DiffusionPipeline
+
+[[autodoc]] DiffusionPipeline
+ - all
+ - __call__
+ - device
+ - to
+ - components
+
+## FlaxDiffusionPipeline
+
+[[autodoc]] pipelines.pipeline_flax_utils.FlaxDiffusionPipeline
+
+## PushToHubMixin
+
+[[autodoc]] utils.PushToHubMixin
diff --git a/diffusers/docs/source/en/api/pipelines/paint_by_example.md b/diffusers/docs/source/en/api/pipelines/paint_by_example.md
new file mode 100644
index 0000000000000000000000000000000000000000..b89e80cbb254423074de3baadb53e11e1534bf51
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/paint_by_example.md
@@ -0,0 +1,39 @@
+
+
+# Paint by Example
+
+[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.
+
+The abstract from the paper is:
+
+*Language-guided image editing has achieved great success recently. In this paper, for the first time, we investigate exemplar-guided image editing for more precise control. We achieve this goal by leveraging self-supervised training to disentangle and re-organize the source image and the exemplar. However, the naive approach will cause obvious fusing artifacts. We carefully analyze it and propose an information bottleneck and strong augmentations to avoid the trivial solution of directly copying and pasting the exemplar image. Meanwhile, to ensure the controllability of the editing process, we design an arbitrary shape mask for the exemplar image and leverage the classifier-free guidance to increase the similarity to the exemplar image. The whole framework involves a single forward of the diffusion model without any iterative optimization. We demonstrate that our method achieves an impressive performance and enables controllable editing on in-the-wild images with high fidelity.*
+
+The original codebase can be found at [Fantasy-Studio/Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example), and you can try it out in a [demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example).
+
+## Tips
+
+Paint by Example is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint is warm-started from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to inpaint partly masked images conditioned on example and reference images.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## PaintByExamplePipeline
+[[autodoc]] PaintByExamplePipeline
+ - all
+ - __call__
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/panorama.md b/diffusers/docs/source/en/api/pipelines/panorama.md
new file mode 100644
index 0000000000000000000000000000000000000000..8aa86112aa89a9ee827c5b15c6dad679628c06f5
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/panorama.md
@@ -0,0 +1,50 @@
+
+
+# MultiDiffusion
+
+[MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://huggingface.co/papers/2302.08113) is by Omer Bar-Tal, Lior Yariv, Yaron Lipman, and Tali Dekel.
+
+The abstract from the paper is:
+
+*Recent advances in text-to-image generation with diffusion models present transformative capabilities in image quality. However, user controllability of the generated image, and fast adaptation to new tasks still remains an open challenge, currently mostly addressed by costly and long re-training and fine-tuning or ad-hoc adaptations to specific image generation tasks. In this work, we present MultiDiffusion, a unified framework that enables versatile and controllable image generation, using a pre-trained text-to-image diffusion model, without any further training or finetuning. At the center of our approach is a new generation process, based on an optimization task that binds together multiple diffusion generation processes with a shared set of parameters or constraints. We show that MultiDiffusion can be readily applied to generate high quality and diverse images that adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.*
+
+You can find additional information about MultiDiffusion on the [project page](https://multidiffusion.github.io/), [original codebase](https://github.com/omerbt/MultiDiffusion), and try it out in a [demo](https://huggingface.co/spaces/weizmannscience/MultiDiffusion).
+
+## Tips
+
+While calling [`StableDiffusionPanoramaPipeline`], it's possible to specify the `view_batch_size` parameter to be > 1.
+For some GPUs with high performance, this can speedup the generation process and increase VRAM usage.
+
+To generate panorama-like images make sure you pass the width parameter accordingly. We recommend a width value of 2048 which is the default.
+
+Circular padding is applied to ensure there are no stitching artifacts when working with panoramas to ensure a seamless transition from the rightmost part to the leftmost part. By enabling circular padding (set `circular_padding=True`), the operation applies additional crops after the rightmost point of the image, allowing the model to "see” the transition from the rightmost part to the leftmost part. This helps maintain visual consistency in a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree panorama viewers. When decoding latents in Stable Diffusion, circular padding is applied to ensure that the decoded latents match in the RGB space.
+
+For example, without circular padding, there is a stitching artifact (default):
+![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20no_circular_padding.png)
+
+But with circular padding, the right and the left parts are matching (`circular_padding=True`):
+![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20circular_padding.png)
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionPanoramaPipeline
+[[autodoc]] StableDiffusionPanoramaPipeline
+ - __call__
+ - all
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/paradigms.md b/diffusers/docs/source/en/api/pipelines/paradigms.md
new file mode 100644
index 0000000000000000000000000000000000000000..ca2fedc796df373366a203fd54dfc238c17f6fb2
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/paradigms.md
@@ -0,0 +1,51 @@
+
+
+# Parallel Sampling of Diffusion Models
+
+[Parallel Sampling of Diffusion Models](https://huggingface.co/papers/2305.16317) is by Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, Nima Anari.
+
+The abstract from the paper is:
+
+*Diffusion models are powerful generative models but suffer from slow sampling, often taking 1000 sequential denoising steps for one sample. As a result, considerable efforts have been directed toward reducing the number of denoising steps, but these methods hurt sample quality. Instead of reducing the number of denoising steps (trading quality for speed), in this paper we explore an orthogonal approach: can we run the denoising steps in parallel (trading compute for speed)? In spite of the sequential nature of the denoising steps, we show that surprisingly it is possible to parallelize sampling via Picard iterations, by guessing the solution of future denoising steps and iteratively refining until convergence. With this insight, we present ParaDiGMS, a novel method to accelerate the sampling of pretrained diffusion models by denoising multiple steps in parallel. ParaDiGMS is the first diffusion sampling method that enables trading compute for speed and is even compatible with existing fast sampling techniques such as DDIM and DPMSolver. Using ParaDiGMS, we improve sampling speed by 2-4x across a range of robotics and image generation models, giving state-of-the-art sampling speeds of 0.2s on 100-step DiffusionPolicy and 14.6s on 1000-step StableDiffusion-v2 with no measurable degradation of task reward, FID score, or CLIP score.*
+
+The original codebase can be found at [AndyShih12/paradigms](https://github.com/AndyShih12/paradigms), and the pipeline was contributed by [AndyShih12](https://github.com/AndyShih12). ❤️
+
+## Tips
+
+This pipeline improves sampling speed by running denoising steps in parallel, at the cost of increased total FLOPs.
+Therefore, it is better to call this pipeline when running on multiple GPUs. Otherwise, without enough GPU bandwidth
+sampling may be even slower than sequential sampling.
+
+The two parameters to play with are `parallel` (batch size) and `tolerance`.
+- If it fits in memory, for a 1000-step DDPM you can aim for a batch size of around 100 (for example, 8 GPUs and `batch_per_device=12` to get `parallel=96`). A higher batch size may not fit in memory, and lower batch size gives less parallelism.
+- For tolerance, using a higher tolerance may get better speedups but can risk sample quality degradation. If there is quality degradation with the default tolerance, then use a lower tolerance like `0.001`.
+
+For a 1000-step DDPM on 8 A100 GPUs, you can expect around a 3x speedup from [`StableDiffusionParadigmsPipeline`] compared to the [`StableDiffusionPipeline`]
+by setting `parallel=80` and `tolerance=0.1`.
+
+🤗 Diffusers offers [distributed inference support](../../training/distributed_inference) for generating multiple prompts
+in parallel on multiple GPUs. But [`StableDiffusionParadigmsPipeline`] is designed for speeding up sampling of a single prompt by using multiple GPUs.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionParadigmsPipeline
+[[autodoc]] StableDiffusionParadigmsPipeline
+ - __call__
+ - all
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/pix2pix.md b/diffusers/docs/source/en/api/pipelines/pix2pix.md
new file mode 100644
index 0000000000000000000000000000000000000000..4fd76cfb56c26222f79aa7d0102b2b52345655de
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/pix2pix.md
@@ -0,0 +1,40 @@
+
+
+# InstructPix2Pix
+
+[InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/papers/2211.09800) is by Tim Brooks, Aleksander Holynski and Alexei A. Efros.
+
+The abstract from the paper is:
+
+*We propose a method for editing images from human instructions: given an input image and a written instruction that tells the model what to do, our model follows these instructions to edit the image. To obtain training data for this problem, we combine the knowledge of two large pretrained models -- a language model (GPT-3) and a text-to-image model (Stable Diffusion) -- to generate a large dataset of image editing examples. Our conditional diffusion model, InstructPix2Pix, is trained on our generated data, and generalizes to real images and user-written instructions at inference time. Since it performs edits in the forward pass and does not require per example fine-tuning or inversion, our model edits images quickly, in a matter of seconds. We show compelling editing results for a diverse collection of input images and written instructions.*
+
+You can find additional information about InstructPix2Pix on the [project page](https://www.timothybrooks.com/instruct-pix2pix), [original codebase](https://github.com/timothybrooks/instruct-pix2pix), and try it out in a [demo](https://huggingface.co/spaces/timbrooks/instruct-pix2pix).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionInstructPix2PixPipeline
+[[autodoc]] StableDiffusionInstructPix2PixPipeline
+ - __call__
+ - all
+ - load_textual_inversion
+ - load_lora_weights
+ - save_lora_weights
+
+## StableDiffusionXLInstructPix2PixPipeline
+[[autodoc]] StableDiffusionXLInstructPix2PixPipeline
+ - __call__
+ - all
diff --git a/diffusers/docs/source/en/api/pipelines/pix2pix_zero.md b/diffusers/docs/source/en/api/pipelines/pix2pix_zero.md
new file mode 100644
index 0000000000000000000000000000000000000000..6d7b9fb31471565492e297f245144df2261a3619
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/pix2pix_zero.md
@@ -0,0 +1,289 @@
+
+
+# Pix2Pix Zero
+
+[Zero-shot Image-to-Image Translation](https://huggingface.co/papers/2302.03027) is by Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, and Jun-Yan Zhu.
+
+The abstract from the paper is:
+
+*Large-scale text-to-image generative models have shown their remarkable ability to synthesize diverse and high-quality images. However, it is still challenging to directly apply these models for editing real images for two reasons. First, it is hard for users to come up with a perfect text prompt that accurately describes every visual detail in the input image. Second, while existing models can introduce desirable changes in certain regions, they often dramatically alter the input content and introduce unexpected changes in unwanted regions. In this work, we propose pix2pix-zero, an image-to-image translation method that can preserve the content of the original image without manual prompting. We first automatically discover editing directions that reflect desired edits in the text embedding space. To preserve the general content structure after editing, we further propose cross-attention guidance, which aims to retain the cross-attention maps of the input image throughout the diffusion process. In addition, our method does not need additional training for these edits and can directly use the existing pre-trained text-to-image diffusion model. We conduct extensive experiments and show that our method outperforms existing and concurrent works for both real and synthetic image editing.*
+
+You can find additional information about Pix2Pix Zero on the [project page](https://pix2pixzero.github.io/), [original codebase](https://github.com/pix2pixzero/pix2pix-zero), and try it out in a [demo](https://huggingface.co/spaces/pix2pix-zero-library/pix2pix-zero-demo).
+
+## Tips
+
+* The pipeline can be conditioned on real input images. Check out the code examples below to know more.
+* The pipeline exposes two arguments namely `source_embeds` and `target_embeds`
+that let you control the direction of the semantic edits in the final image to be generated. Let's say,
+you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect
+this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to
+`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details.
+* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking
+the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gogh".
+* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to:
+ * Swap the `source_embeds` and `target_embeds`.
+ * Change the input prompt to include "dog".
+* To learn more about how the source and target embeddings are generated, refer to the [original paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions on how to generate the embeddings.
+* Note that the quality of the outputs generated with this pipeline is dependent on how good the `source_embeds` and `target_embeds` are. Please, refer to [this discussion](#generating-source-and-target-embeddings) for some suggestions on the topic.
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Demo
+|---|---|:---:|
+| [StableDiffusionPix2PixZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py) | *Text-Based Image Editing* | [🤗 Space](https://huggingface.co/spaces/pix2pix-zero-library/pix2pix-zero-demo) |
+
+
+
+## Usage example
+
+### Based on an image generated with the input prompt
+
+```python
+import requests
+import torch
+
+from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
+
+
+def download(embedding_url, local_filepath):
+ r = requests.get(embedding_url)
+ with open(local_filepath, "wb") as f:
+ f.write(r.content)
+
+
+model_ckpt = "CompVis/stable-diffusion-v1-4"
+pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
+ model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
+)
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+pipeline.to("cuda")
+
+prompt = "a high resolution painting of a cat in the style of van gogh"
+src_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt"
+target_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt"
+
+for url in [src_embs_url, target_embs_url]:
+ download(url, url.split("/")[-1])
+
+src_embeds = torch.load(src_embs_url.split("/")[-1])
+target_embeds = torch.load(target_embs_url.split("/")[-1])
+
+image = pipeline(
+ prompt,
+ source_embeds=src_embeds,
+ target_embeds=target_embeds,
+ num_inference_steps=50,
+ cross_attention_guidance_amount=0.15,
+).images[0]
+image
+```
+
+### Based on an input image
+
+When the pipeline is conditioned on an input image, we first obtain an inverted
+noise from it using a `DDIMInverseScheduler` with the help of a generated caption. Then the inverted noise is used to start the generation process.
+
+First, let's load our pipeline:
+
+```py
+import torch
+from transformers import BlipForConditionalGeneration, BlipProcessor
+from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
+
+captioner_id = "Salesforce/blip-image-captioning-base"
+processor = BlipProcessor.from_pretrained(captioner_id)
+model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
+pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
+ sd_model_ckpt,
+ caption_generator=model,
+ caption_processor=processor,
+ torch_dtype=torch.float16,
+ safety_checker=None,
+)
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
+pipeline.enable_model_cpu_offload()
+```
+
+Then, we load an input image for conditioning and obtain a suitable caption for it:
+
+```py
+from diffusers.utils import load_image
+
+img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
+raw_image = load_image(url).resize((512, 512))
+caption = pipeline.generate_caption(raw_image)
+caption
+```
+
+Then we employ the generated caption and the input image to get the inverted noise:
+
+```py
+generator = torch.manual_seed(0)
+inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents
+```
+
+Now, generate the image with edit directions:
+
+```py
+# See the "Generating source and target embeddings" section below to
+# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below.
+source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
+target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
+
+source_embeds = pipeline.get_embeds(source_prompts, batch_size=2)
+target_embeds = pipeline.get_embeds(target_prompts, batch_size=2)
+
+
+image = pipeline(
+ caption,
+ source_embeds=source_embeds,
+ target_embeds=target_embeds,
+ num_inference_steps=50,
+ cross_attention_guidance_amount=0.15,
+ generator=generator,
+ latents=inv_latents,
+ negative_prompt=caption,
+).images[0]
+image
+```
+
+## Generating source and target embeddings
+
+The authors originally used the [GPT-3 API](https://openai.com/api/) to generate the source and target captions for discovering
+edit directions. However, we can also leverage open source and public models for the same purpose.
+Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model
+for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for
+computing embeddings on the generated captions.
+
+**1. Load the generation model**:
+
+```py
+import torch
+from transformers import AutoTokenizer, T5ForConditionalGeneration
+
+tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
+model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
+```
+
+**2. Construct a starting prompt**:
+
+```py
+source_concept = "cat"
+target_concept = "dog"
+
+source_text = f"Provide a caption for images containing a {source_concept}. "
+"The captions should be in English and should be no longer than 150 characters."
+
+target_text = f"Provide a caption for images containing a {target_concept}. "
+"The captions should be in English and should be no longer than 150 characters."
+```
+
+Here, we're interested in the "cat -> dog" direction.
+
+**3. Generate captions**:
+
+We can use a utility like so for this purpose.
+
+```py
+def generate_captions(input_prompt):
+ input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
+
+ outputs = model.generate(
+ input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
+ )
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
+```
+
+And then we just call it to generate our captions:
+
+```py
+source_captions = generate_captions(source_text)
+target_captions = generate_captions(target_concept)
+print(source_captions, target_captions, sep='\n')
+```
+
+We encourage you to play around with the different parameters supported by the
+`generate()` method ([documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.generation_tf_utils.TFGenerationMixin.generate)) for the generation quality you are looking for.
+
+**4. Load the embedding model**:
+
+Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model.
+
+```py
+from diffusers import StableDiffusionPix2PixZeroPipeline
+
+pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
+)
+pipeline = pipeline.to("cuda")
+tokenizer = pipeline.tokenizer
+text_encoder = pipeline.text_encoder
+```
+
+**5. Compute embeddings**:
+
+```py
+import torch
+
+def embed_captions(sentences, tokenizer, text_encoder, device="cuda"):
+ with torch.no_grad():
+ embeddings = []
+ for sent in sentences:
+ text_inputs = tokenizer(
+ sent,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
+ embeddings.append(prompt_embeds)
+ return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)
+
+source_embeddings = embed_captions(source_captions, tokenizer, text_encoder)
+target_embeddings = embed_captions(target_captions, tokenizer, text_encoder)
+```
+
+And you're done! [Here](https://colab.research.google.com/drive/1tz2C1EdfZYAPlzXXbTnf-5PRBiR8_R1F?usp=sharing) is a Colab Notebook that you can use to interact with the entire process.
+
+Now, you can use these embeddings directly while calling the pipeline:
+
+```py
+from diffusers import DDIMScheduler
+
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+
+image = pipeline(
+ prompt,
+ source_embeds=source_embeddings,
+ target_embeds=target_embeddings,
+ num_inference_steps=50,
+ cross_attention_guidance_amount=0.15,
+).images[0]
+image
+```
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionPix2PixZeroPipeline
+[[autodoc]] StableDiffusionPix2PixZeroPipeline
+ - __call__
+ - all
diff --git a/diffusers/docs/source/en/api/pipelines/pixart.md b/diffusers/docs/source/en/api/pipelines/pixart.md
new file mode 100644
index 0000000000000000000000000000000000000000..6fa44cd508e40d924452e3ecd1f4c1fa2131a7b5
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/pixart.md
@@ -0,0 +1,43 @@
+
+
+# PixArt-α
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage.png)
+
+[PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis](https://huggingface.co/papers/2310.00426) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.
+
+The abstract from the paper is:
+
+*The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α's training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.*
+
+You can find the original codebase at [PixArt-alpha/PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha).
+
+Some notes about this pipeline:
+
+* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](./dit).
+* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details.
+* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py).
+* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## PixArtAlphaPipeline
+
+[[autodoc]] PixArtAlphaPipeline
+ - all
+ - __call__
+
\ No newline at end of file
diff --git a/diffusers/docs/source/en/api/pipelines/pndm.md b/diffusers/docs/source/en/api/pipelines/pndm.md
new file mode 100644
index 0000000000000000000000000000000000000000..162e7934dc22620029f12809335a1a03600bd758
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/pndm.md
@@ -0,0 +1,35 @@
+
+
+# PNDM
+
+[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://huggingface.co/papers/2202.09778) (PNDM) is by Luping Liu, Yi Ren, Zhijie Lin and Zhou Zhao.
+
+The abstract from the paper is:
+
+*Denoising Diffusion Probabilistic Models (DDPMs) can generate high-quality samples such as image and audio samples. However, DDPMs require hundreds to thousands of iterations to produce final samples. Several prior works have successfully accelerated DDPMs through adjusting the variance schedule (e.g., Improved Denoising Diffusion Probabilistic Models) or the denoising equation (e.g., Denoising Diffusion Implicit Models (DDIMs)). However, these acceleration methods cannot maintain the quality of samples and even introduce new noise at a high speedup rate, which limit their practicability. To accelerate the inference process while keeping the sample quality, we provide a fresh perspective that DDPMs should be treated as solving differential equations on manifolds. Under such a perspective, we propose pseudo numerical methods for diffusion models (PNDMs). Specifically, we figure out how to solve differential equations on manifolds and show that DDIMs are simple cases of pseudo numerical methods. We change several classical numerical methods to corresponding pseudo numerical methods and find that the pseudo linear multi-step method is the best in most situations. According to our experiments, by directly using pre-trained models on Cifar10, CelebA and LSUN, PNDMs can generate higher quality synthetic images with only 50 steps compared with 1000-step DDIMs (20x speedup), significantly outperform DDIMs with 250 steps (by around 0.4 in FID) and have good generalization on different variance schedules.*
+
+The original codebase can be found at [luping-liu/PNDM](https://github.com/luping-liu/PNDM).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## PNDMPipeline
+[[autodoc]] PNDMPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/repaint.md b/diffusers/docs/source/en/api/pipelines/repaint.md
new file mode 100644
index 0000000000000000000000000000000000000000..1be69a3f9a46202ca5f3f29d12fcebee36230083
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/repaint.md
@@ -0,0 +1,37 @@
+
+
+# RePaint
+
+[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2201.09865) is by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
+
+The abstract from the paper is:
+
+*Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
+RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.*
+
+The original codebase can be found at [andreas128/RePaint](https://github.com/andreas128/RePaint).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+
+## RePaintPipeline
+[[autodoc]] RePaintPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/score_sde_ve.md b/diffusers/docs/source/en/api/pipelines/score_sde_ve.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc9c8574f92def493501e18336d8326207fd1669
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/score_sde_ve.md
@@ -0,0 +1,35 @@
+
+
+# Score SDE VE
+
+[Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) (Score SDE) is by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon and Ben Poole. This pipeline implements the variance expanding (VE) variant of the stochastic differential equation method.
+
+The abstract from the paper is:
+
+*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*
+
+The original codebase can be found at [yang-song/score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## ScoreSdeVePipeline
+[[autodoc]] ScoreSdeVePipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/self_attention_guidance.md b/diffusers/docs/source/en/api/pipelines/self_attention_guidance.md
new file mode 100644
index 0000000000000000000000000000000000000000..408e62daf988f1885f1d8df20baae39f838c157f
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/self_attention_guidance.md
@@ -0,0 +1,35 @@
+
+
+# Self-Attention Guidance
+
+[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.
+
+The abstract from the paper is:
+
+*Denoising diffusion models (DDMs) have attracted attention for their exceptional generation quality and diversity. This success is largely attributed to the use of class- or text-conditional diffusion guidance methods, such as classifier and classifier-free guidance. In this paper, we present a more comprehensive perspective that goes beyond the traditional guidance methods. From this generalized perspective, we introduce novel condition- and training-free strategies to enhance the quality of generated images. As a simple solution, blur guidance improves the suitability of intermediate samples for their fine-scale information and structures, enabling diffusion models to generate higher quality samples with a moderate guidance scale. Improving upon this, Self-Attention Guidance (SAG) uses the intermediate self-attention maps of diffusion models to enhance their stability and efficacy. Specifically, SAG adversarially blurs only the regions that diffusion models attend to at each iteration and guides them accordingly. Our experimental results show that our SAG improves the performance of various diffusion models, including ADM, IDDPM, Stable Diffusion, and DiT. Moreover, combining SAG with conventional guidance methods leads to further improvement.*
+
+You can find additional information about Self-Attention Guidance on the [project page](https://ku-cvlab.github.io/Self-Attention-Guidance), [original codebase](https://github.com/KU-CVLAB/Self-Attention-Guidance), and try it out in a [demo](https://huggingface.co/spaces/susunghong/Self-Attention-Guidance) or [notebook](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionSAGPipeline
+[[autodoc]] StableDiffusionSAGPipeline
+ - __call__
+ - all
+
+## StableDiffusionOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/diffusers/docs/source/en/api/pipelines/semantic_stable_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..d7b393447cf8603475abd410009342d7cb8e7371
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/semantic_stable_diffusion.md
@@ -0,0 +1,35 @@
+
+
+# Semantic Guidance
+
+Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.
+Small changes to the text prompt usually result in entirely different output images. However, with SEGA a variety of changes to the image are enabled that can be controlled easily and intuitively, while staying true to the original image composition.
+
+The abstract from the paper is:
+
+*Text-to-image diffusion models have recently received a lot of interest for their astonishing ability to produce high-fidelity images from text only. However, achieving one-shot generation that aligns with the user's intent is nearly impossible, yet small changes to the input prompt often result in very different images. This leaves the user with little semantic control. To put the user in control, we show how to interact with the diffusion process to flexibly steer it along semantic directions. This semantic guidance (SEGA) generalizes to any generative architecture using classifier-free guidance. More importantly, it allows for subtle and extensive edits, changes in composition and style, as well as optimizing the overall artistic conception. We demonstrate SEGA's effectiveness on both latent and pixel-based diffusion models such as Stable Diffusion, Paella, and DeepFloyd-IF using a variety of tasks, thus providing strong evidence for its versatility, flexibility, and improvements over existing methods.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## SemanticStableDiffusionPipeline
+[[autodoc]] SemanticStableDiffusionPipeline
+ - all
+ - __call__
+
+## StableDiffusionSafePipelineOutput
+[[autodoc]] pipelines.semantic_stable_diffusion.pipeline_output.SemanticStableDiffusionPipelineOutput
+ - all
diff --git a/diffusers/docs/source/en/api/pipelines/shap_e.md b/diffusers/docs/source/en/api/pipelines/shap_e.md
new file mode 100644
index 0000000000000000000000000000000000000000..bbf904afb5c8c37bd82c030725ebfbb2f92bf275
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/shap_e.md
@@ -0,0 +1,37 @@
+
+
+# Shap-E
+
+The Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://huggingface.co/papers/2305.02463) by Alex Nichol and Heewoo Jun from [OpenAI](https://github.com/openai).
+
+The abstract from the paper is:
+
+*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.*
+
+The original codebase can be found at [openai/shap-e](https://github.com/openai/shap-e).
+
+
+
+See the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## ShapEPipeline
+[[autodoc]] ShapEPipeline
+ - all
+ - __call__
+
+## ShapEImg2ImgPipeline
+[[autodoc]] ShapEImg2ImgPipeline
+ - all
+ - __call__
+
+## ShapEPipelineOutput
+[[autodoc]] pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/spectrogram_diffusion.md b/diffusers/docs/source/en/api/pipelines/spectrogram_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc9ff3e45646315be303b75fc34ea61cc82609d3
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/spectrogram_diffusion.md
@@ -0,0 +1,37 @@
+
+
+# Spectrogram Diffusion
+
+[Spectrogram Diffusion](https://huggingface.co/papers/2206.05408) is by Curtis Hawthorne, Ian Simon, Adam Roberts, Neil Zeghidour, Josh Gardner, Ethan Manilow, and Jesse Engel.
+
+*An ideal music synthesizer should be both interactive and expressive, generating high-fidelity audio in realtime for arbitrary combinations of instruments and notes. Recent neural synthesizers have exhibited a tradeoff between domain-specific models that offer detailed control of only specific instruments, or raw waveform models that can train on any music but with minimal control and slow generation. In this work, we focus on a middle ground of neural synthesizers that can generate audio from MIDI sequences with arbitrary combinations of instruments in realtime. This enables training on a wide range of transcription datasets with a single model, which in turn offers note-level control of composition and instrumentation across a wide range of instruments. We use a simple two-stage process: MIDI to spectrograms with an encoder-decoder Transformer, then spectrograms to audio with a generative adversarial network (GAN) spectrogram inverter. We compare training the decoder as an autoregressive model and as a Denoising Diffusion Probabilistic Model (DDPM) and find that the DDPM approach is superior both qualitatively and as measured by audio reconstruction and Fréchet distance metrics. Given the interactivity and generality of this approach, we find this to be a promising first step towards interactive and expressive neural synthesis for arbitrary combinations of instruments and notes.*
+
+The original codebase can be found at [magenta/music-spectrogram-diffusion](https://github.com/magenta/music-spectrogram-diffusion).
+
+![img](https://storage.googleapis.com/music-synthesis-with-spectrogram-diffusion/architecture.png)
+
+As depicted above the model takes as input a MIDI file and tokenizes it into a sequence of 5 second intervals. Each tokenized interval then together with positional encodings is passed through the Note Encoder and its representation is concatenated with the previous window's generated spectrogram representation obtained via the Context Encoder. For the initial 5 second window this is set to zero. The resulting context is then used as conditioning to sample the denoised Spectrogram from the MIDI window and we concatenate this spectrogram to the final output as well as use it for the context of the next MIDI window. The process repeats till we have gone over all the MIDI inputs. Finally a MelGAN decoder converts the potentially long spectrogram to audio which is the final result of this pipeline.
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## SpectrogramDiffusionPipeline
+[[autodoc]] SpectrogramDiffusionPipeline
+ - all
+ - __call__
+
+## AudioPipelineOutput
+[[autodoc]] pipelines.AudioPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/adapter.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/adapter.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e2e7fd250fceb81662ea4862f03eb705411afc6
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/adapter.md
@@ -0,0 +1,259 @@
+
+
+# Text-to-Image Generation with Adapter Conditioning
+
+## Overview
+
+[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
+
+Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.
+
+The abstract of the paper is the following:
+
+*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate controlling (e.g., color and structure) is needed. In this paper, we aim to ``dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and lightweight T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, achieving rich control and editing effects in the color and structure of the generation results. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.*
+
+This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .
+
+## Available Pipelines:
+
+| Pipeline | Tasks | Demo
+|---|---|:---:|
+| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
+| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
+
+## Usage example with the base model of StableDiffusion-1.4/1.5
+
+In the following we give a simple example of how to use a *T2I-Adapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
+All adapters use the same pipeline.
+
+ 1. Images are first converted into the appropriate *control image* format.
+ 2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`].
+
+Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1).
+
+```python
+from diffusers.utils import load_image, make_image_grid
+
+image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png")
+```
+
+![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png)
+
+
+Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size.
+
+```python
+from PIL import Image
+
+color_palette = image.resize((8, 8))
+color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
+```
+
+Let's take a look at the processed image.
+
+![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_palette.png)
+
+
+Next, create the adapter pipeline
+
+```py
+import torch
+from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
+
+adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16)
+pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ adapter=adapter,
+ torch_dtype=torch.float16,
+)
+pipe.to("cuda")
+```
+
+Finally, pass the prompt and control image to the pipeline
+
+```py
+# fix the random seed, so you will get the same result as the example
+generator = torch.Generator("cuda").manual_seed(7)
+
+out_image = pipe(
+ "At night, glowing cubes in front of the beach",
+ image=color_palette,
+ generator=generator,
+).images[0]
+make_image_grid([image, color_palette, out_image], rows=1, cols=3)
+```
+
+![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_output.png)
+
+## Usage example with the base model of StableDiffusion-XL
+
+In the following we give a simple example of how to use a *T2I-Adapter* checkpoint with Diffusers for inference based on StableDiffusion-XL.
+All adapters use the same pipeline.
+
+ 1. Images are first downloaded into the appropriate *control image* format.
+ 2. The *control image* and *prompt* are passed to the [`StableDiffusionXLAdapterPipeline`].
+
+Let's have a look at a simple example using the [Sketch Adapter](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0).
+
+```python
+from diffusers.utils import load_image, make_image_grid
+
+sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L")
+```
+
+![img](https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png)
+
+Then, create the adapter pipeline
+
+```py
+import torch
+from diffusers import (
+ T2IAdapter,
+ StableDiffusionXLAdapterPipeline,
+ DDPMScheduler
+)
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+adapter = T2IAdapter.from_pretrained("Adapter/t2iadapter", subfolder="sketch_sdxl_1.0", torch_dtype=torch.float16, adapter_type="full_adapter_xl")
+scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
+
+pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ model_id, adapter=adapter, safety_checker=None, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler
+)
+
+pipe.to("cuda")
+```
+
+Finally, pass the prompt and control image to the pipeline
+
+```py
+# fix the random seed, so you will get the same result as the example
+generator = torch.Generator().manual_seed(42)
+
+sketch_image_out = pipe(
+ prompt="a photo of a dog in real world, high quality",
+ negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
+ image=sketch_image,
+ generator=generator,
+ guidance_scale=7.5
+).images[0]
+make_image_grid([sketch_image, sketch_image_out], rows=1, cols=2)
+```
+
+![img](https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch_output.png)
+
+## Available checkpoints
+
+Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
+
+### T2I-Adapter with Stable Diffusion 1.4
+
+| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
+|---|---|---|---|
+|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1) *Trained with spatial color palette* | An image with 8x8 color palette.|||
+|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1) *Trained with canny edge detection* | A monochrome image with white edges on a black background.|||
+|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1) *Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|||
+|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1) *Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|||
+|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1) *Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|||
+|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1) *Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|||
+|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1) *Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|| |
+|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)||
+|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
+|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
+|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
+|[Adapter/t2iadapter, subfolder='sketch_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0)||
+|[Adapter/t2iadapter, subfolder='canny_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/canny_sdxl_1.0)||
+|[Adapter/t2iadapter, subfolder='openpose_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/openpose_sdxl_1.0)||
+
+## Combining multiple adapters
+
+[`MultiAdapter`] can be used for applying multiple conditionings at once.
+
+Here we use the keypose adapter for the character posture and the depth adapter for creating the scene.
+
+```py
+from diffusers.utils import load_image, make_image_grid
+
+cond_keypose = load_image(
+ "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
+)
+cond_depth = load_image(
+ "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
+)
+cond = [cond_keypose, cond_depth]
+
+prompt = ["A man walking in an office room with a nice view"]
+```
+
+The two control images look as such:
+
+![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png)
+![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png)
+
+
+`MultiAdapter` combines keypose and depth adapters.
+
+`adapter_conditioning_scale` balances the relative influence of the different adapters.
+
+```py
+import torch
+from diffusers import StableDiffusionAdapterPipeline, MultiAdapter, T2IAdapter
+
+adapters = MultiAdapter(
+ [
+ T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
+ T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
+ ]
+)
+adapters = adapters.to(torch.float16)
+
+pipe = StableDiffusionAdapterPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ torch_dtype=torch.float16,
+ adapter=adapters,
+).to("cuda")
+
+image = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8]).images[0]
+make_image_grid([cond_keypose, cond_depth, image], rows=1, cols=3)
+```
+
+![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_depth_sample_output.png)
+
+
+## T2I-Adapter vs ControlNet
+
+T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet).
+T2I-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process.
+However, T2I-Adapter performs slightly worse than ControlNet.
+
+## StableDiffusionAdapterPipeline
+[[autodoc]] StableDiffusionAdapterPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+
+## StableDiffusionXLAdapterPipeline
+[[autodoc]] StableDiffusionXLAdapterPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/depth2img.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/depth2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..f7c8f2de942018d1ce9ce29e60265734e7d81ddb
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/depth2img.md
@@ -0,0 +1,40 @@
+
+
+# Depth-to-image
+
+The Stable Diffusion model can also infer depth based on an image using [MiDaS](https://github.com/isl-org/MiDaS). This allows you to pass a text prompt and an initial image to condition the generation of new images as well as a `depth_map` to preserve the image structure.
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
+
+
+
+## StableDiffusionDepth2ImgPipeline
+
+[[autodoc]] StableDiffusionDepth2ImgPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - load_textual_inversion
+ - load_lora_weights
+ - save_lora_weights
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/gligen.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/gligen.md
new file mode 100644
index 0000000000000000000000000000000000000000..d981e892c053928c297b90c4309b10de6c10f91c
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/gligen.md
@@ -0,0 +1,59 @@
+
+
+# GLIGEN (Grounded Language-to-Image Generation)
+
+The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
+
+The abstract from the [paper](https://huggingface.co/papers/2301.07093) is:
+
+*Large-scale text-to-image diffusion models have made amazing advances. However, the status quo is to use text input alone, which can impede controllability. In this work, we propose GLIGEN, Grounded-Language-to-Image Generation, a novel approach that builds upon and extends the functionality of existing pre-trained text-to-image diffusion models by enabling them to also be conditioned on grounding inputs. To preserve the vast concept knowledge of the pre-trained model, we freeze all of its weights and inject the grounding information into new trainable layers via a gated mechanism. Our model achieves open-world grounded text2img generation with caption and bounding box condition inputs, and the grounding ability generalizes well to novel spatial configurations and concepts. GLIGEN’s zeroshot performance on COCO and LVIS outperforms existing supervised layout-to-image baselines by a large margin.*
+
+
+
+Make sure to check out the Stable Diffusion [Tips](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality and how to reuse pipeline components efficiently!
+
+If you want to use one of the official checkpoints for a task, explore the [gligen](https://huggingface.co/gligen) Hub organizations!
+
+
+
+[`StableDiffusionGLIGENPipeline`] was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful) and [`StableDiffusionGLIGENTextImagePipeline`] was contributed by [Nguyễn Công Tú Anh](https://github.com/tuanh123789).
+
+## StableDiffusionGLIGENPipeline
+
+[[autodoc]] StableDiffusionGLIGENPipeline
+ - all
+ - __call__
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_vae_tiling
+ - disable_vae_tiling
+ - enable_model_cpu_offload
+ - prepare_latents
+ - enable_fuser
+
+## StableDiffusionGLIGENTextImagePipeline
+
+[[autodoc]] StableDiffusionGLIGENTextImagePipeline
+ - all
+ - __call__
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_vae_tiling
+ - disable_vae_tiling
+ - enable_model_cpu_offload
+ - prepare_latents
+ - enable_fuser
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/image_variation.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/image_variation.md
new file mode 100644
index 0000000000000000000000000000000000000000..4895ababf5bd19fdd02578647ecec6f4885423f5
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/image_variation.md
@@ -0,0 +1,37 @@
+
+
+# Image variation
+
+The Stable Diffusion model can also generate variations from an input image. It uses a fine-tuned version of a Stable Diffusion model by [Justin Pinkney](https://www.justinpinkney.com/) from [Lambda](https://lambdalabs.com/).
+
+The original codebase can be found at [LambdaLabsML/lambda-diffusers](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) and additional official checkpoints for image variation can be found at [lambdalabs/sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers).
+
+
+
+Make sure to check out the Stable Diffusion [Tips](./overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+
+
+## StableDiffusionImageVariationPipeline
+
+[[autodoc]] StableDiffusionImageVariationPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/img2img.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/img2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3de84c0f4eb72f3fb2871e5d78d80a812de548f
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/img2img.md
@@ -0,0 +1,55 @@
+
+
+# Image-to-image
+
+The Stable Diffusion model can also be applied to image-to-image generation by passing a text prompt and an initial image to condition the generation of new images.
+
+The [`StableDiffusionImg2ImgPipeline`] uses the diffusion-denoising mechanism proposed in [SDEdit: Guided Image Synthesis and Editing with Stochastic Differential Equations](https://huggingface.co/papers/2108.01073) by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan Zhu, Stefano Ermon.
+
+The abstract from the paper is:
+
+*Guided image synthesis enables everyday users to create and edit photo-realistic images with minimum effort. The key challenge is balancing faithfulness to the user input (e.g., hand-drawn colored strokes) and realism of the synthesized image. Existing GAN-based methods attempt to achieve such balance using either conditional GANs or GAN inversions, which are challenging and often require additional training data or loss functions for individual applications. To address these issues, we introduce a new image synthesis and editing method, Stochastic Differential Editing (SDEdit), based on a diffusion model generative prior, which synthesizes realistic images by iteratively denoising through a stochastic differential equation (SDE). Given an input image with user guide of any type, SDEdit first adds noise to the input, then subsequently denoises the resulting image through the SDE prior to increase its realism. SDEdit does not require task-specific training or inversions and can naturally achieve the balance between realism and faithfulness. SDEdit significantly outperforms state-of-the-art GAN-based methods by up to 98.09% on realism and 91.72% on overall satisfaction scores, according to a human perception study, on multiple tasks, including stroke-based image synthesis and editing as well as image compositing.*
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+
+
+## StableDiffusionImg2ImgPipeline
+
+[[autodoc]] StableDiffusionImg2ImgPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - load_textual_inversion
+ - from_single_file
+ - load_lora_weights
+ - save_lora_weights
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
+
+## FlaxStableDiffusionImg2ImgPipeline
+
+[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
+ - all
+ - __call__
+
+## FlaxStableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/inpaint.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/inpaint.md
new file mode 100644
index 0000000000000000000000000000000000000000..362ad325ac8580136e85bd9a528e80e88bedde54
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/inpaint.md
@@ -0,0 +1,57 @@
+
+
+# Inpainting
+
+The Stable Diffusion model can also be applied to inpainting which lets you edit specific parts of an image by providing a mask and a text prompt using Stable Diffusion.
+
+## Tips
+
+It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such
+as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default
+text-to-image Stable Diffusion checkpoints, such as
+[runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible but they might be less performant.
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
+
+
+
+## StableDiffusionInpaintPipeline
+
+[[autodoc]] StableDiffusionInpaintPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - load_textual_inversion
+ - load_lora_weights
+ - save_lora_weights
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
+
+## FlaxStableDiffusionInpaintPipeline
+
+[[autodoc]] FlaxStableDiffusionInpaintPipeline
+ - all
+ - __call__
+
+## FlaxStableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md
new file mode 100644
index 0000000000000000000000000000000000000000..bdb113f6e465afea891c214ab239ab57661fb167
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md
@@ -0,0 +1,38 @@
+
+
+# Latent upscaler
+
+The Stable Diffusion latent upscaler model was created by [Katherine Crowson](https://github.com/crowsonkb/k-diffusion) in collaboration with [Stability AI](https://stability.ai/). It is used to enhance the output image resolution by a factor of 2 (see this demo [notebook](https://colab.research.google.com/drive/1o1qYJcFeywzCIdkfKJy7cTpgZTCM2EI4) for a demonstration of the original implementation).
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
+
+
+
+## StableDiffusionLatentUpscalePipeline
+
+[[autodoc]] StableDiffusionLatentUpscalePipeline
+ - all
+ - __call__
+ - enable_sequential_cpu_offload
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..2e489c0eeb7c3fcd8c72ac62a4d917e98eeeecbb
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
@@ -0,0 +1,37 @@
+
+
+# Text-to-(RGB, depth)
+
+LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
+
+The abstract from the paper is:
+
+*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+
+
+## StableDiffusionLDM3DPipeline
+
+[[autodoc]] StableDiffusionLDM3DPipeline
+ - all
+ - __call__
+
+## LDM3DPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/overview.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..fb4f2739dd2ba0200dd6b4fed871fc3782e017ca
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/overview.md
@@ -0,0 +1,168 @@
+
+
+# Stable Diffusion pipelines
+
+Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). Latent diffusion applies the diffusion process over a lower dimensional latent space to reduce memory and compute complexity. This specific type of diffusion model was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
+
+Stable Diffusion is trained on 512x512 images from a subset of the LAION-5B dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs.
+
+For more details about how Stable Diffusion works and how it differs from the base latent diffusion model, take a look at the Stability AI [announcement](https://stability.ai/blog/stable-diffusion-announcement) and our own [blog post](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) for more technical details.
+
+You can find the original codebase for Stable Diffusion v1.0 at [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) and Stable Diffusion v2.0 at [Stability-AI/stablediffusion](https://github.com/Stability-AI/stablediffusion) as well as their original scripts for various tasks. Additional official checkpoints for the different Stable Diffusion versions and tasks can be found on the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations. Explore these organizations to find the best checkpoint for your use-case!
+
+The table below summarizes the available Stable Diffusion pipelines, their supported tasks, and an interactive demo:
+
+
+
+## Tips
+
+To help you get the most out of the Stable Diffusion pipelines, here are a few tips for improving performance and usability. These tips are applicable to all Stable Diffusion pipelines.
+
+### Explore tradeoff between speed and quality
+
+[`StableDiffusionPipeline`] uses the [`PNDMScheduler`] by default, but 🤗 Diffusers provides many other schedulers (some of which are faster or output better quality) that are compatible. For example, if you want to use the [`EulerDiscreteScheduler`] instead of the default:
+
+```py
+from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
+
+pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+# or
+euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
+```
+
+### Reuse pipeline components to save memory
+
+To save memory and use the same components across multiple pipelines, use the `.components` method to avoid loading weights into RAM more than once.
+
+```py
+from diffusers import (
+ StableDiffusionPipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+)
+
+text2img = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
+inpaint = StableDiffusionInpaintPipeline(**text2img.components)
+
+# now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
+```
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.md
new file mode 100644
index 0000000000000000000000000000000000000000..75f36ba335a63f6702d91fb2731f97144ba41046
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.md
@@ -0,0 +1,125 @@
+
+
+# Stable Diffusion 2
+
+Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of the original [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release), and it was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/).
+
+*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels.
+These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAION’s NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).*
+
+For more details about how Stable Diffusion 2 works and how it differs from the original Stable Diffusion, please refer to the official [announcement post](https://stability.ai/blog/stable-diffusion-v2-release).
+
+The architecture of Stable Diffusion 2 is more or less identical to the original [Stable Diffusion model](./text2img) so check out it's API documentation for how to use Stable Diffusion 2. We recommend using the [`DPMSolverMultistepScheduler`] as it gives a reasonable speed/quality trade-off and can be run with as little as 20 steps.
+
+Stable Diffusion 2 is available for tasks like text-to-image, inpainting, super-resolution, and depth-to-image:
+
+| Task | Repository |
+|-------------------------|---------------------------------------------------------------------------------------------------------------|
+| text-to-image (512x512) | [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) |
+| text-to-image (768x768) | [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) |
+| inpainting | [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) |
+| super-resolution | [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) |
+| depth-to-image | [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) |
+
+Here are some examples for how to use Stable Diffusion 2 for each task:
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
+
+
+
+## Text-to-image
+
+```py
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+import torch
+
+repo_id = "stabilityai/stable-diffusion-2-base"
+pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
+
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+
+prompt = "High quality photo of an astronaut riding a horse in space"
+image = pipe(prompt, num_inference_steps=25).images[0]
+image
+```
+
+## Inpainting
+
+```py
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.utils import load_image, make_image_grid
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).resize((512, 512))
+mask_image = load_image(mask_url).resize((512, 512))
+
+repo_id = "stabilityai/stable-diffusion-2-inpainting"
+pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
+
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+## Super-resolution
+
+```py
+from diffusers import StableDiffusionUpscalePipeline
+from diffusers.utils import load_image, make_image_grid
+import torch
+
+# load model and scheduler
+model_id = "stabilityai/stable-diffusion-x4-upscaler"
+pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+pipeline = pipeline.to("cuda")
+
+# let's download an image
+url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
+low_res_img = load_image(url)
+low_res_img = low_res_img.resize((128, 128))
+prompt = "a white cat"
+upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
+make_image_grid([low_res_img.resize((512, 512)), upscaled_image.resize((512, 512))], rows=1, cols=2)
+```
+
+## Depth-to-image
+
+```py
+import torch
+from diffusers import StableDiffusionDepth2ImgPipeline
+from diffusers.utils import load_image, make_image_grid
+
+pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-depth",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+
+url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+init_image = load_image(url)
+prompt = "two tigers"
+negative_prompt = "bad, deformed, ugly, bad anotomy"
+image = pipe(prompt=prompt, image=init_image, negative_prompt=negative_prompt, strength=0.7).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
new file mode 100644
index 0000000000000000000000000000000000000000..217434c6b6698462d1bc5db0f7c9f6d8590121b9
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
@@ -0,0 +1,61 @@
+
+
+# Safe Stable Diffusion
+
+Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.
+
+The abstract from the paper is:
+
+*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
+
+## Tips
+
+Use the `safety_concept` property of [`StableDiffusionPipelineSafe`] to check and edit the current safety concept:
+
+```python
+>>> from diffusers import StableDiffusionPipelineSafe
+
+>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
+>>> pipeline.safety_concept
+'an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty'
+```
+For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
+
+There are 4 configurations (`SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONG`, and `SafetyConfig.MAX`) that can be applied:
+
+```python
+>>> from diffusers import StableDiffusionPipelineSafe
+>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
+
+>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
+>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
+>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
+```
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+
+
+## StableDiffusionPipelineSafe
+
+[[autodoc]] StableDiffusionPipelineSafe
+ - all
+ - __call__
+
+## StableDiffusionSafePipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
new file mode 100644
index 0000000000000000000000000000000000000000..74f4cba0835423486ee00f9319d862efcc8d1527
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -0,0 +1,55 @@
+
+
+# Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
+
+The abstract from the paper is:
+
+*We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared the previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators.*
+
+## Tips
+
+- Using SDXL with a DPM++ scheduler for less than 50 steps is known to produce [visual artifacts](https://github.com/huggingface/diffusers/issues/5433) because the solver becomes numerically unstable. To fix this issue, take a look at this [PR](https://github.com/huggingface/diffusers/pull/5541) which recommends for ODE/SDE solvers:
+ - set `use_karras_sigmas=True` or `lu_lambdas=True` to improve image quality
+ - set `euler_at_final=True` if you're using a solver with uniform step sizes (DPM++2M or DPM++2M SDE)
+- Most SDXL checkpoints work best with an image size of 1024x1024. Image sizes of 768x768 and 512x512 are also supported, but the results aren't as good. Anything below 512x512 is not recommended and likely won't be for default checkpoints like [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
+- SDXL can pass a different prompt for each of the text encoders it was trained on. We can even pass different parts of the same prompt to the text encoders.
+- SDXL output images can be improved by making use of a refiner model in an image-to-image setting.
+- SDXL offers `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to negatively condition the model on image resolution and cropping parameters.
+
+
+
+To learn how to use SDXL for various tasks, how to optimize performance, and other usage examples, take a look at the [Stable Diffusion XL](../../../using-diffusers/sdxl) guide.
+
+Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!
+
+
+
+## StableDiffusionXLPipeline
+
+[[autodoc]] StableDiffusionXLPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLImg2ImgPipeline
+
+[[autodoc]] StableDiffusionXLImg2ImgPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLInpaintPipeline
+
+[[autodoc]] StableDiffusionXLInpaintPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/text2img.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/text2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..75d0b305d22f9f4a6f179963cac3b42ab0b8cd68
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/text2img.md
@@ -0,0 +1,59 @@
+
+
+# Text-to-image
+
+The Stable Diffusion model was created by researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [Runway](https://github.com/runwayml), and [LAION](https://laion.ai/). The [`StableDiffusionPipeline`] is capable of generating photorealistic images given any text input. It's trained on 512x512 images from a subset of the LAION-5B dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs. Latent diffusion is the research on top of which Stable Diffusion was built. It was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
+
+The abstract from the paper is:
+
+*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs. Code is available at https://github.com/CompVis/latent-diffusion.*
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
+
+
+
+## StableDiffusionPipeline
+
+[[autodoc]] StableDiffusionPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+ - enable_vae_tiling
+ - disable_vae_tiling
+ - load_textual_inversion
+ - from_single_file
+ - load_lora_weights
+ - save_lora_weights
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
+
+## FlaxStableDiffusionPipeline
+
+[[autodoc]] FlaxStableDiffusionPipeline
+ - all
+ - __call__
+
+## FlaxStableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_diffusion/upscale.md b/diffusers/docs/source/en/api/pipelines/stable_diffusion/upscale.md
new file mode 100644
index 0000000000000000000000000000000000000000..d8df718d9d3641102eaea4bc0a8978af8c723a01
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_diffusion/upscale.md
@@ -0,0 +1,37 @@
+
+
+# Super-resolution
+
+The Stable Diffusion upscaler diffusion model was created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), and [LAION](https://laion.ai/). It is used to enhance the resolution of input images by a factor of 4.
+
+
+
+Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
+
+If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis), [Runway](https://huggingface.co/runwayml), and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!
+
+
+
+## StableDiffusionUpscalePipeline
+
+[[autodoc]] StableDiffusionUpscalePipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+
+## StableDiffusionPipelineOutput
+
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stable_unclip.md b/diffusers/docs/source/en/api/pipelines/stable_unclip.md
new file mode 100644
index 0000000000000000000000000000000000000000..2942cefec4a9024cf74e478011b6d43801c11356
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stable_unclip.md
@@ -0,0 +1,129 @@
+
+
+# Stable unCLIP
+
+Stable unCLIP checkpoints are finetuned from [Stable Diffusion 2.1](./stable_diffusion/stable_diffusion_2) checkpoints to condition on CLIP image embeddings.
+Stable unCLIP still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used
+for text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation.
+
+The abstract from the paper is:
+
+*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*
+
+## Tips
+
+Stable unCLIP takes `noise_level` as input during inference which determines how much noise is added to the image embeddings. A higher `noise_level` increases variation in the final un-noised images. By default, we do not add any additional noise to the image embeddings (`noise_level = 0`).
+
+### Text-to-Image Generation
+Stable unCLIP can be leveraged for text-to-image generation by pipelining it with the prior model of KakaoBrain's open source DALL-E 2 replication [Karlo](https://huggingface.co/kakaobrain/karlo-v1-alpha):
+
+```python
+import torch
+from diffusers import UnCLIPScheduler, DDPMScheduler, StableUnCLIPPipeline
+from diffusers.models import PriorTransformer
+from transformers import CLIPTokenizer, CLIPTextModelWithProjection
+
+prior_model_id = "kakaobrain/karlo-v1-alpha"
+data_type = torch.float16
+prior = PriorTransformer.from_pretrained(prior_model_id, subfolder="prior", torch_dtype=data_type)
+
+prior_text_model_id = "openai/clip-vit-large-patch14"
+prior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id)
+prior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type)
+prior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder="prior_scheduler")
+prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
+
+stable_unclip_model_id = "stabilityai/stable-diffusion-2-1-unclip-small"
+
+pipe = StableUnCLIPPipeline.from_pretrained(
+ stable_unclip_model_id,
+ torch_dtype=data_type,
+ variant="fp16",
+ prior_tokenizer=prior_tokenizer,
+ prior_text_encoder=prior_text_model,
+ prior=prior,
+ prior_scheduler=prior_scheduler,
+)
+
+pipe = pipe.to("cuda")
+wave_prompt = "dramatic wave, the Oceans roar, Strong wave spiral across the oceans as the waves unfurl into roaring crests; perfect wave form; perfect wave shape; dramatic wave shape; wave shape unbelievable; wave; wave shape spectacular"
+
+image = pipe(prompt=wave_prompt).images[0]
+image
+```
+
+
+For text-to-image we use `stabilityai/stable-diffusion-2-1-unclip-small` as it was trained on CLIP ViT-L/14 embedding, the same as the Karlo model prior. [stabilityai/stable-diffusion-2-1-unclip](https://hf.co/stabilityai/stable-diffusion-2-1-unclip) was trained on OpenCLIP ViT-H, so we don't recommend its use.
+
+
+
+### Text guided Image-to-Image Variation
+
+```python
+from diffusers import StableUnCLIPImg2ImgPipeline
+from diffusers.utils import load_image
+import torch
+
+pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16"
+)
+pipe = pipe.to("cuda")
+
+url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png"
+init_image = load_image(url)
+
+images = pipe(init_image).images
+images[0].save("variation_image.png")
+```
+
+Optionally, you can also pass a prompt to `pipe` such as:
+
+```python
+prompt = "A fantasy landscape, trending on artstation"
+
+image = pipe(init_image, prompt=prompt).images[0]
+image
+```
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableUnCLIPPipeline
+
+[[autodoc]] StableUnCLIPPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+
+## StableUnCLIPImg2ImgPipeline
+
+[[autodoc]] StableUnCLIPImg2ImgPipeline
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/stochastic_karras_ve.md b/diffusers/docs/source/en/api/pipelines/stochastic_karras_ve.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e3f1a5b833352887b5c6fc23964f0a4332d4e09
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/stochastic_karras_ve.md
@@ -0,0 +1,33 @@
+
+
+# Stochastic Karras VE
+
+[Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) is by Tero Karras, Miika Aittala, Timo Aila and Samuli Laine. This pipeline implements the stochastic sampling tailored to variance expanding (VE) models.
+
+The abstract from the paper:
+
+*We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## KarrasVePipeline
+[[autodoc]] KarrasVePipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/text_to_video.md b/diffusers/docs/source/en/api/pipelines/text_to_video.md
new file mode 100644
index 0000000000000000000000000000000000000000..244bb2e43b74e7415afe66c1c0dc90f580a8f510
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/text_to_video.md
@@ -0,0 +1,187 @@
+
+
+
+
+🧪 This pipeline is for research purposes only.
+
+
+
+# Text-to-video
+
+[ModelScope Text-to-Video Technical Report](https://arxiv.org/abs/2308.06571) is by Jiuniu Wang, Hangjie Yuan, Dayou Chen, Yingya Zhang, Xiang Wang, Shiwei Zhang.
+
+The abstract from the paper is:
+
+*This paper introduces ModelScopeT2V, a text-to-video synthesis model that evolves from a text-to-image synthesis model (i.e., Stable Diffusion). ModelScopeT2V incorporates spatio-temporal blocks to ensure consistent frame generation and smooth movement transitions. The model could adapt to varying frame numbers during training and inference, rendering it suitable for both image-text and video-text datasets. ModelScopeT2V brings together three components (i.e., VQGAN, a text encoder, and a denoising UNet), totally comprising 1.7 billion parameters, in which 0.5 billion parameters are dedicated to temporal capabilities. The model demonstrates superior performance over state-of-the-art methods across three evaluation metrics. The code and an online demo are available at https://modelscope.cn/models/damo/text-to-video-synthesis/summary.*
+
+You can find additional information about Text-to-Video on the [project page](https://modelscope.cn/models/damo/text-to-video-synthesis/summary), [original codebase](https://github.com/modelscope/modelscope/), and try it out in a [demo](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis). Official checkpoints can be found at [damo-vilab](https://huggingface.co/damo-vilab) and [cerspense](https://huggingface.co/cerspense).
+
+## Usage example
+
+### `text-to-video-ms-1.7b`
+
+Let's start by generating a short video with the default length of 16 frames (2s at 8 fps):
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.utils import export_to_video
+
+pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
+pipe = pipe.to("cuda")
+
+prompt = "Spiderman is surfing"
+video_frames = pipe(prompt).frames
+video_path = export_to_video(video_frames)
+video_path
+```
+
+Diffusers supports different optimization techniques to improve the latency
+and memory footprint of a pipeline. Since videos are often more memory-heavy than images,
+we can enable CPU offloading and VAE slicing to keep the memory footprint at bay.
+
+Let's generate a video of 8 seconds (64 frames) on the same GPU using CPU offloading and VAE slicing:
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.utils import export_to_video
+
+pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
+pipe.enable_model_cpu_offload()
+
+# memory optimization
+pipe.enable_vae_slicing()
+
+prompt = "Darth Vader surfing a wave"
+video_frames = pipe(prompt, num_frames=64).frames
+video_path = export_to_video(video_frames)
+video_path
+```
+
+It just takes **7 GBs of GPU memory** to generate the 64 video frames using PyTorch 2.0, "fp16" precision and the techniques mentioned above.
+
+We can also use a different scheduler easily, using the same method we'd use for Stable Diffusion:
+
+```python
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.utils import export_to_video
+
+pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+
+prompt = "Spiderman is surfing"
+video_frames = pipe(prompt, num_inference_steps=25).frames
+video_path = export_to_video(video_frames)
+video_path
+```
+
+Here are some sample outputs:
+
+
+
+
+ An astronaut riding a horse.
+
+
+
+
+ Darth vader surfing in waves.
+
+
+
+
+
+
+### `cerspense/zeroscope_v2_576w` & `cerspense/zeroscope_v2_XL`
+
+Zeroscope are watermark-free model and have been trained on specific sizes such as `576x320` and `1024x576`.
+One should first generate a video using the lower resolution checkpoint [`cerspense/zeroscope_v2_576w`](https://huggingface.co/cerspense/zeroscope_v2_576w) with [`TextToVideoSDPipeline`],
+which can then be upscaled using [`VideoToVideoSDPipeline`] and [`cerspense/zeroscope_v2_XL`](https://huggingface.co/cerspense/zeroscope_v2_XL).
+
+
+```py
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.utils import export_to_video
+from PIL import Image
+
+pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+# memory optimization
+pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
+pipe.enable_vae_slicing()
+
+prompt = "Darth Vader surfing a wave"
+video_frames = pipe(prompt, num_frames=24).frames
+video_path = export_to_video(video_frames)
+video_path
+```
+
+Now the video can be upscaled:
+
+```py
+pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16)
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+
+# memory optimization
+pipe.unet.enable_forward_chunking(chunk_size=1, dim=1)
+pipe.enable_vae_slicing()
+
+video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
+
+video_frames = pipe(prompt, video=video, strength=0.6).frames
+video_path = export_to_video(video_frames)
+video_path
+```
+
+Here are some sample outputs:
+
+
+
+
+ Darth vader surfing in waves.
+
+
+
+
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## TextToVideoSDPipeline
+[[autodoc]] TextToVideoSDPipeline
+ - all
+ - __call__
+
+## VideoToVideoSDPipeline
+[[autodoc]] VideoToVideoSDPipeline
+ - all
+ - __call__
+
+## TextToVideoSDPipelineOutput
+[[autodoc]] pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/text_to_video_zero.md b/diffusers/docs/source/en/api/pipelines/text_to_video_zero.md
new file mode 100644
index 0000000000000000000000000000000000000000..626e75f94936daa035c595cec7e8f4843534de3d
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/text_to_video_zero.md
@@ -0,0 +1,257 @@
+
+
+# Text2Video-Zero
+
+[Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://huggingface.co/papers/2303.13439) is by Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, [Zhangyang Wang](https://www.ece.utexas.edu/people/faculty/atlas-wang), Shant Navasardyan, [Humphrey Shi](https://www.humphreyshi.com).
+
+Text2Video-Zero enables zero-shot video generation using either:
+1. A textual prompt
+2. A prompt combined with guidance from poses or edges
+3. Video Instruct-Pix2Pix (instruction-guided video editing)
+
+Results are temporally consistent and closely follow the guidance and textual prompts.
+
+![teaser-img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2v_zero_teaser.png)
+
+The abstract from the paper is:
+
+*Recent text-to-video generation approaches rely on computationally heavy training and require large-scale video datasets. In this paper, we introduce a new task of zero-shot text-to-video generation and propose a low-cost approach (without any training or optimization) by leveraging the power of existing text-to-image synthesis methods (e.g., Stable Diffusion), making them suitable for the video domain.
+Our key modifications include (i) enriching the latent codes of the generated frames with motion dynamics to keep the global scene and the background time consistent; and (ii) reprogramming frame-level self-attention using a new cross-frame attention of each frame on the first frame, to preserve the context, appearance, and identity of the foreground object.
+Experiments show that this leads to low overhead, yet high-quality and remarkably consistent video generation. Moreover, our approach is not limited to text-to-video synthesis but is also applicable to other tasks such as conditional and content-specialized video generation, and Video Instruct-Pix2Pix, i.e., instruction-guided video editing.
+As experiments show, our method performs comparably or sometimes better than recent approaches, despite not being trained on additional video data.*
+
+You can find additional information about Text2Video-Zero on the [project page](https://text2video-zero.github.io/), [paper](https://arxiv.org/abs/2303.13439), and [original codebase](https://github.com/Picsart-AI-Research/Text2Video-Zero).
+
+## Usage example
+
+### Text-To-Video
+
+To generate a video from prompt, run the following Python code:
+```python
+import torch
+from diffusers import TextToVideoZeroPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A panda is playing guitar on times square"
+result = pipe(prompt=prompt).images
+result = [(r * 255).astype("uint8") for r in result]
+imageio.mimsave("video.mp4", result, fps=4)
+```
+You can change these parameters in the pipeline call:
+* Motion field strength (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1):
+ * `motion_field_strength_x` and `motion_field_strength_y`. Default: `motion_field_strength_x=12`, `motion_field_strength_y=12`
+* `T` and `T'` (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1)
+ * `t0` and `t1` in the range `{0, ..., num_inference_steps}`. Default: `t0=45`, `t1=48`
+* Video length:
+ * `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
+
+We can also generate longer videos by doing the processing in a chunk-by-chunk manner:
+```python
+import torch
+from diffusers import TextToVideoZeroPipeline
+import numpy as np
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+seed = 0
+video_length = 24 #24 ÷ 4fps = 6 seconds
+chunk_size = 8
+prompt = "A panda is playing guitar on times square"
+
+# Generate the video chunk-by-chunk
+result = []
+chunk_ids = np.arange(0, video_length, chunk_size - 1)
+generator = torch.Generator(device="cuda")
+for i in range(len(chunk_ids)):
+ print(f"Processing chunk {i + 1} / {len(chunk_ids)}")
+ ch_start = chunk_ids[i]
+ ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
+ # Attach the first frame for Cross Frame Attention
+ frame_ids = [0] + list(range(ch_start, ch_end))
+ # Fix the seed for the temporal consistency
+ generator.manual_seed(seed)
+ output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids)
+ result.append(output.images[1:])
+
+# Concatenate chunks and save
+result = np.concatenate(result)
+result = [(r * 255).astype("uint8") for r in result]
+imageio.mimsave("video.mp4", result, fps=4)
+```
+
+
+### Text-To-Video with Pose Control
+To generate a video from prompt with additional pose control
+
+1. Download a demo video
+
+ ```python
+ from huggingface_hub import hf_hub_download
+
+ filename = "__assets__/poses_skeleton_gifs/dance1_corr.mp4"
+ repo_id = "PAIR/Text2Video-Zero"
+ video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
+ ```
+
+
+2. Read video containing extracted pose images
+ ```python
+ from PIL import Image
+ import imageio
+
+ reader = imageio.get_reader(video_path, "ffmpeg")
+ frame_count = 8
+ pose_images = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
+ ```
+ To extract pose from actual video, read [ControlNet documentation](controlnet).
+
+3. Run `StableDiffusionControlNetPipeline` with our custom attention processor
+
+ ```python
+ import torch
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
+
+ model_id = "runwayml/stable-diffusion-v1-5"
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ model_id, controlnet=controlnet, torch_dtype=torch.float16
+ ).to("cuda")
+
+ # Set the attention processor
+ pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
+ pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
+
+ # fix latents for all frames
+ latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
+
+ prompt = "Darth Vader dancing in a desert"
+ result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
+ imageio.mimsave("video.mp4", result, fps=4)
+ ```
+
+
+### Text-To-Video with Edge Control
+
+To generate a video from prompt with additional Canny edge control, follow the same steps described above for pose-guided generation using [Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny).
+
+
+### Video Instruct-Pix2Pix
+
+To perform text-guided video editing (with [InstructPix2Pix](pix2pix)):
+
+1. Download a demo video
+
+ ```python
+ from huggingface_hub import hf_hub_download
+
+ filename = "__assets__/pix2pix video/camel.mp4"
+ repo_id = "PAIR/Text2Video-Zero"
+ video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
+ ```
+
+2. Read video from path
+ ```python
+ from PIL import Image
+ import imageio
+
+ reader = imageio.get_reader(video_path, "ffmpeg")
+ frame_count = 8
+ video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
+ ```
+
+3. Run `StableDiffusionInstructPix2PixPipeline` with our custom attention processor
+ ```python
+ import torch
+ from diffusers import StableDiffusionInstructPix2PixPipeline
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
+
+ model_id = "timbrooks/instruct-pix2pix"
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+ pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=3))
+
+ prompt = "make it Van Gogh Starry Night style"
+ result = pipe(prompt=[prompt] * len(video), image=video).images
+ imageio.mimsave("edited_video.mp4", result, fps=4)
+ ```
+
+
+### DreamBooth specialization
+
+Methods **Text-To-Video**, **Text-To-Video with Pose Control** and **Text-To-Video with Edge Control**
+can run with custom [DreamBooth](../../training/dreambooth) models, as shown below for
+[Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny) and
+[Avatar style DreamBooth](https://huggingface.co/PAIR/text2video-zero-controlnet-canny-avatar) model:
+
+1. Download a demo video
+
+ ```python
+ from huggingface_hub import hf_hub_download
+
+ filename = "__assets__/canny_videos_mp4/girl_turning.mp4"
+ repo_id = "PAIR/Text2Video-Zero"
+ video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
+ ```
+
+2. Read video from path
+ ```python
+ from PIL import Image
+ import imageio
+
+ reader = imageio.get_reader(video_path, "ffmpeg")
+ frame_count = 8
+ canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
+ ```
+
+3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
+ ```python
+ import torch
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
+
+ # set model id to custom model
+ model_id = "PAIR/text2video-zero-controlnet-canny-avatar"
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ model_id, controlnet=controlnet, torch_dtype=torch.float16
+ ).to("cuda")
+
+ # Set the attention processor
+ pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
+ pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
+
+ # fix latents for all frames
+ latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1)
+
+ prompt = "oil painting of a beautiful girl avatar style"
+ result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images
+ imageio.mimsave("video.mp4", result, fps=4)
+ ```
+
+You can filter out some available DreamBooth-trained models with [this link](https://huggingface.co/models?search=dreambooth).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## TextToVideoZeroPipeline
+[[autodoc]] TextToVideoZeroPipeline
+ - all
+ - __call__
+
+## TextToVideoPipelineOutput
+[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/unclip.md b/diffusers/docs/source/en/api/pipelines/unclip.md
new file mode 100644
index 0000000000000000000000000000000000000000..da076ae8320ccbcd7042874e4574a0021d004538
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/unclip.md
@@ -0,0 +1,37 @@
+
+
+# unCLIP
+
+[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
+
+The abstract from the paper is following:
+
+*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*
+
+You can find lucidrains' DALL-E 2 recreation at [lucidrains/DALLE2-pytorch](https://github.com/lucidrains/DALLE2-pytorch).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## UnCLIPPipeline
+[[autodoc]] UnCLIPPipeline
+ - all
+ - __call__
+
+## UnCLIPImageVariationPipeline
+[[autodoc]] UnCLIPImageVariationPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/unidiffuser.md b/diffusers/docs/source/en/api/pipelines/unidiffuser.md
new file mode 100644
index 0000000000000000000000000000000000000000..5da194e320ccf5d0062dcfadfcf4110cbc45dada
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/unidiffuser.md
@@ -0,0 +1,205 @@
+
+
+# UniDiffuser
+
+The UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://huggingface.co/papers/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu.
+
+The abstract from the paper is:
+
+*This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is -- learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model -- perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead. In particular, UniDiffuser is able to produce perceptually realistic samples in all tasks and its quantitative results (e.g., the FID and CLIP score) are not only superior to existing general-purpose models but also comparable to the bespoken models (e.g., Stable Diffusion and DALL-E 2) in representative tasks (e.g., text-to-image generation).*
+
+You can find the original codebase at [thu-ml/unidiffuser](https://github.com/thu-ml/unidiffuser) and additional checkpoints at [thu-ml](https://huggingface.co/thu-ml).
+
+
+
+There is currently an issue on PyTorch 1.X where the output images are all black or the pixel values become `NaNs`. This issue can be mitigated by switching to PyTorch 2.X.
+
+
+
+This pipeline was contributed by [dg845](https://github.com/dg845). ❤️
+
+## Usage Examples
+
+Because the UniDiffuser model is trained to model the joint distribution of (image, text) pairs, it is capable of performing a diverse range of generation tasks:
+
+### Unconditional Image and Text Generation
+
+Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair:
+
+```python
+import torch
+
+from diffusers import UniDiffuserPipeline
+
+device = "cuda"
+model_id_or_path = "thu-ml/unidiffuser-v1"
+pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+pipe.to(device)
+
+# Unconditional image and text generation. The generation task is automatically inferred.
+sample = pipe(num_inference_steps=20, guidance_scale=8.0)
+image = sample.images[0]
+text = sample.text[0]
+image.save("unidiffuser_joint_sample_image.png")
+print(text)
+```
+
+This is also called "joint" generation in the UniDiffuser paper, since we are sampling from the joint image-text distribution.
+
+Note that the generation task is inferred from the inputs used when calling the pipeline.
+It is also possible to manually specify the unconditional generation task ("mode") manually with [`UniDiffuserPipeline.set_joint_mode`]:
+
+```python
+# Equivalent to the above.
+pipe.set_joint_mode()
+sample = pipe(num_inference_steps=20, guidance_scale=8.0)
+```
+
+When the mode is set manually, subsequent calls to the pipeline will use the set mode without attempting to infer the mode.
+You can reset the mode with [`UniDiffuserPipeline.reset_mode`], after which the pipeline will once again infer the mode.
+
+You can also generate only an image or only text (which the UniDiffuser paper calls "marginal" generation since we sample from the marginal distribution of images and text, respectively):
+
+```python
+# Unlike other generation tasks, image-only and text-only generation don't use classifier-free guidance
+# Image-only generation
+pipe.set_image_mode()
+sample_image = pipe(num_inference_steps=20).images[0]
+# Text-only generation
+pipe.set_text_mode()
+sample_text = pipe(num_inference_steps=20).text[0]
+```
+
+### Text-to-Image Generation
+
+UniDiffuser is also capable of sampling from conditional distributions; that is, the distribution of images conditioned on a text prompt or the distribution of texts conditioned on an image.
+Here is an example of sampling from the conditional image distribution (text-to-image generation or text-conditioned image generation):
+
+```python
+import torch
+
+from diffusers import UniDiffuserPipeline
+
+device = "cuda"
+model_id_or_path = "thu-ml/unidiffuser-v1"
+pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+pipe.to(device)
+
+# Text-to-image generation
+prompt = "an elephant under the sea"
+
+sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)
+t2i_image = sample.images[0]
+t2i_image
+```
+
+The `text2img` mode requires that either an input `prompt` or `prompt_embeds` be supplied. You can set the `text2img` mode manually with [`UniDiffuserPipeline.set_text_to_image_mode`].
+
+### Image-to-Text Generation
+
+Similarly, UniDiffuser can also produce text samples given an image (image-to-text or image-conditioned text generation):
+
+```python
+import torch
+
+from diffusers import UniDiffuserPipeline
+from diffusers.utils import load_image
+
+device = "cuda"
+model_id_or_path = "thu-ml/unidiffuser-v1"
+pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+pipe.to(device)
+
+# Image-to-text generation
+image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
+init_image = load_image(image_url).resize((512, 512))
+
+sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
+i2t_text = sample.text[0]
+print(i2t_text)
+```
+
+The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`].
+
+### Image Variation
+
+The UniDiffuser authors suggest performing image variation through a "round-trip" generation method, where given an input image, we first perform an image-to-text generation, and then perform a text-to-image generation on the outputs of the first generation.
+This produces a new image which is semantically similar to the input image:
+
+```python
+import torch
+
+from diffusers import UniDiffuserPipeline
+from diffusers.utils import load_image
+
+device = "cuda"
+model_id_or_path = "thu-ml/unidiffuser-v1"
+pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+pipe.to(device)
+
+# Image variation can be performed with an image-to-text generation followed by a text-to-image generation:
+# 1. Image-to-text generation
+image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
+init_image = load_image(image_url).resize((512, 512))
+
+sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
+i2t_text = sample.text[0]
+print(i2t_text)
+
+# 2. Text-to-image generation
+sample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0)
+final_image = sample.images[0]
+final_image.save("unidiffuser_image_variation_sample.png")
+```
+
+### Text Variation
+
+Similarly, text variation can be performed on an input prompt with a text-to-image generation followed by a image-to-text generation:
+
+```python
+import torch
+
+from diffusers import UniDiffuserPipeline
+
+device = "cuda"
+model_id_or_path = "thu-ml/unidiffuser-v1"
+pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
+pipe.to(device)
+
+# Text variation can be performed with a text-to-image generation followed by a image-to-text generation:
+# 1. Text-to-image generation
+prompt = "an elephant under the sea"
+
+sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)
+t2i_image = sample.images[0]
+t2i_image.save("unidiffuser_text2img_sample_image.png")
+
+# 2. Image-to-text generation
+sample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0)
+final_prompt = sample.text[0]
+print(final_prompt)
+```
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## UniDiffuserPipeline
+[[autodoc]] UniDiffuserPipeline
+ - all
+ - __call__
+
+## ImageTextPipelineOutput
+[[autodoc]] pipelines.ImageTextPipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/value_guided_sampling.md b/diffusers/docs/source/en/api/pipelines/value_guided_sampling.md
new file mode 100644
index 0000000000000000000000000000000000000000..01b7717f49f82506591fa973852a0a6910560aea
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/value_guided_sampling.md
@@ -0,0 +1,38 @@
+
+
+# Value-guided planning
+
+
+
+🧪 This is an experimental pipeline for reinforcement learning!
+
+
+
+This pipeline is based on the [Planning with Diffusion for Flexible Behavior Synthesis](https://huggingface.co/papers/2205.09991) paper by Michael Janner, Yilun Du, Joshua B. Tenenbaum, Sergey Levine.
+
+The abstract from the paper is:
+
+*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*
+
+You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb).
+
+The script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## ValueGuidedRLPipeline
+[[autodoc]] diffusers.experimental.ValueGuidedRLPipeline
diff --git a/diffusers/docs/source/en/api/pipelines/versatile_diffusion.md b/diffusers/docs/source/en/api/pipelines/versatile_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..953f4822486aeafc600be9bdd6cf69c089ad5cb6
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/versatile_diffusion.md
@@ -0,0 +1,54 @@
+
+
+# Versatile Diffusion
+
+Versatile Diffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://huggingface.co/papers/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi.
+
+The abstract from the paper is:
+
+*Recent advances in diffusion models have set an impressive milestone in many generation tasks, and trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-task multimodal network, dubbed Versatile Diffusion (VD), that handles multiple flows of text-to-image, image-to-text, and variations in one unified model. The pipeline design of VD instantiates a unified multi-flow diffusion framework, consisting of sharable and swappable layer modules that enable the crossmodal generality beyond images and text. Through extensive experiments, we demonstrate that VD successfully achieves the following: a) VD outperforms the baseline approaches and handles all its base tasks with competitive quality; b) VD enables novel extensions such as disentanglement of style and semantics, dual- and multi-context blending, etc.; c) The success of our multi-flow multimodal framework over images and text may inspire further diffusion-based universal AI research.*
+
+## Tips
+
+You can load the more memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that supports all the tasks or use the individual pipelines which are more memory efficient.
+
+| **Pipeline** | **Supported tasks** |
+|------------------------------------------------------|-----------------------------------|
+| [`VersatileDiffusionPipeline`] | all of the below |
+| [`VersatileDiffusionTextToImagePipeline`] | text-to-image |
+| [`VersatileDiffusionImageVariationPipeline`] | image variation |
+| [`VersatileDiffusionDualGuidedPipeline`] | image-text dual guided generation |
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## VersatileDiffusionPipeline
+[[autodoc]] VersatileDiffusionPipeline
+
+## VersatileDiffusionTextToImagePipeline
+[[autodoc]] VersatileDiffusionTextToImagePipeline
+ - all
+ - __call__
+
+## VersatileDiffusionImageVariationPipeline
+[[autodoc]] VersatileDiffusionImageVariationPipeline
+ - all
+ - __call__
+
+## VersatileDiffusionDualGuidedPipeline
+[[autodoc]] VersatileDiffusionDualGuidedPipeline
+ - all
+ - __call__
diff --git a/diffusers/docs/source/en/api/pipelines/vq_diffusion.md b/diffusers/docs/source/en/api/pipelines/vq_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..f2b0db71612333017a270dd745aa41fcd61c06a0
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/vq_diffusion.md
@@ -0,0 +1,35 @@
+
+
+# VQ Diffusion
+
+[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://huggingface.co/papers/2111.14822) is by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo.
+
+The abstract from the paper is:
+
+*We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.*
+
+The original codebase can be found at [microsoft/VQ-Diffusion](https://github.com/microsoft/VQ-Diffusion).
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## VQDiffusionPipeline
+[[autodoc]] VQDiffusionPipeline
+ - all
+ - __call__
+
+## ImagePipelineOutput
+[[autodoc]] pipelines.ImagePipelineOutput
diff --git a/diffusers/docs/source/en/api/pipelines/wuerstchen.md b/diffusers/docs/source/en/api/pipelines/wuerstchen.md
new file mode 100644
index 0000000000000000000000000000000000000000..127c6df9413eb002477c0c5d5e6fc316648c516a
--- /dev/null
+++ b/diffusers/docs/source/en/api/pipelines/wuerstchen.md
@@ -0,0 +1,163 @@
+
+
+# Würstchen
+
+
+
+[Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, Mats L. Richter and Christopher Pal and Marc Aubreville.
+
+The abstract from the paper is:
+
+*We introduce Würstchen, a novel architecture for text-to-image synthesis that combines competitive performance with unprecedented cost-effectiveness for large-scale text-to-image diffusion models. A key contribution of our work is to develop a latent diffusion technique in which we learn a detailed but extremely compact semantic image representation used to guide the diffusion process. This highly compressed representation of an image provides much more detailed guidance compared to latent representations of language and this significantly reduces the computational requirements to achieve state-of-the-art results. Our approach also improves the quality of text-conditioned image generation based on our user preference study. The training requirements of our approach consists of 24,602 A100-GPU hours - compared to Stable Diffusion 2.1's 200,000 GPU hours. Our approach also requires less training data to achieve these results. Furthermore, our compact latent representations allows us to perform inference over twice as fast, slashing the usual costs and carbon footprint of a state-of-the-art (SOTA) diffusion model significantly, without compromising the end performance. In a broader comparison against SOTA models our approach is substantially more efficient and compares favorably in terms of image quality. We believe that this work motivates more emphasis on the prioritization of both performance and computational accessibility.*
+
+## Würstchen Overview
+Würstchen is a diffusion model, whose text-conditional model works in a highly compressed latent space of images. Why is this important? Compressing data can reduce computational costs for both training and inference by magnitudes. Training on 1024x1024 images is way more expensive than training on 32x32. Usually, other works make use of a relatively small compression, in the range of 4x - 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, we achieve a 42x spatial compression. This was unseen before because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details can be found in the [paper](https://huggingface.co/papers/2306.00637)). A third model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, while also allowing cheaper and faster inference.
+
+## Würstchen v2 comes to Diffusers
+
+After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competitive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
+
+- Higher resolution (1024x1024 up to 2048x2048)
+- Faster inference
+- Multi Aspect Resolution Sampling
+- Better quality
+
+
+We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
+
+- v2-base
+- v2-aesthetic
+- **(default)** v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
+
+We recommend using v2-interpolated, as it has a nice touch of both photorealism and aesthetics. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
+A comparison can be seen here:
+
+
+
+## Text-to-Image Generation
+
+For the sake of usability, Würstchen can be used with a single pipeline. This pipeline can be used as follows:
+
+```python
+import torch
+from diffusers import AutoPipelineForText2Image
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
+
+pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
+
+caption = "Anthropomorphic cat dressed as a fire fighter"
+images = pipe(
+ caption,
+ width=1024,
+ height=1536,
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ prior_guidance_scale=4.0,
+ num_images_per_prompt=2,
+).images
+```
+
+For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look at the [paper](https://huggingface.co/papers/2306.00637).
+
+```python
+import torch
+from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
+
+device = "cuda"
+dtype = torch.float16
+num_images_per_prompt = 2
+
+prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
+ "warp-ai/wuerstchen-prior", torch_dtype=dtype
+).to(device)
+decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
+ "warp-ai/wuerstchen", torch_dtype=dtype
+).to(device)
+
+caption = "Anthropomorphic cat dressed as a fire fighter"
+negative_prompt = ""
+
+prior_output = prior_pipeline(
+ prompt=caption,
+ height=1024,
+ width=1536,
+ timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ negative_prompt=negative_prompt,
+ guidance_scale=4.0,
+ num_images_per_prompt=num_images_per_prompt,
+)
+decoder_output = decoder_pipeline(
+ image_embeddings=prior_output.image_embeddings,
+ prompt=caption,
+ negative_prompt=negative_prompt,
+ guidance_scale=0.0,
+ output_type="pil",
+).images[0]
+decoder_output
+```
+
+## Speed-Up Inference
+You can make use of `torch.compile` function and gain a speed-up of about 2-3x:
+
+```python
+prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
+decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
+```
+
+## Limitations
+
+- Due to the high compression employed by Würstchen, generations can lack a good amount
+of detail. To our human eye, this is especially noticeable in faces, hands etc.
+- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
+after 1024x1024 is 1152x1152
+- The model lacks the ability to render correct text in images
+- The model often does not achieve photorealism
+- Difficult compositional prompts are hard for the model
+
+The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
+
+
+## WuerstchenCombinedPipeline
+
+[[autodoc]] WuerstchenCombinedPipeline
+ - all
+ - __call__
+
+## WuerstchenPriorPipeline
+
+[[autodoc]] WuerstchenPriorPipeline
+ - all
+ - __call__
+
+## WuerstchenPriorPipelineOutput
+
+[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput
+
+## WuerstchenDecoderPipeline
+
+[[autodoc]] WuerstchenDecoderPipeline
+ - all
+ - __call__
+
+## Citation
+
+```bibtex
+ @misc{pernias2023wuerstchen,
+ title={Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models},
+ author={Pablo Pernias and Dominic Rampas and Mats L. Richter and Christopher J. Pal and Marc Aubreville},
+ year={2023},
+ eprint={2306.00637},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+ }
+```
diff --git a/diffusers/docs/source/en/api/schedulers/cm_stochastic_iterative.md b/diffusers/docs/source/en/api/schedulers/cm_stochastic_iterative.md
new file mode 100644
index 0000000000000000000000000000000000000000..c112c89a12fc3881da1f25d5044d1a8e49edc708
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/cm_stochastic_iterative.md
@@ -0,0 +1,27 @@
+
+
+# CMStochasticIterativeScheduler
+
+[Consistency Models](https://huggingface.co/papers/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever introduced a multistep and onestep scheduler (Algorithm 1) that is capable of generating good samples in one or a small number of steps.
+
+The abstract from the paper is:
+
+*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.*
+
+The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).
+
+## CMStochasticIterativeScheduler
+[[autodoc]] CMStochasticIterativeScheduler
+
+## CMStochasticIterativeSchedulerOutput
+[[autodoc]] schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/consistency_decoder.md b/diffusers/docs/source/en/api/schedulers/consistency_decoder.md
new file mode 100644
index 0000000000000000000000000000000000000000..6c937b9132795fdedb999c72c5b98a9abd24e650
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/consistency_decoder.md
@@ -0,0 +1,9 @@
+# ConsistencyDecoderScheduler
+
+This scheduler is a part of the [`ConsistencyDecoderPipeline`] and was introduced in [DALL-E 3](https://openai.com/dall-e-3).
+
+The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).
+
+
+## ConsistencyDecoderScheduler
+[[autodoc]] schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler
\ No newline at end of file
diff --git a/diffusers/docs/source/en/api/schedulers/ddim.md b/diffusers/docs/source/en/api/schedulers/ddim.md
new file mode 100644
index 0000000000000000000000000000000000000000..422b74cff3a957baba18d5aa2e295f4f113d84be
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/ddim.md
@@ -0,0 +1,82 @@
+
+
+# DDIMScheduler
+
+[Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.
+
+The abstract from the paper is:
+
+*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample.
+To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models
+with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process.
+We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from.
+We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
+
+The original codebase of this paper can be found at [ermongroup/ddim](https://github.com/ermongroup/ddim), and you can contact the author on [tsong.me](https://tsong.me/).
+
+## Tips
+
+The paper [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. To fix this, the authors propose:
+
+
+
+🧪 This is an experimental feature!
+
+
+
+1. rescale the noise schedule to enforce zero terminal signal-to-noise ratio (SNR)
+
+```py
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
+```
+
+2. train a model with `v_prediction` (add the following argument to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts)
+
+```bash
+--prediction_type="v_prediction"
+```
+
+3. change the sampler to always start from the last timestep
+
+```py
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
+```
+
+4. rescale classifier-free guidance to prevent over-exposure
+
+```py
+image = pipe(prompt, guidance_rescale=0.7).images[0]
+```
+
+For example:
+
+```py
+from diffusers import DiffusionPipeline, DDIMScheduler
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16)
+pipe.scheduler = DDIMScheduler.from_config(
+ pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
+)
+pipe.to("cuda")
+
+prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
+image = pipe(prompt, guidance_rescale=0.7).images[0]
+image
+```
+
+## DDIMScheduler
+[[autodoc]] DDIMScheduler
+
+## DDIMSchedulerOutput
+[[autodoc]] schedulers.scheduling_ddim.DDIMSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/ddim_inverse.md b/diffusers/docs/source/en/api/schedulers/ddim_inverse.md
new file mode 100644
index 0000000000000000000000000000000000000000..9b28b9dc59500a7fa8ccc8f1971774bde22b1e6d
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/ddim_inverse.md
@@ -0,0 +1,19 @@
+
+
+# DDIMInverseScheduler
+
+`DDIMInverseScheduler` is the inverted scheduler from [Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.
+The implementation is mostly based on the DDIM inversion definition from [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://huggingface.co/papers/2211.09794).
+
+## DDIMInverseScheduler
+[[autodoc]] DDIMInverseScheduler
diff --git a/diffusers/docs/source/en/api/schedulers/ddpm.md b/diffusers/docs/source/en/api/schedulers/ddpm.md
new file mode 100644
index 0000000000000000000000000000000000000000..5402d8863df6ac3308c7bd93213b36df2129db42
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/ddpm.md
@@ -0,0 +1,25 @@
+
+
+# DDPMScheduler
+
+[Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2006.11239) (DDPM) by Jonathan Ho, Ajay Jain and Pieter Abbeel proposes a diffusion based model of the same name. In the context of the 🤗 Diffusers library, DDPM refers to the discrete denoising scheduler from the paper as well as the pipeline.
+
+The abstract from the paper is:
+
+*We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN. Our implementation is available at [this https URL](https://github.com/hojonathanho/diffusion).*
+
+## DDPMScheduler
+[[autodoc]] DDPMScheduler
+
+## DDPMSchedulerOutput
+[[autodoc]] schedulers.scheduling_ddpm.DDPMSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/deis.md b/diffusers/docs/source/en/api/schedulers/deis.md
new file mode 100644
index 0000000000000000000000000000000000000000..fc05dd39ee6131b42f0a375fba36e0fa8c6d3f43
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/deis.md
@@ -0,0 +1,34 @@
+
+
+# DEISMultistepScheduler
+
+Diffusion Exponential Integrator Sampler (DEIS) is proposed in [Fast Sampling of Diffusion Models with Exponential Integrator](https://huggingface.co/papers/2204.13902) by Qinsheng Zhang and Yongxin Chen. `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs).
+
+This implementation modifies the polynomial fitting formula in log-rho space instead of the original linear `t` space in the DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver.
+
+The abstract from the paper is:
+
+*The past few years have witnessed the great success of Diffusion models~(DMs) in generating high-fidelity samples in generative modeling tasks. A major limitation of the DM is its notoriously slow sampling procedure which normally requires hundreds to thousands of time discretization steps of the learned diffusion process to reach the desired accuracy. Our goal is to develop a fast sampling method for DMs with a much less number of steps while retaining high sample quality. To this end, we systematically analyze the sampling procedure in DMs and identify key factors that affect the sample quality, among which the method of discretization is most crucial. By carefully examining the learned diffusion process, we propose Diffusion Exponential Integrator Sampler~(DEIS). It is based on the Exponential Integrator designed for discretizing ordinary differential equations (ODEs) and leverages a semilinear structure of the learned diffusion process to reduce the discretization error. The proposed method can be applied to any DMs and can generate high-fidelity samples in as few as 10 steps. In our experiments, it takes about 3 minutes on one A6000 GPU to generate 50k images from CIFAR10. Moreover, by directly using pre-trained DMs, we achieve the state-of-art sampling performance when the number of score function evaluation~(NFE) is limited, e.g., 4.17 FID with 10 NFEs, 3.37 FID, and 9.74 IS with only 15 NFEs on CIFAR10. Code is available at [this https URL](https://github.com/qsh-zh/deis).*
+
+## Tips
+
+It is recommended to set `solver_order` to 2 or 3, while `solver_order=1` is equivalent to [`DDIMScheduler`].
+
+Dynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space
+diffusion models, you can set `thresholding=True` to use the dynamic thresholding.
+
+## DEISMultistepScheduler
+[[autodoc]] DEISMultistepScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/dpm_discrete.md b/diffusers/docs/source/en/api/schedulers/dpm_discrete.md
new file mode 100644
index 0000000000000000000000000000000000000000..eea09915c68a38a23ac52047314cc30f6fbd9c40
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/dpm_discrete.md
@@ -0,0 +1,23 @@
+
+
+# KDPM2DiscreteScheduler
+
+The `KDPM2DiscreteScheduler` is inspired by the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper, and the scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/).
+
+The original codebase can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion).
+
+## KDPM2DiscreteScheduler
+[[autodoc]] KDPM2DiscreteScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/dpm_discrete_ancestral.md b/diffusers/docs/source/en/api/schedulers/dpm_discrete_ancestral.md
new file mode 100644
index 0000000000000000000000000000000000000000..5f8ae193c5a7caf911e89af0585e250af91869cc
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/dpm_discrete_ancestral.md
@@ -0,0 +1,23 @@
+
+
+# KDPM2AncestralDiscreteScheduler
+
+The `KDPM2DiscreteScheduler` with ancestral sampling is inspired by the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper, and the scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/).
+
+The original codebase can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion).
+
+## KDPM2AncestralDiscreteScheduler
+[[autodoc]] KDPM2AncestralDiscreteScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/dpm_sde.md b/diffusers/docs/source/en/api/schedulers/dpm_sde.md
new file mode 100644
index 0000000000000000000000000000000000000000..1486ba3d275ed16807bade66892b987ed28ab58c
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/dpm_sde.md
@@ -0,0 +1,21 @@
+
+
+# DPMSolverSDEScheduler
+
+The `DPMSolverSDEScheduler` is inspired by the stochastic sampler from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper, and the scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/).
+
+## DPMSolverSDEScheduler
+[[autodoc]] DPMSolverSDEScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/euler.md b/diffusers/docs/source/en/api/schedulers/euler.md
new file mode 100644
index 0000000000000000000000000000000000000000..92743283370d77e9e05e472c6061ae1f4808e895
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/euler.md
@@ -0,0 +1,22 @@
+
+
+# EulerDiscreteScheduler
+
+The Euler scheduler (Algorithm 2) is from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).
+
+
+## EulerDiscreteScheduler
+[[autodoc]] EulerDiscreteScheduler
+
+## EulerDiscreteSchedulerOutput
+[[autodoc]] schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/euler_ancestral.md b/diffusers/docs/source/en/api/schedulers/euler_ancestral.md
new file mode 100644
index 0000000000000000000000000000000000000000..c78a407d2eb2e92d470844fb09241072fa1c0b87
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/euler_ancestral.md
@@ -0,0 +1,21 @@
+
+
+# EulerAncestralDiscreteScheduler
+
+A scheduler that uses ancestral sampling with Euler method steps. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) implementation by [Katherine Crowson](https://github.com/crowsonkb/).
+
+## EulerAncestralDiscreteScheduler
+[[autodoc]] EulerAncestralDiscreteScheduler
+
+## EulerAncestralDiscreteSchedulerOutput
+[[autodoc]] schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/heun.md b/diffusers/docs/source/en/api/schedulers/heun.md
new file mode 100644
index 0000000000000000000000000000000000000000..abfde24a1678a2f015f1ad9f0b3e90a935791546
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/heun.md
@@ -0,0 +1,21 @@
+
+
+# HeunDiscreteScheduler
+
+The Heun scheduler (Algorithm 1) is from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. The scheduler is ported from the [k-diffusion](https://github.com/crowsonkb/k-diffusion) library and created by [Katherine Crowson](https://github.com/crowsonkb/).
+
+## HeunDiscreteScheduler
+[[autodoc]] HeunDiscreteScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/ipndm.md b/diffusers/docs/source/en/api/schedulers/ipndm.md
new file mode 100644
index 0000000000000000000000000000000000000000..b81206493494d8952d8d2b1b8d3830432b23f920
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/ipndm.md
@@ -0,0 +1,21 @@
+
+
+# IPNDMScheduler
+
+`IPNDMScheduler` is a fourth-order Improved Pseudo Linear Multistep scheduler. The original implementation can be found at [crowsonkb/v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296).
+
+## IPNDMScheduler
+[[autodoc]] IPNDMScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/lcm.md b/diffusers/docs/source/en/api/schedulers/lcm.md
new file mode 100644
index 0000000000000000000000000000000000000000..5223072fd153c6732ea1a0287b8a11724cd09ea3
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/lcm.md
@@ -0,0 +1,21 @@
+
+
+# Latent Consistency Model Multistep Scheduler
+
+## Overview
+
+Multistep and onestep scheduler (Algorithm 3) introduced alongside latent consistency models in the paper [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.
+This scheduler should be able to generate good samples from [`LatentConsistencyModelPipeline`] in 1-8 steps.
+
+## LCMScheduler
+[[autodoc]] LCMScheduler
diff --git a/diffusers/docs/source/en/api/schedulers/lms_discrete.md b/diffusers/docs/source/en/api/schedulers/lms_discrete.md
new file mode 100644
index 0000000000000000000000000000000000000000..46d95da5fcd99c2a7ff9abcae4d50d3b37cd10a2
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/lms_discrete.md
@@ -0,0 +1,21 @@
+
+
+# LMSDiscreteScheduler
+
+`LMSDiscreteScheduler` is a linear multistep scheduler for discrete beta schedules. The scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/), and the original implementation can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
+
+## LMSDiscreteScheduler
+[[autodoc]] LMSDiscreteScheduler
+
+## LMSDiscreteSchedulerOutput
+[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/multistep_dpm_solver.md b/diffusers/docs/source/en/api/schedulers/multistep_dpm_solver.md
new file mode 100644
index 0000000000000000000000000000000000000000..ce6bde55446359e661fd10045db8ba2f4d04aec4
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/multistep_dpm_solver.md
@@ -0,0 +1,35 @@
+
+
+# DPMSolverMultistepScheduler
+
+`DPMSolverMultistep` is a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
+
+DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality
+samples, and it can generate quite good samples even in 10 steps.
+
+## Tips
+
+It is recommended to set `solver_order` to 2 for guide sampling, and `solver_order=3` for unconditional sampling.
+
+Dynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space
+diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+thresholding. This thresholding method is unsuitable for latent-space diffusion models such as
+Stable Diffusion.
+
+The SDE variant of DPMSolver and DPM-Solver++ is also supported, but only for the first and second-order solvers. This is a fast SDE solver for the reverse diffusion SDE. It is recommended to use the second-order `sde-dpmsolver++`.
+
+## DPMSolverMultistepScheduler
+[[autodoc]] DPMSolverMultistepScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.md b/diffusers/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.md
new file mode 100644
index 0000000000000000000000000000000000000000..6a286f3d0ce14b0a8ca68437f192c83d600f7832
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.md
@@ -0,0 +1,30 @@
+
+
+# DPMSolverMultistepInverse
+
+`DPMSolverMultistepInverse` is the inverted scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
+
+The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://huggingface.co/papers/2211.09794) and notebook implementation of the [`DiffEdit`] latent inversion from [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/diffedit.ipynb).
+
+## Tips
+
+Dynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space
+diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
+thresholding. This thresholding method is unsuitable for latent-space diffusion models such as
+Stable Diffusion.
+
+## DPMSolverMultistepInverseScheduler
+[[autodoc]] DPMSolverMultistepInverseScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/overview.md b/diffusers/docs/source/en/api/schedulers/overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef17e43e7217d14a0f8739723737cb954038004c
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/overview.md
@@ -0,0 +1,64 @@
+
+
+# Schedulers
+
+🤗 Diffusers provides many scheduler functions for the diffusion process. A scheduler takes a model's output (the sample which the diffusion process is iterating on) and a timestep to return a denoised sample. The timestep is important because it dictates where in the diffusion process the step is; data is generated by iterating forward *n* timesteps and inference occurs by propagating backward through the timesteps. Based on the timestep, a scheduler may be *discrete* in which case the timestep is an `int` or *continuous* in which case the timestep is a `float`.
+
+Depending on the context, a scheduler defines how to iteratively add noise to an image or how to update a sample based on a model's output:
+
+- during *training*, a scheduler adds noise (there are different algorithms for how to add noise) to a sample to train a diffusion model
+- during *inference*, a scheduler defines how to update a sample based on a pretrained model's output
+
+Many schedulers are implemented from the [k-diffusion](https://github.com/crowsonkb/k-diffusion) library by [Katherine Crowson](https://github.com/crowsonkb/), and they're also widely used in A1111. To help you map the schedulers from k-diffusion and A1111 to the schedulers in 🤗 Diffusers, take a look at the table below:
+
+| A1111/k-diffusion | 🤗 Diffusers | Usage |
+|---------------------|-------------------------------------|---------------------------------------------------------------------------------------------------------------|
+| DPM++ 2M | [`DPMSolverMultistepScheduler`] | |
+| DPM++ 2M Karras | [`DPMSolverMultistepScheduler`] | init with `use_karras_sigmas=True` |
+| DPM++ 2M SDE | [`DPMSolverMultistepScheduler`] | init with `algorithm_type="sde-dpmsolver++"` |
+| DPM++ 2M SDE Karras | [`DPMSolverMultistepScheduler`] | init with `use_karras_sigmas=True` and `algorithm_type="sde-dpmsolver++"` |
+| DPM++ 2S a | N/A | very similar to `DPMSolverSinglestepScheduler` |
+| DPM++ 2S a Karras | N/A | very similar to `DPMSolverSinglestepScheduler(use_karras_sigmas=True, ...)` |
+| DPM++ SDE | [`DPMSolverSinglestepScheduler`] | |
+| DPM++ SDE Karras | [`DPMSolverSinglestepScheduler`] | init with `use_karras_sigmas=True` |
+| DPM2 | [`KDPM2DiscreteScheduler`] | |
+| DPM2 Karras | [`KDPM2DiscreteScheduler`] | init with `use_karras_sigmas=True` |
+| DPM2 a | [`KDPM2AncestralDiscreteScheduler`] | |
+| DPM2 a Karras | [`KDPM2AncestralDiscreteScheduler`] | init with `use_karras_sigmas=True` |
+| DPM adaptive | N/A | |
+| DPM fast | N/A | |
+| Euler | [`EulerDiscreteScheduler`] | |
+| Euler a | [`EulerAncestralDiscreteScheduler`] | |
+| Heun | [`HeunDiscreteScheduler`] | |
+| LMS | [`LMSDiscreteScheduler`] | |
+| LMS Karras | [`LMSDiscreteScheduler`] | init with `use_karras_sigmas=True` |
+| N/A | [`DEISMultistepScheduler`] | |
+| N/A | [`UniPCMultistepScheduler`] | |
+
+All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
+
+## SchedulerMixin
+[[autodoc]] SchedulerMixin
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
+
+## KarrasDiffusionSchedulers
+
+[`KarrasDiffusionSchedulers`] are a broad generalization of schedulers in 🤗 Diffusers. The schedulers in this class are distinguished at a high level by their noise sampling strategy, the type of network and scaling, the training strategy, and how the loss is weighed.
+
+The different schedulers in this class, depending on the ordinary differential equations (ODE) solver type, fall into the above taxonomy and provide a good abstraction for the design of the main schedulers implemented in 🤗 Diffusers. The schedulers in this class are given [here](https://github.com/huggingface/diffusers/blob/a69754bb879ed55b9b6dc9dd0b3cf4fa4124c765/src/diffusers/schedulers/scheduling_utils.py#L32).
+
+## PushToHubMixin
+
+[[autodoc]] utils.PushToHubMixin
diff --git a/diffusers/docs/source/en/api/schedulers/pndm.md b/diffusers/docs/source/en/api/schedulers/pndm.md
new file mode 100644
index 0000000000000000000000000000000000000000..33717662ae3fc1d1b4f92730be8aa74cea208503
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/pndm.md
@@ -0,0 +1,21 @@
+
+
+# PNDMScheduler
+
+`PNDMScheduler`, or pseudo numerical methods for diffusion models, uses more advanced ODE integration techniques like the Runge-Kutta and linear multi-step method. The original implementation can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
+
+## PNDMScheduler
+[[autodoc]] PNDMScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/repaint.md b/diffusers/docs/source/en/api/schedulers/repaint.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3910ad71056f103980b649ba5ac0c638d20b790
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/repaint.md
@@ -0,0 +1,27 @@
+
+
+# RePaintScheduler
+
+`RePaintScheduler` is a DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks. It is designed to be used with the [`RePaintPipeline`], and it is based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2201.09865) by Andreas Lugmayr et al.
+
+The abstract from the paper is:
+
+*Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks. RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions. GitHub Repository: [this http URL](http://git.io/RePaint).*
+
+The original implementation can be found at [andreas128/RePaint](https://github.com/andreas128/).
+
+## RePaintScheduler
+[[autodoc]] RePaintScheduler
+
+## RePaintSchedulerOutput
+[[autodoc]] schedulers.scheduling_repaint.RePaintSchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/score_sde_ve.md b/diffusers/docs/source/en/api/schedulers/score_sde_ve.md
new file mode 100644
index 0000000000000000000000000000000000000000..5b930f192d93199fee8d41332725a975970e775e
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/score_sde_ve.md
@@ -0,0 +1,25 @@
+
+
+# ScoreSdeVeScheduler
+
+`ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler. It was introduced in the [Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) paper by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, Ben Poole.
+
+The abstract from the paper is:
+
+*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*
+
+## ScoreSdeVeScheduler
+[[autodoc]] ScoreSdeVeScheduler
+
+## SdeVeOutput
+[[autodoc]] schedulers.scheduling_sde_ve.SdeVeOutput
diff --git a/diffusers/docs/source/en/api/schedulers/score_sde_vp.md b/diffusers/docs/source/en/api/schedulers/score_sde_vp.md
new file mode 100644
index 0000000000000000000000000000000000000000..204cba877722a6580e89e766a0877382175ac692
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/score_sde_vp.md
@@ -0,0 +1,28 @@
+
+
+# ScoreSdeVpScheduler
+
+`ScoreSdeVpScheduler` is a variance preserving stochastic differential equation (SDE) scheduler. It was introduced in the [Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) paper by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, Ben Poole.
+
+The abstract from the paper is:
+
+*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*
+
+
+
+🚧 This scheduler is under construction!
+
+
+
+## ScoreSdeVpScheduler
+[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
diff --git a/diffusers/docs/source/en/api/schedulers/singlestep_dpm_solver.md b/diffusers/docs/source/en/api/schedulers/singlestep_dpm_solver.md
new file mode 100644
index 0000000000000000000000000000000000000000..8962a3e40d9ac1599f2cd9a0cbde76ad593a6b95
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/singlestep_dpm_solver.md
@@ -0,0 +1,35 @@
+
+
+# DPMSolverSinglestepScheduler
+
+`DPMSolverSinglestepScheduler` is a single step scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
+
+DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality
+samples, and it can generate quite good samples even in 10 steps.
+
+The original implementation can be found at [LuChengTHU/dpm-solver](https://github.com/LuChengTHU/dpm-solver).
+
+## Tips
+
+It is recommended to set `solver_order` to 2 for guide sampling, and `solver_order=3` for unconditional sampling.
+
+Dynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space
+diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use dynamic
+thresholding. This thresholding method is unsuitable for latent-space diffusion models such as
+Stable Diffusion.
+
+## DPMSolverSinglestepScheduler
+[[autodoc]] DPMSolverSinglestepScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/stochastic_karras_ve.md b/diffusers/docs/source/en/api/schedulers/stochastic_karras_ve.md
new file mode 100644
index 0000000000000000000000000000000000000000..eb954d7e5e7b3fbb8642d238790838adb46258ce
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/stochastic_karras_ve.md
@@ -0,0 +1,21 @@
+
+
+# KarrasVeScheduler
+
+`KarrasVeScheduler` is a stochastic sampler tailored to variance-expanding (VE) models. It is based on the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) and [Score-based generative modeling through stochastic differential equations](https://huggingface.co/papers/2011.13456) papers.
+
+## KarrasVeScheduler
+[[autodoc]] KarrasVeScheduler
+
+## KarrasVeOutput
+[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeOutput
diff --git a/diffusers/docs/source/en/api/schedulers/unipc.md b/diffusers/docs/source/en/api/schedulers/unipc.md
new file mode 100644
index 0000000000000000000000000000000000000000..df514ca4a61cd28854ded2f3ac0989257b52b502
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/unipc.md
@@ -0,0 +1,35 @@
+
+
+# UniPCMultistepScheduler
+
+`UniPCMultistepScheduler` is a training-free framework designed for fast sampling of diffusion models. It was introduced in [UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models](https://huggingface.co/papers/2302.04867) by Wenliang Zhao, Lujia Bai, Yongming Rao, Jie Zhou, Jiwen Lu.
+
+It consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
+UniPC is by design model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can also be applied to both noise prediction and data prediction models. The corrector UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy.
+
+The abstract from the paper is:
+
+*Diffusion probabilistic models (DPMs) have demonstrated a very promising ability in high-resolution image synthesis. However, sampling from a pre-trained DPM is time-consuming due to the multiple evaluations of the denoising network, making it more and more important to accelerate the sampling of DPMs. Despite recent progress in designing fast samplers, existing methods still cannot generate satisfying images in many applications where fewer steps (e.g., <10) are favored. In this paper, we develop a unified corrector (UniC) that can be applied after any existing DPM sampler to increase the order of accuracy without extra model evaluations, and derive a unified predictor (UniP) that supports arbitrary order as a byproduct. Combining UniP and UniC, we propose a unified predictor-corrector framework called UniPC for the fast sampling of DPMs, which has a unified analytical form for any order and can significantly improve the sampling quality over previous methods, especially in extremely few steps. We evaluate our methods through extensive experiments including both unconditional and conditional sampling using pixel-space and latent-space DPMs. Our UniPC can achieve 3.87 FID on CIFAR10 (unconditional) and 7.51 FID on ImageNet 256×256 (conditional) with only 10 function evaluations. Code is available at [this https URL](https://github.com/wl-zhao/UniPC).*
+
+## Tips
+
+It is recommended to set `solver_order` to 2 for guide sampling, and `solver_order=3` for unconditional sampling.
+
+Dynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space
+diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use dynamic thresholding. This thresholding method is unsuitable for latent-space diffusion models such as Stable Diffusion.
+
+## UniPCMultistepScheduler
+[[autodoc]] UniPCMultistepScheduler
+
+## SchedulerOutput
+[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
diff --git a/diffusers/docs/source/en/api/schedulers/vq_diffusion.md b/diffusers/docs/source/en/api/schedulers/vq_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..09928583f67022c5cbc49ad30d5d0808222d2111
--- /dev/null
+++ b/diffusers/docs/source/en/api/schedulers/vq_diffusion.md
@@ -0,0 +1,25 @@
+
+
+# VQDiffusionScheduler
+
+`VQDiffusionScheduler` converts the transformer model's output into a sample for the unnoised image at the previous diffusion timestep. It was introduced in [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://huggingface.co/papers/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo.
+
+The abstract from the paper is:
+
+*We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.*
+
+## VQDiffusionScheduler
+[[autodoc]] VQDiffusionScheduler
+
+## VQDiffusionSchedulerOutput
+[[autodoc]] schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput
diff --git a/diffusers/docs/source/en/api/utilities.md b/diffusers/docs/source/en/api/utilities.md
new file mode 100644
index 0000000000000000000000000000000000000000..77ada0834808e379b5b72eef0abede665242a961
--- /dev/null
+++ b/diffusers/docs/source/en/api/utilities.md
@@ -0,0 +1,39 @@
+
+
+# Utilities
+
+Utility and helper functions for working with 🤗 Diffusers.
+
+## numpy_to_pil
+
+[[autodoc]] utils.numpy_to_pil
+
+## pt_to_pil
+
+[[autodoc]] utils.pt_to_pil
+
+## load_image
+
+[[autodoc]] utils.load_image
+
+## export_to_gif
+
+[[autodoc]] utils.export_to_gif
+
+## export_to_video
+
+[[autodoc]] utils.export_to_video
+
+## make_image_grid
+
+[[autodoc]] utils.make_image_grid
diff --git a/diffusers/docs/source/en/conceptual/contribution.md b/diffusers/docs/source/en/conceptual/contribution.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc942a24c42e5c2d6b28215e478869066d74803c
--- /dev/null
+++ b/diffusers/docs/source/en/conceptual/contribution.md
@@ -0,0 +1,505 @@
+
+
+# How to contribute to Diffusers 🧨
+
+We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!
+
+Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕.
+
+Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.
+
+We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.
+
+## Overview
+
+You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
+the core library.
+
+In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
+
+* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
+* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
+* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
+* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
+* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
+* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
+* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).
+* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).
+* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).
+
+As said before, **all contributions are valuable to the community**.
+In the following, we will explain each contribution a bit more in detail.
+
+For all contributions 4 - 9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).
+
+### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
+
+Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):
+- Reports of training or inference experiments in an attempt to share knowledge
+- Presentation of personal projects
+- Questions to non-official training examples
+- Project proposals
+- General feedback
+- Paper summaries
+- Asking for help on personal projects that build on top of the Diffusers library
+- General questions
+- Ethical questions regarding diffusion models
+- ...
+
+Every question that is asked on the forum or on Discord actively encourages the community to publicly
+share knowledge and might very well help a beginner in the future who has the same question you're
+having. Please do pose any questions you might have.
+In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
+
+**Please** keep in mind that the more effort you put into asking or answering a question, the higher
+the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
+In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formated/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
+
+**NOTE about channels**:
+[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
+In addition, questions and answers posted in the forum can easily be linked to.
+In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.
+While it will most likely take less time for you to get an answer to your question on Discord, your
+question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.
+
+### 2. Opening new issues on the GitHub issues tab
+
+The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of
+the problems they encounter. So thank you for reporting an issue.
+
+Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.
+
+In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
+
+**Please consider the following guidelines when opening a new issue**:
+- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).
+- Please never report a new issue on another (related) issue. If another issue is highly related, please
+open a new issue nevertheless and link to the related issue.
+- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.
+- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version.
+- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.
+
+New issues usually include the following.
+
+#### 2.1. Reproducible, minimal bug reports
+
+A bug report should always have a reproducible code snippet and be as minimal and concise as possible.
+This means in more detail:
+- Narrow the bug down as much as you can, **do not just dump your whole code file**.
+- Format your code.
+- Do not include any external libraries except for Diffusers depending on them.
+- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
+- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
+- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
+- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
+
+For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
+
+You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).
+
+#### 2.2. Feature requests
+
+A world-class feature request addresses the following points:
+
+1. Motivation first:
+* Is it related to a problem/frustration with the library? If so, please explain
+why. Providing a code snippet that demonstrates the problem is best.
+* Is it related to something you would need for a project? We'd love to hear
+about it!
+* Is it something you worked on and think could benefit the community?
+Awesome! Tell us what problem it solved for you.
+2. Write a *full paragraph* describing the feature;
+3. Provide a **code snippet** that demonstrates its future use;
+4. In case this is related to a paper, please attach a link;
+5. Attach any additional information (drawings, screenshots, etc.) you think may help.
+
+You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
+
+#### 2.3 Feedback
+
+Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
+If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
+
+You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
+
+#### 2.4 Technical questions
+
+Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide details on
+why this part of the code is difficult to understand.
+
+You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).
+
+#### 2.5 Proposal to add a new model, scheduler, or pipeline
+
+If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:
+
+* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.
+* Link to any of its open-source implementation(s).
+* Link to the model weights if they are available.
+
+If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget
+to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.
+
+You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).
+
+### 3. Answering issues on the GitHub issues tab
+
+Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.
+Some tips to give a high-quality answer to an issue:
+- Be as concise and minimal as possible.
+- Stay on topic. An answer to the issue should concern the issue and only the issue.
+- Provide links to code, papers, or other sources that prove or encourage your point.
+- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.
+
+Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great
+help to the maintainers if you can answer such issues, encouraging the author of the issue to be
+more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
+
+If you have verified that the issued bug report is correct and requires a correction in the source code,
+please have a look at the next sections.
+
+For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
+
+### 4. Fixing a "Good first issue"
+
+*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already
+explains how a potential solution should look so that it is easier to fix.
+If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios:
+- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.
+- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.
+- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.
+
+
+### 5. Contribute to the documentation
+
+A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly
+valuable contribution**.
+
+Contributing to the library can have many forms:
+
+- Correcting spelling or grammatical errors.
+- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we would be very happy if you take some time to correct it.
+- Correct the shape or dimensions of a docstring input or output tensor.
+- Clarify documentation that is hard to understand or incorrect.
+- Update outdated code examples.
+- Translating the documentation to another language.
+
+Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).
+
+Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.
+
+
+### 6. Contribute a community pipeline
+
+[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user.
+Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview).
+We support two types of pipelines:
+
+- Official Pipelines
+- Community Pipelines
+
+Both official and community pipelines follow the same design and consist of the same type of components.
+
+Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code
+resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
+In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested.
+They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution.
+
+The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all
+possible ways diffusion models can be used for inference, but some of them may be of interest to the community.
+Officially released diffusion pipelines,
+such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures
+high quality of maintenance, no backward-breaking code changes, and testing.
+More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library.
+
+To add a community pipeline, one should add a .py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline.
+
+An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400).
+
+Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors.
+
+Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the
+core package.
+
+### 7. Contribute to training examples
+
+Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).
+
+We support two types of training examples:
+
+- Official training examples
+- Research training examples
+
+Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.
+The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.
+This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
+If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
+
+Both official training and research examples consist of a directory that contains one or more training scripts, a requirements.txt file, and a README.md file. In order for the user to make use of the
+training examples, it is required to clone the repository:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+```
+
+as well as to install all additional dependencies required for training:
+
+```bash
+pip install -r /examples//requirements.txt
+```
+
+Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
+
+Training examples of the Diffusers library should adhere to the following philosophy:
+- All the code necessary to run the examples should be found in a single Python file.
+- One should be able to run the example from the command line with `python .py --args`.
+- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.
+
+To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.
+We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated
+with Diffusers.
+Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:
+- An example command on how to run the example script as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).
+- A link to some training results (logs, models, etc.) that show what the user can expect as shown [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
+- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
+
+If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
+
+### 8. Fixing a "Good second issue"
+
+*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are
+usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
+The issue description usually gives less guidance on how to fix the issue and requires
+a decent understanding of the library by the interested contributor.
+If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.
+Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.
+
+### 9. Adding pipelines, models, schedulers
+
+Pipelines, models, and schedulers are the most important pieces of the Diffusers library.
+They provide easy access to state-of-the-art diffusion technologies and thus allow the community to
+build powerful generative AI applications.
+
+By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.
+
+Diffusers has a couple of open feature requests for all three components - feel free to gloss over them
+if you don't know yet what specific component you would like to add:
+- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
+- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
+
+Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that
+we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
+as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
+open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
+pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
+
+Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
+original author directly on the PR so that they can follow the progress and potentially help with questions.
+
+If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
+
+## How to write a good issue
+
+**The better your issue is written, the higher the chances that it will be quickly resolved.**
+
+1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).
+2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers".
+3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.
+4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.
+5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.
+6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.
+7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.
+
+## How to write a good PR
+
+1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.
+2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.
+3. If helpful, try to add a code snippet that displays an example of how your addition can be used.
+4. The title of your pull request should be a summary of its contribution.
+5. If your pull request addresses an issue, please mention the issue number in
+the pull request description to make sure they are linked (and people
+consulting the issue know you are working on it);
+6. To indicate a work in progress please prefix the title with `[WIP]`. These
+are useful to avoid duplicated work, and to differentiate it from PRs ready
+to be merged;
+7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).
+8. Make sure existing tests pass;
+9. Add high-coverage tests. No quality testing = no merge.
+- If you are adding new `@slow` tests, make sure they pass using
+`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
+CircleCI does not run the slow tests, but GitHub Actions does every night!
+10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.
+11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
+[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.
+If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
+to this dataset.
+
+## How to open a PR
+
+Before writing code, we strongly advise you to search through the existing PRs or
+issues to make sure that nobody is already working on the same thing. If you are
+unsure, it is always a good idea to open an issue to get some feedback.
+
+You will need basic `git` proficiency to be able to contribute to
+🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest
+manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
+Git](https://git-scm.com/book/en/v2) is a very good reference.
+
+Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
+
+1. Fork the [repository](https://github.com/huggingface/diffusers) by
+clicking on the 'Fork' button on the repository's page. This creates a copy of the code
+under your GitHub user account.
+
+2. Clone your fork to your local disk, and add the base repository as a remote:
+
+ ```bash
+ $ git clone git@github.com:/diffusers.git
+ $ cd diffusers
+ $ git remote add upstream https://github.com/huggingface/diffusers.git
+ ```
+
+3. Create a new branch to hold your development changes:
+
+ ```bash
+ $ git checkout -b a-descriptive-name-for-my-changes
+ ```
+
+**Do not** work on the `main` branch.
+
+4. Set up a development environment by running the following command in a virtual environment:
+
+ ```bash
+ $ pip install -e ".[dev]"
+ ```
+
+If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the
+library.
+
+5. Develop the features on your branch.
+
+As you work on the features, you should make sure that the test suite
+passes. You should run the tests impacted by your changes like this:
+
+ ```bash
+ $ pytest tests/.py
+ ```
+
+Before you run the tests, please make sure you install the dependencies required for testing. You can do so
+with this command:
+
+ ```bash
+ $ pip install -e ".[test]"
+ ```
+
+You can also run the full test suite with the following command, but it takes
+a beefy machine to produce a result in a decent amount of time now that
+Diffusers has grown a lot. Here is the command for it:
+
+ ```bash
+ $ make test
+ ```
+
+🧨 Diffusers relies on `black` and `isort` to format its source code
+consistently. After you make changes, apply automatic style corrections and code verifications
+that can't be automated in one go with:
+
+ ```bash
+ $ make style
+ ```
+
+🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
+control runs in CI, however, you can also run the same checks with:
+
+ ```bash
+ $ make quality
+ ```
+
+Once you're happy with your changes, add changed files using `git add` and
+make a commit with `git commit` to record your changes locally:
+
+ ```bash
+ $ git add modified_file.py
+ $ git commit -m "A descriptive message about your changes."
+ ```
+
+It is a good idea to sync your copy of the code with the original
+repository regularly. This way you can quickly account for changes:
+
+ ```bash
+ $ git pull upstream main
+ ```
+
+Push the changes to your account using:
+
+ ```bash
+ $ git push -u origin a-descriptive-name-for-my-changes
+ ```
+
+6. Once you are satisfied, go to the
+webpage of your fork on GitHub. Click on 'Pull request' to send your changes
+to the project maintainers for review.
+
+7. It's OK if maintainers ask you for changes. It happens to core contributors
+too! So everyone can see the changes in the Pull request, work in your local
+branch and push the changes to your fork. They will automatically appear in
+the pull request.
+
+### Tests
+
+An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
+the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).
+
+We like `pytest` and `pytest-xdist` because it's faster. From the root of the
+repository, here's how to run tests with `pytest` for the library:
+
+```bash
+$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
+```
+
+In fact, that's how `make test` is implemented!
+
+You can specify a smaller set of tests in order to test only the feature
+you're working on.
+
+By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to
+`yes` to run them. This will download many gigabytes of models — make sure you
+have enough disk space and a good Internet connection, or a lot of patience!
+
+```bash
+$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
+```
+
+`unittest` is fully supported, here's how to run tests with it:
+
+```bash
+$ python -m unittest discover -s tests -t . -v
+$ python -m unittest discover -s examples -t examples -v
+```
+
+### Syncing forked main with upstream (HuggingFace) main
+
+To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,
+when syncing the main branch of a forked repository, please, follow these steps:
+1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.
+2. If a PR is absolutely necessary, use the following steps after checking out your branch:
+```bash
+$ git checkout -b your-branch-for-syncing
+$ git pull --squash --no-commit upstream main
+$ git commit -m ''
+$ git push --set-upstream origin your-branch-for-syncing
+```
+
+### Style guide
+
+For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
diff --git a/diffusers/docs/source/en/conceptual/ethical_guidelines.md b/diffusers/docs/source/en/conceptual/ethical_guidelines.md
new file mode 100644
index 0000000000000000000000000000000000000000..86176bcaa33ec85929ec8ba3570e884d2f360d23
--- /dev/null
+++ b/diffusers/docs/source/en/conceptual/ethical_guidelines.md
@@ -0,0 +1,63 @@
+
+
+# 🧨 Diffusers’ Ethical Guidelines
+
+## Preamble
+
+[Diffusers](https://huggingface.co/docs/diffusers/index) provides pre-trained diffusion models and serves as a modular toolbox for inference and training.
+
+Given its real case applications in the world and potential negative impacts on society, we think it is important to provide the project with ethical guidelines to guide the development, users’ contributions, and usage of the Diffusers library.
+
+The risks associated with using this technology are still being examined, but to name a few: copyrights issues for artists; deep-fake exploitation; sexual content generation in inappropriate contexts; non-consensual impersonation; harmful social biases perpetuating the oppression of marginalized groups.
+We will keep tracking risks and adapt the following guidelines based on the community's responsiveness and valuable feedback.
+
+
+## Scope
+
+The Diffusers community will apply the following ethical guidelines to the project’s development and help coordinate how the community will integrate the contributions, especially concerning sensitive topics related to ethical concerns.
+
+
+## Ethical guidelines
+
+The following ethical guidelines apply generally, but we will primarily implement them when dealing with ethically sensitive issues while making a technical choice. Furthermore, we commit to adapting those ethical principles over time following emerging harms related to the state of the art of the technology in question.
+
+- **Transparency**: we are committed to being transparent in managing PRs, explaining our choices to users, and making technical decisions.
+
+- **Consistency**: we are committed to guaranteeing our users the same level of attention in project management, keeping it technically stable and consistent.
+
+- **Simplicity**: with a desire to make it easy to use and exploit the Diffusers library, we are committed to keeping the project’s goals lean and coherent.
+
+- **Accessibility**: the Diffusers project helps lower the entry bar for contributors who can help run it even without technical expertise. Doing so makes research artifacts more accessible to the community.
+
+- **Reproducibility**: we aim to be transparent about the reproducibility of upstream code, models, and datasets when made available through the Diffusers library.
+
+- **Responsibility**: as a community and through teamwork, we hold a collective responsibility to our users by anticipating and mitigating this technology's potential risks and dangers.
+
+
+## Examples of implementations: Safety features and Mechanisms
+
+The team works daily to make the technical and non-technical tools available to deal with the potential ethical and social risks associated with diffusion technology. Moreover, the community's input is invaluable in ensuring these features' implementation and raising awareness with us.
+
+- [**Community tab**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): it enables the community to discuss and better collaborate on a project.
+
+- **Bias exploration and evaluation**: the Hugging Face team provides a [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer) to demonstrate the biases in Stable Diffusion interactively. In this sense, we support and encourage bias explorers and evaluations.
+
+- **Encouraging safety in deployment**
+
+ - [**Safe Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): It mitigates the well-known issue that models, like Stable Diffusion, that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. Related paper: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105).
+
+ - [**Safety Checker**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): It checks and compares the class probability of a set of hard-coded harmful concepts in the embedding space against an image after it has been generated. The harmful concepts are intentionally hidden to prevent reverse engineering of the checker.
+
+- **Staged released on the Hub**: in particularly sensitive situations, access to some repositories should be restricted. This staged release is an intermediary step that allows the repository’s authors to have more control over its use.
+
+- **Licensing**: [OpenRAILs](https://huggingface.co/blog/open_rail), a new type of licensing, allow us to ensure free access while having a set of restrictions that ensure more responsible use.
diff --git a/diffusers/docs/source/en/conceptual/evaluation.md b/diffusers/docs/source/en/conceptual/evaluation.md
new file mode 100644
index 0000000000000000000000000000000000000000..848eec8620cda4fcc8aa1efda04341b4fcd7c1fe
--- /dev/null
+++ b/diffusers/docs/source/en/conceptual/evaluation.md
@@ -0,0 +1,567 @@
+
+
+# Evaluating Diffusion Models
+
+
+
+
+
+Evaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?
+
+Qualitative evaluation of such models can be error-prone and might incorrectly influence a decision.
+However, quantitative metrics don't necessarily correspond to image quality. So, usually, a combination
+of both qualitative and quantitative evaluations provides a stronger signal when choosing one model
+over the other.
+
+In this document, we provide a non-exhaustive overview of qualitative and quantitative methods to evaluate Diffusion models. For quantitative methods, we specifically focus on how to implement them alongside `diffusers`.
+
+The methods shown in this document can also be used to evaluate different [noise schedulers](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview) keeping the underlying generation model fixed.
+
+## Scenarios
+
+We cover Diffusion models with the following pipelines:
+
+- Text-guided image generation (such as the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img)).
+- Text-guided image generation, additionally conditioned on an input image (such as the [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img) and [`StableDiffusionInstructPix2PixPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix)).
+- Class-conditioned image generation models (such as the [`DiTPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit)).
+
+## Qualitative Evaluation
+
+Qualitative evaluation typically involves human assessment of generated images. Quality is measured across aspects such as compositionality, image-text alignment, and spatial relations. Common prompts provide a degree of uniformity for subjective metrics.
+DrawBench and PartiPrompts are prompt datasets used for qualitative benchmarking. DrawBench and PartiPrompts were introduced by [Imagen](https://imagen.research.google/) and [Parti](https://parti.research.google/) respectively.
+
+From the [official Parti website](https://parti.research.google/):
+
+> PartiPrompts (P2) is a rich set of over 1600 prompts in English that we release as part of this work. P2 can be used to measure model capabilities across various categories and challenge aspects.
+
+![parti-prompts](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts.png)
+
+PartiPrompts has the following columns:
+
+- Prompt
+- Category of the prompt (such as “Abstract”, “World Knowledge”, etc.)
+- Challenge reflecting the difficulty (such as “Basic”, “Complex”, “Writing & Symbols”, etc.)
+
+These benchmarks allow for side-by-side human evaluation of different image generation models.
+
+For this, the 🧨 Diffusers team has built **Open Parti Prompts**, which is a community-driven qualitative benchmark based on Parti Prompts to compare state-of-the-art open-source diffusion models:
+- [Open Parti Prompts Game](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts): For 10 parti prompts, 4 generated images are shown and the user selects the image that suits the prompt best.
+- [Open Parti Prompts Leaderboard](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard): The leaderboard comparing the currently best open-sourced diffusion models to each other.
+
+To manually compare images, let’s see how we can use `diffusers` on a couple of PartiPrompts.
+
+Below we show some prompts sampled across different challenges: Basic, Complex, Linguistic Structures, Imagination, and Writing & Symbols. Here we are using PartiPrompts as a [dataset](https://huggingface.co/datasets/nateraw/parti-prompts).
+
+```python
+from datasets import load_dataset
+
+# prompts = load_dataset("nateraw/parti-prompts", split="train")
+# prompts = prompts.shuffle()
+# sample_prompts = [prompts[i]["Prompt"] for i in range(5)]
+
+# Fixing these sample prompts in the interest of reproducibility.
+sample_prompts = [
+ "a corgi",
+ "a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky",
+ "a car with no windows",
+ "a cube made of porcupine",
+ 'The saying "BE EXCELLENT TO EACH OTHER" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',
+]
+```
+
+Now we can use these prompts to generate some images using Stable Diffusion ([v1-4 checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4)):
+
+```python
+import torch
+
+seed = 0
+generator = torch.manual_seed(seed)
+
+images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images
+```
+
+![parti-prompts-14](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png)
+
+We can also set `num_images_per_prompt` accordingly to compare different images for the same prompt. Running the same pipeline but with a different checkpoint ([v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)), yields:
+
+![parti-prompts-15](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png)
+
+Once several images are generated from all the prompts using multiple models (under evaluation), these results are presented to human evaluators for scoring. For
+more details on the DrawBench and PartiPrompts benchmarks, refer to their respective papers.
+
+
+
+It is useful to look at some inference samples while a model is training to measure the
+training progress. In our [training scripts](https://github.com/huggingface/diffusers/tree/main/examples/), we support this utility with additional support for
+logging to TensorBoard and Weights & Biases.
+
+
+
+## Quantitative Evaluation
+
+In this section, we will walk you through how to evaluate three different diffusion pipelines using:
+
+- CLIP score
+- CLIP directional similarity
+- FID
+
+### Text-guided image generation
+
+[CLIP score](https://arxiv.org/abs/2104.08718) measures the compatibility of image-caption pairs. Higher CLIP scores imply higher compatibility 🔼. The CLIP score is a quantitative measurement of the qualitative concept "compatibility". Image-caption pair compatibility can also be thought of as the semantic similarity between the image and the caption. CLIP score was found to have high correlation with human judgement.
+
+Let's first load a [`StableDiffusionPipeline`]:
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_ckpt = "CompVis/stable-diffusion-v1-4"
+sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to("cuda")
+```
+
+Generate some images with multiple prompts:
+
+```python
+prompts = [
+ "a photo of an astronaut riding a horse on mars",
+ "A high tech solarpunk utopia in the Amazon rainforest",
+ "A pikachu fine dining with a view to the Eiffel Tower",
+ "A mecha robot in a favela in expressionist style",
+ "an insect robot preparing a delicious meal",
+ "A small cabin on top of a snowy mountain in the style of Disney, artstation",
+]
+
+images = sd_pipeline(prompts, num_images_per_prompt=1, output_type="np").images
+
+print(images.shape)
+# (6, 512, 512, 3)
+```
+
+And then, we calculate the CLIP score.
+
+```python
+from torchmetrics.functional.multimodal import clip_score
+from functools import partial
+
+clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
+
+def calculate_clip_score(images, prompts):
+ images_int = (images * 255).astype("uint8")
+ clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
+ return round(float(clip_score), 4)
+
+sd_clip_score = calculate_clip_score(images, prompts)
+print(f"CLIP score: {sd_clip_score}")
+# CLIP score: 35.7038
+```
+
+In the above example, we generated one image per prompt. If we generated multiple images per prompt, we would have to take the average score from the generated images per prompt.
+
+Now, if we wanted to compare two checkpoints compatible with the [`StableDiffusionPipeline`] we should pass a generator while calling the pipeline. First, we generate images with a
+fixed seed with the [v1-4 Stable Diffusion checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4):
+
+```python
+seed = 0
+generator = torch.manual_seed(seed)
+
+images = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
+```
+
+Then we load the [v1-5 checkpoint](https://huggingface.co/runwayml/stable-diffusion-v1-5) to generate images:
+
+```python
+model_ckpt_1_5 = "runwayml/stable-diffusion-v1-5"
+sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device)
+
+images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
+```
+
+And finally, we compare their CLIP scores:
+
+```python
+sd_clip_score_1_4 = calculate_clip_score(images, prompts)
+print(f"CLIP Score with v-1-4: {sd_clip_score_1_4}")
+# CLIP Score with v-1-4: 34.9102
+
+sd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)
+print(f"CLIP Score with v-1-5: {sd_clip_score_1_5}")
+# CLIP Score with v-1-5: 36.2137
+```
+
+It seems like the [v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint performs better than its predecessor. Note, however, that the number of prompts we used to compute the CLIP scores is quite low. For a more practical evaluation, this number should be way higher, and the prompts should be diverse.
+
+
+
+By construction, there are some limitations in this score. The captions in the training dataset
+were crawled from the web and extracted from `alt` and similar tags associated an image on the internet.
+They are not necessarily representative of what a human being would use to describe an image. Hence we
+had to "engineer" some prompts here.
+
+
+
+### Image-conditioned text-to-image generation
+
+In this case, we condition the generation pipeline with an input image as well as a text prompt. Let's take the [`StableDiffusionInstructPix2PixPipeline`], as an example. It takes an edit instruction as an input prompt and an input image to be edited.
+
+Here is one example:
+
+![edit-instruction](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png)
+
+One strategy to evaluate such a model is to measure the consistency of the change between the two images (in [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) space) with the change between the two image captions (as shown in [CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/abs/2108.00946)). This is referred to as the "**CLIP directional similarity**".
+
+- Caption 1 corresponds to the input image (image 1) that is to be edited.
+- Caption 2 corresponds to the edited image (image 2). It should reflect the edit instruction.
+
+Following is a pictorial overview:
+
+![edit-consistency](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-consistency.png)
+
+We have prepared a mini dataset to implement this metric. Let's first load the dataset.
+
+```python
+from datasets import load_dataset
+
+dataset = load_dataset("sayakpaul/instructpix2pix-demo", split="train")
+dataset.features
+```
+
+```bash
+{'input': Value(dtype='string', id=None),
+ 'edit': Value(dtype='string', id=None),
+ 'output': Value(dtype='string', id=None),
+ 'image': Image(decode=True, id=None)}
+```
+
+Here we have:
+
+- `input` is a caption corresponding to the `image`.
+- `edit` denotes the edit instruction.
+- `output` denotes the modified caption reflecting the `edit` instruction.
+
+Let's take a look at a sample.
+
+```python
+idx = 0
+print(f"Original caption: {dataset[idx]['input']}")
+print(f"Edit instruction: {dataset[idx]['edit']}")
+print(f"Modified caption: {dataset[idx]['output']}")
+```
+
+```bash
+Original caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
+Edit instruction: make the isles all white marble
+Modified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
+```
+
+And here is the image:
+
+```python
+dataset[idx]["image"]
+```
+
+![edit-dataset](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-dataset.png)
+
+We will first edit the images of our dataset with the edit instruction and compute the directional similarity.
+
+Let's first load the [`StableDiffusionInstructPix2PixPipeline`]:
+
+```python
+from diffusers import StableDiffusionInstructPix2PixPipeline
+
+instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
+).to(device)
+```
+
+Now, we perform the edits:
+
+```python
+import numpy as np
+
+
+def edit_image(input_image, instruction):
+ image = instruct_pix2pix_pipeline(
+ instruction,
+ image=input_image,
+ output_type="np",
+ generator=generator,
+ ).images[0]
+ return image
+
+input_images = []
+original_captions = []
+modified_captions = []
+edited_images = []
+
+for idx in range(len(dataset)):
+ input_image = dataset[idx]["image"]
+ edit_instruction = dataset[idx]["edit"]
+ edited_image = edit_image(input_image, edit_instruction)
+
+ input_images.append(np.array(input_image))
+ original_captions.append(dataset[idx]["input"])
+ modified_captions.append(dataset[idx]["output"])
+ edited_images.append(edited_image)
+```
+
+To measure the directional similarity, we first load CLIP's image and text encoders:
+
+```python
+from transformers import (
+ CLIPTokenizer,
+ CLIPTextModelWithProjection,
+ CLIPVisionModelWithProjection,
+ CLIPImageProcessor,
+)
+
+clip_id = "openai/clip-vit-large-patch14"
+tokenizer = CLIPTokenizer.from_pretrained(clip_id)
+text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device)
+image_processor = CLIPImageProcessor.from_pretrained(clip_id)
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)
+```
+
+Notice that we are using a particular CLIP checkpoint, i.e., `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to the [documentation](https://huggingface.co/docs/transformers/model_doc/clip).
+
+Next, we prepare a PyTorch `nn.Module` to compute directional similarity:
+
+```python
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DirectionalSimilarity(nn.Module):
+ def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.text_encoder = text_encoder
+ self.image_processor = image_processor
+ self.image_encoder = image_encoder
+
+ def preprocess_image(self, image):
+ image = self.image_processor(image, return_tensors="pt")["pixel_values"]
+ return {"pixel_values": image.to(device)}
+
+ def tokenize_text(self, text):
+ inputs = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ return {"input_ids": inputs.input_ids.to(device)}
+
+ def encode_image(self, image):
+ preprocessed_image = self.preprocess_image(image)
+ image_features = self.image_encoder(**preprocessed_image).image_embeds
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
+ return image_features
+
+ def encode_text(self, text):
+ tokenized_text = self.tokenize_text(text)
+ text_features = self.text_encoder(**tokenized_text).text_embeds
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
+ return text_features
+
+ def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):
+ sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)
+ return sim_direction
+
+ def forward(self, image_one, image_two, caption_one, caption_two):
+ img_feat_one = self.encode_image(image_one)
+ img_feat_two = self.encode_image(image_two)
+ text_feat_one = self.encode_text(caption_one)
+ text_feat_two = self.encode_text(caption_two)
+ directional_similarity = self.compute_directional_similarity(
+ img_feat_one, img_feat_two, text_feat_one, text_feat_two
+ )
+ return directional_similarity
+```
+
+Let's put `DirectionalSimilarity` to use now.
+
+```python
+dir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)
+scores = []
+
+for i in range(len(input_images)):
+ original_image = input_images[i]
+ original_caption = original_captions[i]
+ edited_image = edited_images[i]
+ modified_caption = modified_captions[i]
+
+ similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)
+ scores.append(float(similarity_score.detach().cpu()))
+
+print(f"CLIP directional similarity: {np.mean(scores)}")
+# CLIP directional similarity: 0.0797976553440094
+```
+
+Like the CLIP Score, the higher the CLIP directional similarity, the better it is.
+
+It should be noted that the `StableDiffusionInstructPix2PixPipeline` exposes two arguments, namely, `image_guidance_scale` and `guidance_scale` that let you control the quality of the final edited image. We encourage you to experiment with these two arguments and see the impact of that on the directional similarity.
+
+We can extend the idea of this metric to measure how similar the original image and edited version are. To do that, we can just do `F.cosine_similarity(img_feat_two, img_feat_one)`. For these kinds of edits, we would still want the primary semantics of the images to be preserved as much as possible, i.e., a high similarity score.
+
+We can use these metrics for similar pipelines such as the [`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline).
+
+
+
+Both CLIP score and CLIP direction similarity rely on the CLIP model, which can make the evaluations biased.
+
+
+
+***Extending metrics like IS, FID (discussed later), or KID can be difficult*** when the model under evaluation was pre-trained on a large image-captioning dataset (such as the [LAION-5B dataset](https://laion.ai/blog/laion-5b/)). This is because underlying these metrics is an InceptionNet (pre-trained on the ImageNet-1k dataset) used for extracting intermediate image features. The pre-training dataset of Stable Diffusion may have limited overlap with the pre-training dataset of InceptionNet, so it is not a good candidate here for feature extraction.
+
+***Using the above metrics helps evaluate models that are class-conditioned. For example, [DiT](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit). It was pre-trained being conditioned on the ImageNet-1k classes.***
+
+### Class-conditioned image generation
+
+Class-conditioned generative models are usually pre-trained on a class-labeled dataset such as [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k). Popular metrics for evaluating these models include Fréchet Inception Distance (FID), Kernel Inception Distance (KID), and Inception Score (IS). In this document, we focus on FID ([Heusel et al.](https://arxiv.org/abs/1706.08500)). We show how to compute it with the [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit), which uses the [DiT model](https://arxiv.org/abs/2212.09748) under the hood.
+
+FID aims to measure how similar are two datasets of images. As per [this resource](https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid):
+
+> Fréchet Inception Distance is a measure of similarity between two datasets of images. It was shown to correlate well with the human judgment of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. FID is calculated by computing the Fréchet distance between two Gaussians fitted to feature representations of the Inception network.
+
+These two datasets are essentially the dataset of real images and the dataset of fake images (generated images in our case). FID is usually calculated with two large datasets. However, for this document, we will work with two mini datasets.
+
+Let's first download a few images from the ImageNet-1k training set:
+
+```python
+from zipfile import ZipFile
+import requests
+
+
+def download(url, local_filepath):
+ r = requests.get(url)
+ with open(local_filepath, "wb") as f:
+ f.write(r.content)
+ return local_filepath
+
+dummy_dataset_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip"
+local_filepath = download(dummy_dataset_url, dummy_dataset_url.split("/")[-1])
+
+with ZipFile(local_filepath, "r") as zipper:
+ zipper.extractall(".")
+```
+
+```python
+from PIL import Image
+import os
+
+dataset_path = "sample-imagenet-images"
+image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])
+
+real_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]
+```
+
+These are 10 images from the following ImageNet-1k classes: "cassette_player", "chain_saw" (x2), "church", "gas_pump" (x3), "parachute" (x2), and "tench".
+
+
+
+ Real images.
+
+
+Now that the images are loaded, let's apply some lightweight pre-processing on them to use them for FID calculation.
+
+```python
+from torchvision.transforms import functional as F
+
+
+def preprocess_image(image):
+ image = torch.tensor(image).unsqueeze(0)
+ image = image.permute(0, 3, 1, 2) / 255.0
+ return F.center_crop(image, (256, 256))
+
+real_images = torch.cat([preprocess_image(image) for image in real_images])
+print(real_images.shape)
+# torch.Size([10, 3, 256, 256])
+```
+
+We now load the [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit) to generate images conditioned on the above-mentioned classes.
+
+```python
+from diffusers import DiTPipeline, DPMSolverMultistepScheduler
+
+dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
+dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
+dit_pipeline = dit_pipeline.to("cuda")
+
+words = [
+ "cassette player",
+ "chainsaw",
+ "chainsaw",
+ "church",
+ "gas pump",
+ "gas pump",
+ "gas pump",
+ "parachute",
+ "parachute",
+ "tench",
+]
+
+class_ids = dit_pipeline.get_label_ids(words)
+output = dit_pipeline(class_labels=class_ids, generator=generator, output_type="np")
+
+fake_images = output.images
+fake_images = torch.tensor(fake_images)
+fake_images = fake_images.permute(0, 3, 1, 2)
+print(fake_images.shape)
+# torch.Size([10, 3, 256, 256])
+```
+
+Now, we can compute the FID using [`torchmetrics`](https://torchmetrics.readthedocs.io/).
+
+```python
+from torchmetrics.image.fid import FrechetInceptionDistance
+
+fid = FrechetInceptionDistance(normalize=True)
+fid.update(real_images, real=True)
+fid.update(fake_images, real=False)
+
+print(f"FID: {float(fid.compute())}")
+# FID: 177.7147216796875
+```
+
+The lower the FID, the better it is. Several things can influence FID here:
+
+- Number of images (both real and fake)
+- Randomness induced in the diffusion process
+- Number of inference steps in the diffusion process
+- The scheduler being used in the diffusion process
+
+For the last two points, it is, therefore, a good practice to run the evaluation across different seeds and inference steps, and then report an average result.
+
+
+
+FID results tend to be fragile as they depend on a lot of factors:
+
+* The specific Inception model used during computation.
+* The implementation accuracy of the computation.
+* The image format (not the same if we start from PNGs vs JPGs).
+
+Keeping that in mind, FID is often most useful when comparing similar runs, but it is
+hard to reproduce paper results unless the authors carefully disclose the FID
+measurement code.
+
+These points apply to other related metrics too, such as KID and IS.
+
+
+
+As a final step, let's visually inspect the `fake_images`.
+
+
+
+ Fake images.
+
diff --git a/diffusers/docs/source/en/conceptual/philosophy.md b/diffusers/docs/source/en/conceptual/philosophy.md
new file mode 100644
index 0000000000000000000000000000000000000000..c7b96abd7f1114705a77ca691e2de95b1aab6aa6
--- /dev/null
+++ b/diffusers/docs/source/en/conceptual/philosophy.md
@@ -0,0 +1,110 @@
+
+
+# Philosophy
+
+🧨 Diffusers provides **state-of-the-art** pretrained diffusion models across multiple modalities.
+Its purpose is to serve as a **modular toolbox** for both inference and training.
+
+We aim at building a library that stands the test of time and therefore take API design very seriously.
+
+In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefore, most of our design choices are based on [PyTorch's Design Principles](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy). Let's go over the most important ones:
+
+## Usability over Performance
+
+- While Diffusers has many built-in performance-enhancing features (see [Memory and Speed](https://huggingface.co/docs/diffusers/optimization/fp16)), models are always loaded with the highest precision and lowest optimization. Therefore, by default diffusion pipelines are always instantiated on CPU with float32 precision if not otherwise defined by the user. This ensures usability across different platforms and accelerators and means that no complex installations are required to run the library.
+- Diffusers aims to be a **light-weight** package and therefore has very few required dependencies, but many soft dependencies that can improve performance (such as `accelerate`, `safetensors`, `onnx`, etc...). We strive to keep the library as lightweight as possible so that it can be added without much concern as a dependency on other packages.
+- Diffusers prefers simple, self-explainable code over condensed, magic code. This means that short-hand code syntaxes such as lambda functions, and advanced PyTorch operators are often not desired.
+
+## Simple over easy
+
+As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:
+- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management.
+- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible.
+- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers.
+- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training
+is very simple thanks to Diffusers' ability to separate single components of the diffusion pipeline.
+
+## Tweakable, contributor-friendly over abstraction
+
+For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).
+In short, just like Transformers does for modeling files, Diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers.
+Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.
+**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because:
+- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions.
+- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions.
+- Open-source libraries rely on community contributions and therefore must build a library that is easy to contribute to. The more abstract the code, the more dependencies, the harder to read, and the harder to contribute to. Contributors simply stop contributing to very abstract libraries out of fear of breaking vital functionality. If contributing to a library cannot break other fundamental code, not only is it more inviting for potential new contributors, but it is also easier to review and contribute to multiple parts in parallel.
+
+At Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look
+at [this blog post](https://huggingface.co/blog/transformers-design-philosophy).
+
+In Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such
+as [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond).
+
+Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.
+We try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️ to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
+
+## Design Philosophy in Details
+
+Now, let's look a bit into the nitty-gritty details of the design philosophy. Diffusers essentially consists of three major classes: [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), and [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
+Let's walk through more in-detail design decisions for each class.
+
+### Pipelines
+
+Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.
+
+The following design principles are followed:
+- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
+- Pipelines all inherit from [`DiffusionPipeline`].
+- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.
+- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.
+- Pipelines should be used **only** for inference.
+- Pipelines should be very readable, self-explanatory, and easy to tweak.
+- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs.
+- Pipelines are **not** intended to be feature-complete user interfaces. For future complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).
+- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines.
+- Pipelines should be named after the task they are intended to solve.
+- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file.
+
+### Models
+
+Models are designed as configurable toolboxes that are natural extensions of [PyTorch's Module class](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). They only partly follow the **single-file policy**.
+
+The following design principles are followed:
+- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
+- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
+- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
+- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
+- Models all inherit from `ModelMixin` and `ConfigMixin`.
+- Models can be optimized for performance when it doesn’t demand major code changes, keeps backward compatibility, and gives significant memory or compute gain.
+- Models should by default have the highest precision and lowest performance setting.
+- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
+- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
+- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
+readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+### Schedulers
+
+Schedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**.
+
+The following design principles are followed:
+- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
+- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
+- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
+- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
+- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
+- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers.md).
+- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.
+- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
+- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
+- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
+- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
diff --git a/diffusers/docs/source/en/imgs/access_request.png b/diffusers/docs/source/en/imgs/access_request.png
new file mode 100644
index 0000000000000000000000000000000000000000..33c6abc88dfb226e929b44c30c173c787b407045
Binary files /dev/null and b/diffusers/docs/source/en/imgs/access_request.png differ
diff --git a/diffusers/docs/source/en/imgs/diffusers_library.jpg b/diffusers/docs/source/en/imgs/diffusers_library.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..07ba9c6571a3f070d9d10b78dccfd4d4537dd539
Binary files /dev/null and b/diffusers/docs/source/en/imgs/diffusers_library.jpg differ
diff --git a/diffusers/docs/source/en/index.md b/diffusers/docs/source/en/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..ce6e79ee44d11cc58f4fb18e3dd40af6a560432f
--- /dev/null
+++ b/diffusers/docs/source/en/index.md
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+# Diffusers
+
+🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or want to train your own diffusion model, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](conceptual/philosophy#usability-over-performance), [simple over easy](conceptual/philosophy#simple-over-easy), and [customizability over abstractions](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
+
+The library has three main components:
+
+- State-of-the-art diffusion pipelines for inference with just a few lines of code. There are many pipelines in 🤗 Diffusers, check out the table in the pipeline [overview](api/pipelines/overview) for a complete list of available pipelines and the task they solve.
+- Interchangeable [noise schedulers](api/schedulers/overview) for balancing trade-offs between generation speed and quality.
+- Pretrained [models](api/models) that can be used as building blocks, and combined with schedulers, for creating your own end-to-end diffusion systems.
+
+
diff --git a/diffusers/docs/source/en/installation.md b/diffusers/docs/source/en/installation.md
new file mode 100644
index 0000000000000000000000000000000000000000..3bf1d46fd0c7e05ccf9c0a066e72191da4c31b9e
--- /dev/null
+++ b/diffusers/docs/source/en/installation.md
@@ -0,0 +1,162 @@
+
+
+# Installation
+
+🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+, and Flax. Follow the installation instructions below for the deep learning library you are using:
+
+- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions
+- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions
+
+## Install with pip
+
+You should install 🤗 Diffusers in a [virtual environment](https://docs.python.org/3/library/venv.html).
+If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
+A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
+
+Start by creating a virtual environment in your project directory:
+
+```bash
+python -m venv .env
+```
+
+Activate the virtual environment:
+
+```bash
+source .env/bin/activate
+```
+
+You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
+
+
+
+```bash
+pip install diffusers["torch"] transformers
+```
+
+
+```bash
+pip install diffusers["flax"] transformers
+```
+
+
+
+## Install with conda
+
+After activating your virtual environment, with `conda` (maintained by the community):
+
+```bash
+conda install -c conda-forge diffusers
+```
+
+## Install from source
+
+Before installing 🤗 Diffusers from source, make sure you have PyTorch and 🤗 Accelerate installed.
+
+To install 🤗 Accelerate:
+
+```bash
+pip install accelerate
+```
+
+Then install 🤗 Diffusers from source:
+
+```bash
+pip install git+https://github.com/huggingface/diffusers
+```
+
+This command installs the bleeding edge `main` version rather than the latest `stable` version.
+The `main` version is useful for staying up-to-date with the latest developments.
+For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet.
+However, this means the `main` version may not always be stable.
+We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day.
+If you run into a problem, please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can fix it even sooner!
+
+## Editable install
+
+You will need an editable install if you'd like to:
+
+* Use the `main` version of the source code.
+* Contribute to 🤗 Diffusers and need to test changes in the code.
+
+Clone the repository and install 🤗 Diffusers with the following commands:
+
+```bash
+git clone https://github.com/huggingface/diffusers.git
+cd diffusers
+```
+
+
+
+```bash
+pip install -e ".[torch]"
+```
+
+
+```bash
+pip install -e ".[flax]"
+```
+
+
+
+These commands will link the folder you cloned the repository to and your Python library paths.
+Python will now look inside the folder you cloned to in addition to the normal library paths.
+For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
+
+
+
+You must keep the `diffusers` folder if you want to keep using the library.
+
+
+
+Now you can easily update your clone to the latest version of 🤗 Diffusers with the following command:
+
+```bash
+cd ~/diffusers/
+git pull
+```
+
+Your Python environment will find the `main` version of 🤗 Diffusers on the next run.
+
+## Cache
+
+Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
+
+Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
+
+```shell
+export HF_HUB_OFFLINE=True
+```
+
+For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
+
+## Telemetry logging
+
+Our library gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
+The data gathered includes the version of 🤗 Diffusers and PyTorch/Flax, the requested model or pipeline class,
+and the path to a pretrained checkpoint if it is hosted on the Hugging Face Hub.
+This usage data helps us debug issues and prioritize new features.
+Telemetry is only sent when loading models and pipelines from the Hub,
+and it is not collected if you're loading local files.
+
+We understand that not everyone wants to share additional information,and we respect your privacy.
+You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
+
+On Linux/MacOS:
+```bash
+export DISABLE_TELEMETRY=YES
+```
+
+On Windows:
+```bash
+set DISABLE_TELEMETRY=YES
+```
diff --git a/diffusers/docs/source/en/optimization/coreml.md b/diffusers/docs/source/en/optimization/coreml.md
new file mode 100644
index 0000000000000000000000000000000000000000..62809305bfb02c77e0306b5574c7cba8c7722d93
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/coreml.md
@@ -0,0 +1,164 @@
+
+
+# How to run Stable Diffusion with Core ML
+
+[Core ML](https://developer.apple.com/documentation/coreml) is the model format and machine learning library supported by Apple frameworks. If you are interested in running Stable Diffusion models inside your macOS or iOS/iPadOS apps, this guide will show you how to convert existing PyTorch checkpoints into the Core ML format and use them for inference with Python or Swift.
+
+Core ML models can leverage all the compute engines available in Apple devices: the CPU, the GPU, and the Apple Neural Engine (or ANE, a tensor-optimized accelerator available in Apple Silicon Macs and modern iPhones/iPads). Depending on the model and the device it's running on, Core ML can mix and match compute engines too, so some portions of the model may run on the CPU while others run on GPU, for example.
+
+
+
+You can also run the `diffusers` Python codebase on Apple Silicon Macs using the `mps` accelerator built into PyTorch. This approach is explained in depth in [the mps guide](mps), but it is not compatible with native apps.
+
+
+
+## Stable Diffusion Core ML Checkpoints
+
+Stable Diffusion weights (or checkpoints) are stored in the PyTorch format, so you need to convert them to the Core ML format before we can use them inside native apps.
+
+Thankfully, Apple engineers developed [a conversion tool](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) based on `diffusers` to convert the PyTorch checkpoints to Core ML.
+
+Before you convert a model, though, take a moment to explore the Hugging Face Hub – chances are the model you're interested in is already available in Core ML format:
+
+- the [Apple](https://huggingface.co/apple) organization includes Stable Diffusion versions 1.4, 1.5, 2.0 base, and 2.1 base
+- [coreml community](https://huggingface.co/coreml-community) includes custom finetuned models
+- use this [filter](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes) to return all available Core ML checkpoints
+
+If you can't find the model you're interested in, we recommend you follow the instructions for [Converting Models to Core ML](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) by Apple.
+
+## Selecting the Core ML Variant to Use
+
+Stable Diffusion models can be converted to different Core ML variants intended for different purposes:
+
+- The type of attention blocks used. The attention operation is used to "pay attention" to the relationship between different areas in the image representations and to understand how the image and text representations are related. Attention is compute- and memory-intensive, so different implementations exist that consider the hardware characteristics of different devices. For Core ML Stable Diffusion models, there are two attention variants:
+ * `split_einsum` ([introduced by Apple](https://machinelearning.apple.com/research/neural-engine-transformers)) is optimized for ANE devices, which is available in modern iPhones, iPads and M-series computers.
+ * The "original" attention (the base implementation used in `diffusers`) is only compatible with CPU/GPU and not ANE. It can be *faster* to run your model on CPU + GPU using `original` attention than ANE. See [this performance benchmark](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks) as well as some [additional measures provided by the community](https://github.com/huggingface/swift-coreml-diffusers/issues/31) for additional details.
+
+- The supported inference framework.
+ * `packages` are suitable for Python inference. This can be used to test converted Core ML models before attempting to integrate them inside native apps, or if you want to explore Core ML performance but don't need to support native apps. For example, an application with a web UI could perfectly use a Python Core ML backend.
+ * `compiled` models are required for Swift code. The `compiled` models in the Hub split the large UNet model weights into several files for compatibility with iOS and iPadOS devices. This corresponds to the [`--chunk-unet` conversion option](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml). If you want to support native apps, then you need to select the `compiled` variant.
+
+The official Core ML Stable Diffusion [models](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main) include these variants, but the community ones may vary:
+
+```
+coreml-stable-diffusion-v1-4
+├── README.md
+├── original
+│ ├── compiled
+│ └── packages
+└── split_einsum
+ ├── compiled
+ └── packages
+```
+
+You can download and use the variant you need as shown below.
+
+## Core ML Inference in Python
+
+Install the following libraries to run Core ML inference in Python:
+
+```bash
+pip install huggingface_hub
+pip install git+https://github.com/apple/ml-stable-diffusion
+```
+
+### Download the Model Checkpoints
+
+To run inference in Python, use one of the versions stored in the `packages` folders because the `compiled` ones are only compatible with Swift. You may choose whether you want to use `original` or `split_einsum` attention.
+
+This is how you'd download the `original` attention variant from the Hub to a directory called `models`:
+
+```Python
+from huggingface_hub import snapshot_download
+from pathlib import Path
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/packages"
+
+model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False)
+print(f"Model downloaded at {model_path}")
+```
+
+### Inference[[python-inference]]
+
+Once you have downloaded a snapshot of the model, you can test it using Apple's Python script.
+
+```shell
+python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i models/coreml-stable-diffusion-v1-4_original_packages -o --compute-unit CPU_AND_GPU --seed 93
+```
+
+Pass the path of the downloaded checkpoint with `-i` flag to the script. `--compute-unit` indicates the hardware you want to allow for inference. It must be one of the following options: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. You may also provide an optional output path, and a seed for reproducibility.
+
+The inference script assumes you're using the original version of the Stable Diffusion model, `CompVis/stable-diffusion-v1-4`. If you use another model, you *have* to specify its Hub id in the inference command line, using the `--model-version` option. This works for models already supported and custom models you trained or fine-tuned yourself.
+
+For example, if you want to use [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5):
+
+```shell
+python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version runwayml/stable-diffusion-v1-5
+```
+
+## Core ML inference in Swift
+
+Running inference in Swift is slightly faster than in Python because the models are already compiled in the `mlmodelc` format. This is noticeable on app startup when the model is loaded but shouldn’t be noticeable if you run several generations afterward.
+
+### Download
+
+To run inference in Swift on your Mac, you need one of the `compiled` checkpoint versions. We recommend you download them locally using Python code similar to the previous example, but with one of the `compiled` variants:
+
+```Python
+from huggingface_hub import snapshot_download
+from pathlib import Path
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/compiled"
+
+model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False)
+print(f"Model downloaded at {model_path}")
+```
+
+### Inference[[swift-inference]]
+
+To run inference, please clone Apple's repo:
+
+```bash
+git clone https://github.com/apple/ml-stable-diffusion
+cd ml-stable-diffusion
+```
+
+And then use Apple's command line tool, [Swift Package Manager](https://www.swift.org/package-manager/#):
+
+```bash
+swift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all "a photo of an astronaut riding a horse on mars"
+```
+
+You have to specify in `--resource-path` one of the checkpoints downloaded in the previous step, so please make sure it contains compiled Core ML bundles with the extension `.mlmodelc`. The `--compute-units` has to be one of these values: `all`, `cpuOnly`, `cpuAndGPU`, `cpuAndNeuralEngine`.
+
+For more details, please refer to the [instructions in Apple's repo](https://github.com/apple/ml-stable-diffusion).
+
+## Supported Diffusers Features
+
+The Core ML models and inference code don't support many of the features, options, and flexibility of 🧨 Diffusers. These are some of the limitations to keep in mind:
+
+- Core ML models are only suitable for inference. They can't be used for training or fine-tuning.
+- Only two schedulers have been ported to Swift, the default one used by Stable Diffusion and `DPMSolverMultistepScheduler`, which we ported to Swift from our `diffusers` implementation. We recommend you use `DPMSolverMultistepScheduler`, since it produces the same quality in about half the steps.
+- Negative prompts, classifier-free guidance scale, and image-to-image tasks are available in the inference code. Advanced features such as depth guidance, ControlNet, and latent upscalers are not available yet.
+
+Apple's [conversion and inference repo](https://github.com/apple/ml-stable-diffusion) and our own [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) repos are intended as technology demonstrators to enable other developers to build upon.
+
+If you feel strongly about any missing features, please feel free to open a feature request or, better yet, a contribution PR 🙂.
+
+## Native Diffusers Swift app
+
+One easy way to run Stable Diffusion on your own Apple hardware is to use [our open-source Swift repo](https://github.com/huggingface/swift-coreml-diffusers), based on `diffusers` and Apple's conversion and inference repo. You can study the code, compile it with [Xcode](https://developer.apple.com/xcode/) and adapt it for your own needs. For your convenience, there's also a [standalone Mac app in the App Store](https://apps.apple.com/app/diffusers/id1666309574), so you can play with it without having to deal with the code or IDE. If you are a developer and have determined that Core ML is the best solution to build your Stable Diffusion app, then you can use the rest of this guide to get started with your project. We can't wait to see what you'll build 🙂.
diff --git a/diffusers/docs/source/en/optimization/fp16.md b/diffusers/docs/source/en/optimization/fp16.md
new file mode 100644
index 0000000000000000000000000000000000000000..61bc5569c53c35223fb99c49331f9167dadc9651
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/fp16.md
@@ -0,0 +1,68 @@
+
+
+# Speed up inference
+
+There are several ways to optimize 🤗 Diffusers for inference speed. As a general rule of thumb, we recommend using either [xFormers](xformers) or `torch.nn.functional.scaled_dot_product_attention` in PyTorch 2.0 for their memory-efficient attention.
+
+
+
+In many cases, optimizing for speed or memory leads to improved performance in the other, so you should try to optimize for both whenever you can. This guide focuses on inference speed, but you can learn more about preserving memory in the [Reduce memory usage](memory) guide.
+
+
+
+The results below are obtained from generating a single 512x512 image from the prompt `a photo of an astronaut riding a horse on mars` with 50 DDIM steps on a Nvidia Titan RTX, demonstrating the speed-up you can expect.
+
+| | latency | speed-up |
+| ---------------- | ------- | ------- |
+| original | 9.50s | x1 |
+| fp16 | 3.61s | x2.63 |
+| channels last | 3.30s | x2.88 |
+| traced UNet | 3.21s | x2.96 |
+| memory efficient attention | 2.63s | x3.61 |
+
+## Use TensorFloat-32
+
+On Ampere and later CUDA devices, matrix multiplications and convolutions can use the [TensorFloat-32 (TF32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode for faster, but slightly less accurate computations. By default, PyTorch enables TF32 mode for convolutions but not matrix multiplications. Unless your network requires full float32 precision, we recommend enabling TF32 for matrix multiplications. It can significantly speeds up computations with typically negligible loss in numerical accuracy.
+
+```python
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = True
+```
+
+You can learn more about TF32 in the [Mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#tf32) guide.
+
+## Half-precision weights
+
+To save GPU memory and get more speed, try loading and running the model weights directly in half-precision or float16:
+
+```Python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+```
+
+
+
+Don't use [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.
+
+
diff --git a/diffusers/docs/source/en/optimization/habana.md b/diffusers/docs/source/en/optimization/habana.md
new file mode 100644
index 0000000000000000000000000000000000000000..8a06210996f30fc28bdffb63bd756e94d039daef
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/habana.md
@@ -0,0 +1,76 @@
+
+
+# Habana Gaudi
+
+🤗 Diffusers is compatible with Habana Gaudi through 🤗 [Optimum](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion). Follow the [installation](https://docs.habana.ai/en/latest/Installation_Guide/index.html) guide to install the SynapseAI and Gaudi drivers, and then install Optimum Habana:
+
+```bash
+python -m pip install --upgrade-strategy eager optimum[habana]
+```
+
+To generate images with Stable Diffusion 1 and 2 on Gaudi, you need to instantiate two instances:
+
+- [`~optimum.habana.diffusers.GaudiStableDiffusionPipeline`], a pipeline for text-to-image generation.
+- [`~optimum.habana.diffusers.GaudiDDIMScheduler`], a Gaudi-optimized scheduler.
+
+When you initialize the pipeline, you have to specify `use_habana=True` to deploy it on HPUs and to get the fastest possible generation, you should enable **HPU graphs** with `use_hpu_graphs=True`.
+
+Finally, specify a [`~optimum.habana.GaudiConfig`] which can be downloaded from the [Habana](https://huggingface.co/Habana) organization on the Hub.
+
+```python
+from optimum.habana import GaudiConfig
+from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
+
+model_name = "stabilityai/stable-diffusion-2-base"
+scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
+pipeline = GaudiStableDiffusionPipeline.from_pretrained(
+ model_name,
+ scheduler=scheduler,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config="Habana/stable-diffusion-2",
+)
+```
+
+Now you can call the pipeline to generate images by batches from one or several prompts:
+
+```python
+outputs = pipeline(
+ prompt=[
+ "High quality photo of an astronaut riding a horse in space",
+ "Face of a yellow cat, high resolution, sitting on a park bench",
+ ],
+ num_images_per_prompt=10,
+ batch_size=4,
+)
+```
+
+For more information, check out 🤗 Optimum Habana's [documentation](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion) and the [example](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) provided in the official GitHub repository.
+
+## Benchmark
+
+We benchmarked Habana's first-generation Gaudi and Gaudi2 with the [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) and [Habana/stable-diffusion-2](https://huggingface.co/Habana/stable-diffusion-2) Gaudi configurations (mixed precision bf16/fp32) to demonstrate their performance.
+
+For [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) on 512x512 images:
+
+| | Latency (batch size = 1) | Throughput |
+| ---------------------- |:------------------------:|:---------------------------:|
+| first-generation Gaudi | 3.80s | 0.308 images/s (batch size = 8) |
+| Gaudi2 | 1.33s | 1.081 images/s (batch size = 8) |
+
+For [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) on 768x768 images:
+
+| | Latency (batch size = 1) | Throughput |
+| ---------------------- |:------------------------:|:-------------------------------:|
+| first-generation Gaudi | 10.2s | 0.108 images/s (batch size = 4) |
+| Gaudi2 | 3.17s | 0.379 images/s (batch size = 8) |
diff --git a/diffusers/docs/source/en/optimization/memory.md b/diffusers/docs/source/en/optimization/memory.md
new file mode 100644
index 0000000000000000000000000000000000000000..42a1bcea8fb5f8199ec3cb1a66f34e32567ac73d
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/memory.md
@@ -0,0 +1,332 @@
+
+
+# Reduce memory usage
+
+A barrier to using diffusion models is the large amount of memory required. To overcome this challenge, there are several memory-reducing techniques you can use to run even some of the largest models on free-tier or consumer GPUs. Some of these techniques can even be combined to further reduce memory usage.
+
+
+
+In many cases, optimizing for memory or speed leads to improved performance in the other, so you should try to optimize for both whenever you can. This guide focuses on minimizing memory usage, but you can also learn more about how to [Speed up inference](fp16).
+
+
+
+The results below are obtained from generating a single 512x512 image from the prompt a photo of an astronaut riding a horse on mars with 50 DDIM steps on a Nvidia Titan RTX, demonstrating the speed-up you can expect as a result of reduced memory consumption.
+
+| | latency | speed-up |
+| ---------------- | ------- | ------- |
+| original | 9.50s | x1 |
+| fp16 | 3.61s | x2.63 |
+| channels last | 3.30s | x2.88 |
+| traced UNet | 3.21s | x2.96 |
+| memory-efficient attention | 2.63s | x3.61 |
+
+## Sliced VAE
+
+Sliced VAE enables decoding large batches of images with limited VRAM or batches with 32 images or more by decoding the batches of latents one image at a time. You'll likely want to couple this with [`~ModelMixin.enable_xformers_memory_efficient_attention`] to reduce memory use further if you have xFormers installed.
+
+To use sliced VAE, call [`~StableDiffusionPipeline.enable_vae_slicing`] on your pipeline before inference:
+
+```python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_vae_slicing()
+#pipe.enable_xformers_memory_efficient_attention()
+images = pipe([prompt] * 32).images
+```
+
+You may see a small performance boost in VAE decoding on multi-image batches, and there should be no performance impact on single-image batches.
+
+## Tiled VAE
+
+Tiled VAE processing also enables working with large images on limited VRAM (for example, generating 4k images on 8GB of VRAM) by splitting the image into overlapping tiles, decoding the tiles, and then blending the outputs together to compose the final image. You should also used tiled VAE with [`~ModelMixin.enable_xformers_memory_efficient_attention`] to reduce memory use further if you have xFormers installed.
+
+To use tiled VAE processing, call [`~StableDiffusionPipeline.enable_vae_tiling`] on your pipeline before inference:
+
+```python
+import torch
+from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+prompt = "a beautiful landscape photograph"
+pipe.enable_vae_tiling()
+#pipe.enable_xformers_memory_efficient_attention()
+
+image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]
+```
+
+The output image has some tile-to-tile tone variation because the tiles are decoded separately, but you shouldn't see any sharp and obvious seams between the tiles. Tiling is turned off for images that are 512x512 or smaller.
+
+## CPU offloading
+
+Offloading the weights to the CPU and only loading them on the GPU when performing the forward pass can also save memory. Often, this technique can reduce memory consumption to less than 3GB.
+
+To perform CPU offloading, call [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]:
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_sequential_cpu_offload()
+image = pipe(prompt).images[0]
+```
+
+CPU offloading works on submodules rather than whole models. This is the best way to minimize memory consumption, but inference is much slower due to the iterative nature of the diffusion process. The UNet component of the pipeline runs several times (as many as `num_inference_steps`); each time, the different UNet submodules are sequentially onloaded and offloaded as needed, resulting in a large number of memory transfers.
+
+
+
+Consider using [model offloading](#model-offloading) if you want to optimize for speed because it is much faster. The tradeoff is your memory savings won't be as large.
+
+
+
+
+
+When using [`~StableDiffusionPipeline.enable_sequential_cpu_offload`], don't move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal (see this [issue](https://github.com/huggingface/diffusers/issues/1934) for more information).
+
+[`~StableDiffusionPipeline.enable_sequential_cpu_offload`] is a stateful operation that installs hooks on the models.
+
+
+
+## Model offloading
+
+
+
+Model offloading requires 🤗 Accelerate version 0.17.0 or higher.
+
+
+
+[Sequential CPU offloading](#cpu-offloading) preserves a lot of memory but it makes inference slower because submodules are moved to GPU as needed, and they're immediately returned to the CPU when a new module runs.
+
+Full-model offloading is an alternative that moves whole models to the GPU, instead of handling each model's constituent *submodules*. There is a negligible impact on inference time (compared with moving the pipeline to `cuda`), and it still provides some memory savings.
+
+During model offloading, only one of the main components of the pipeline (typically the text encoder, UNet and VAE)
+is placed on the GPU while the others wait on the CPU. Components like the UNet that run for multiple iterations stay on the GPU until they're no longer needed.
+
+Enable model offloading by calling [`~StableDiffusionPipeline.enable_model_cpu_offload`] on the pipeline:
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_model_cpu_offload()
+image = pipe(prompt).images[0]
+```
+
+
+
+In order to properly offload models after they're called, it is required to run the entire pipeline and models are called in the pipeline's expected order. Exercise caution if models are reused outside the context of the pipeline after hooks have been installed. See [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) for more information.
+
+[`~StableDiffusionPipeline.enable_model_cpu_offload`] is a stateful operation that installs hooks on the models and state on the pipeline.
+
+
+
+## Channels-last memory format
+
+The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
+
+For example, to set the pipeline's UNet to use the channels-last format:
+
+```python
+print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
+pipe.unet.to(memory_format=torch.channels_last) # in-place operation
+print(
+ pipe.unet.conv_out.state_dict()["weight"].stride()
+) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
+```
+
+## Tracing
+
+Tracing runs an example input tensor through the model and captures the operations that are performed on it as that input makes its way through the model's layers. The executable or `ScriptFunction` that is returned is optimized with just-in-time compilation.
+
+To trace a UNet:
+
+```python
+import time
+import torch
+from diffusers import StableDiffusionPipeline
+import functools
+
+# torch disable grad
+torch.set_grad_enabled(False)
+
+# set variables
+n_experiments = 2
+unet_runs_per_experiment = 50
+
+
+# load inputs
+def generate_inputs():
+ sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
+ timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
+ encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
+ return sample, timestep, encoder_hidden_states
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+unet = pipe.unet
+unet.eval()
+unet.to(memory_format=torch.channels_last) # use channels_last memory format
+unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
+
+# warmup
+for _ in range(3):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet(*inputs)
+
+# trace
+print("tracing..")
+unet_traced = torch.jit.trace(unet, inputs)
+unet_traced.eval()
+print("done tracing")
+
+
+# warmup and optimize graph
+for _ in range(5):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet_traced(*inputs)
+
+
+# benchmarking
+with torch.inference_mode():
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet_traced(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet inference took {time.time() - start_time:.2f} seconds")
+
+# save the model
+unet_traced.save("unet_traced.pt")
+```
+
+Replace the `unet` attribute of the pipeline with the traced model:
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+from dataclasses import dataclass
+
+
+@dataclass
+class UNet2DConditionOutput:
+ sample: torch.FloatTensor
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+
+# use jitted unet
+unet_traced = torch.jit.load("unet_traced.pt")
+
+
+# del pipe.unet
+class TracedUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.in_channels = pipe.unet.config.in_channels
+ self.device = pipe.unet.device
+
+ def forward(self, latent_model_input, t, encoder_hidden_states):
+ sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
+ return UNet2DConditionOutput(sample=sample)
+
+
+pipe.unet = TracedUNet()
+
+with torch.inference_mode():
+ image = pipe([prompt] * 1, num_inference_steps=50).images[0]
+```
+
+## Memory-efficient attention
+
+Recent work on optimizing bandwidth in the attention block has generated huge speed-ups and reductions in GPU memory usage. The most recent type of memory-efficient attention is [Flash Attention](https://arxiv.org/abs/2205.14135) (you can check out the original code at [HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)).
+
+
+
+If you have PyTorch >= 2.0 installed, you should not expect a speed-up for inference when enabling `xformers`.
+
+
+
+To use Flash Attention, install the following:
+
+- PyTorch > 1.12
+- CUDA available
+- [xFormers](xformers)
+
+Then call [`~ModelMixin.enable_xformers_memory_efficient_attention`] on the pipeline:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+
+pipe.enable_xformers_memory_efficient_attention()
+
+with torch.inference_mode():
+ sample = pipe("a small cat")
+
+# optional: You can disable it via
+# pipe.disable_xformers_memory_efficient_attention()
+```
+
+The iteration speed when using `xformers` should match the iteration speed of PyTorch 2.0 as described [here](torch2.0).
diff --git a/diffusers/docs/source/en/optimization/mps.md b/diffusers/docs/source/en/optimization/mps.md
new file mode 100644
index 0000000000000000000000000000000000000000..f5ce3332fc90c7637fa18842a760ccd0b300ca0f
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/mps.md
@@ -0,0 +1,74 @@
+
+
+# Metal Performance Shaders (MPS)
+
+🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:
+
+- macOS computer with Apple silicon (M1/M2) hardware
+- macOS 12.6 or later (13.0 or later recommended)
+- arm64 version of Python
+- [PyTorch 2.0](https://pytorch.org/get-started/locally/) (recommended) or 1.13 (minimum version supported for `mps`)
+
+The `mps` backend uses PyTorch's `.to()` interface to move the Stable Diffusion pipeline on to your M1 or M2 device:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipe = pipe.to("mps")
+
+# Recommended if your computer has < 64 GB of RAM
+pipe.enable_attention_slicing()
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+image
+```
+
+
+
+Generating multiple prompts in a batch can [crash](https://github.com/huggingface/diffusers/issues/363) or fail to work reliably. We believe this is related to the [`mps`](https://github.com/pytorch/pytorch/issues/84039) backend in PyTorch. While this is being investigated, you should iterate instead of batching.
+
+
+
+If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an additional one-time pass through it. This is a temporary workaround for an issue where the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and after just one inference step you can discard the result.
+
+```diff
+ from diffusers import DiffusionPipeline
+
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("mps")
+ pipe.enable_attention_slicing()
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ # First-time "warmup" pass if PyTorch version is 1.13
++ _ = pipe(prompt, num_inference_steps=1)
+
+ # Results match those from the CPU device after the warmup pass.
+ image = pipe(prompt).images[0]
+```
+
+## Troubleshoot
+
+M1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.
+
+To prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("mps")
+pipeline.enable_attention_slicing()
+```
+
+Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.
diff --git a/diffusers/docs/source/en/optimization/onnx.md b/diffusers/docs/source/en/optimization/onnx.md
new file mode 100644
index 0000000000000000000000000000000000000000..4d352480a007f2859128a3da162fa67c9946a92b
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/onnx.md
@@ -0,0 +1,86 @@
+
+
+# ONNX Runtime
+
+🤗 [Optimum](https://github.com/huggingface/optimum) provides a Stable Diffusion pipeline compatible with ONNX Runtime. You'll need to install 🤗 Optimum with the following command for ONNX Runtime support:
+
+```bash
+pip install -q optimum["onnxruntime"]
+```
+
+This guide will show you how to use the Stable Diffusion and Stable Diffusion XL (SDXL) pipelines with ONNX Runtime.
+
+## Stable Diffusion
+
+To load and run inference, use the [`~optimum.onnxruntime.ORTStableDiffusionPipeline`]. If you want to load a PyTorch model and convert it to the ONNX format on-the-fly, set `export=True`:
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipeline = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True)
+prompt = "sailing ship in storm by Leonardo da Vinci"
+image = pipeline(prompt).images[0]
+pipeline.save_pretrained("./onnx-stable-diffusion-v1-5")
+```
+
+
+
+Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.
+
+
+
+To export the pipeline in the ONNX format offline and use it later for inference,
+use the [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command:
+
+```bash
+optimum-cli export onnx --model runwayml/stable-diffusion-v1-5 sd_v15_onnx/
+```
+
+Then to perform inference (you don't have to specify `export=True` again):
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionPipeline
+
+model_id = "sd_v15_onnx"
+pipeline = ORTStableDiffusionPipeline.from_pretrained(model_id)
+prompt = "sailing ship in storm by Leonardo da Vinci"
+image = pipeline(prompt).images[0]
+```
+
+
+
+
+
+You can find more examples in 🤗 Optimum [documentation](https://huggingface.co/docs/optimum/), and Stable Diffusion is supported for text-to-image, image-to-image, and inpainting.
+
+## Stable Diffusion XL
+
+To load and run inference with SDXL, use the [`~optimum.onnxruntime.ORTStableDiffusionXLPipeline`]:
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionXLPipeline
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipeline = ORTStableDiffusionXLPipeline.from_pretrained(model_id)
+prompt = "sailing ship in storm by Leonardo da Vinci"
+image = pipeline(prompt).images[0]
+```
+
+To export the pipeline in the ONNX format and use it later for inference, use the [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command:
+
+```bash
+optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl sd_xl_onnx/
+```
+
+SDXL in the ONNX format is supported for text-to-image and image-to-image.
diff --git a/diffusers/docs/source/en/optimization/open_vino.md b/diffusers/docs/source/en/optimization/open_vino.md
new file mode 100644
index 0000000000000000000000000000000000000000..29299786118adbb8e7ef95c4790026a67afae418
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/open_vino.md
@@ -0,0 +1,80 @@
+
+
+# OpenVINO
+
+🤗 [Optimum](https://github.com/huggingface/optimum-intel) provides Stable Diffusion pipelines compatible with OpenVINO to perform inference on a variety of Intel processors (see the [full list](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) of supported devices).
+
+You'll need to install 🤗 Optimum Intel with the `--upgrade-strategy eager` option to ensure [`optimum-intel`](https://github.com/huggingface/optimum-intel) is using the latest version:
+
+```bash
+pip install --upgrade-strategy eager optimum["openvino"]
+```
+
+This guide will show you how to use the Stable Diffusion and Stable Diffusion XL (SDXL) pipelines with OpenVINO.
+
+## Stable Diffusion
+
+To load and run inference, use the [`~optimum.intel.OVStableDiffusionPipeline`]. If you want to load a PyTorch model and convert it to the OpenVINO format on-the-fly, set `export=True`:
+
+```python
+from optimum.intel import OVStableDiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipeline = OVStableDiffusionPipeline.from_pretrained(model_id, export=True)
+prompt = "sailing ship in storm by Rembrandt"
+image = pipeline(prompt).images[0]
+
+# Don't forget to save the exported model
+pipeline.save_pretrained("openvino-sd-v1-5")
+```
+
+To further speed-up inference, statically reshape the model. If you change any parameters such as the outputs height or width, you’ll need to statically reshape your model again.
+
+```python
+# Define the shapes related to the inputs and desired outputs
+batch_size, num_images, height, width = 1, 1, 512, 512
+
+# Statically reshape the model
+pipeline.reshape(batch_size, height, width, num_images)
+# Compile the model before inference
+pipeline.compile()
+
+image = pipeline(
+ prompt,
+ height=height,
+ width=width,
+ num_images_per_prompt=num_images,
+).images[0]
+```
+
+
+
+
+You can find more examples in the 🤗 Optimum [documentation](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion), and Stable Diffusion is supported for text-to-image, image-to-image, and inpainting.
+
+## Stable Diffusion XL
+
+To load and run inference with SDXL, use the [`~optimum.intel.OVStableDiffusionXLPipeline`]:
+
+```python
+from optimum.intel import OVStableDiffusionXLPipeline
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipeline = OVStableDiffusionXLPipeline.from_pretrained(model_id)
+prompt = "sailing ship in storm by Rembrandt"
+image = pipeline(prompt).images[0]
+```
+
+To further speed-up inference, [statically reshape](#stable-diffusion) the model as shown in the Stable Diffusion section.
+
+You can find more examples in the 🤗 Optimum [documentation](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion-xl), and running SDXL in OpenVINO is supported for text-to-image and image-to-image.
diff --git a/diffusers/docs/source/en/optimization/opt_overview.md b/diffusers/docs/source/en/optimization/opt_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a458291ce5b960fb728758e0db52181ffa4539d
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/opt_overview.md
@@ -0,0 +1,17 @@
+
+
+# Overview
+
+Generating high-quality outputs is computationally intensive, especially during each iterative step where you go from a noisy output to a less noisy output. One of 🤗 Diffuser's goals is to make this technology widely accessible to everyone, which includes enabling fast inference on consumer and specialized hardware.
+
+This section will cover tips and tricks - like half-precision weights and sliced attention - for optimizing inference speed and reducing memory-consumption. You'll also learn how to speed up your PyTorch code with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) or [ONNX Runtime](https://onnxruntime.ai/docs/), and enable memory-efficient attention with [xFormers](https://facebookresearch.github.io/xformers/). There are also guides for running inference on specific hardware like Apple Silicon, and Intel or Habana processors.
diff --git a/diffusers/docs/source/en/optimization/tome.md b/diffusers/docs/source/en/optimization/tome.md
new file mode 100644
index 0000000000000000000000000000000000000000..34726a4c79c26c0e394bc583d080a4e92ecc5b40
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/tome.md
@@ -0,0 +1,96 @@
+
+
+# Token merging
+
+[Token merging](https://huggingface.co/papers/2303.17604) (ToMe) merges redundant tokens/patches progressively in the forward pass of a Transformer-based network which can speed-up the inference latency of [`StableDiffusionPipeline`].
+
+Install ToMe from `pip`:
+
+```bash
+pip install tomesd
+```
+
+You can use ToMe from the [`tomesd`](https://github.com/dbolya/tomesd) library with the [`apply_patch`](https://github.com/dbolya/tomesd?tab=readme-ov-file#usage) function:
+
+```diff
+ from diffusers import StableDiffusionPipeline
+ import torch
+ import tomesd
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
+ ).to("cuda")
++ tomesd.apply_patch(pipeline, ratio=0.5)
+
+ image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
+```
+
+The `apply_patch` function exposes a number of [arguments](https://github.com/dbolya/tomesd#usage) to help strike a balance between pipeline inference speed and the quality of the generated tokens. The most important argument is `ratio` which controls the number of tokens that are merged during the forward pass.
+
+As reported in the [paper](https://huggingface.co/papers/2303.17604), ToMe can greatly preserve the quality of the generated images while boosting inference speed. By increasing the `ratio`, you can speed-up inference even further, but at the cost of some degraded image quality.
+
+To test the quality of the generated images, we sampled a few prompts from [Parti Prompts](https://parti.research.google/) and performed inference with the [`StableDiffusionPipeline`] with the following settings:
+
+
+
+
+
+We didn’t notice any significant decrease in the quality of the generated samples, and you can check out the generated samples in this [WandB report](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=). If you're interested in reproducing this experiment, use this [script](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd).
+
+## Benchmarks
+
+We also benchmarked the impact of `tomesd` on the [`StableDiffusionPipeline`] with [xFormers](https://huggingface.co/docs/diffusers/optimization/xformers) enabled across several image resolutions. The results are obtained from A100 and V100 GPUs in the following development environment:
+
+```bash
+- `diffusers` version: 0.15.1
+- Python version: 3.8.16
+- PyTorch version (GPU?): 1.13.1+cu116 (True)
+- Huggingface_hub version: 0.13.2
+- Transformers version: 4.27.2
+- Accelerate version: 0.18.0
+- xFormers version: 0.0.16
+- tomesd version: 0.1.2
+```
+
+To reproduce this benchmark, feel free to use this [script](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). The results are reported in seconds, and where applicable we report the speed-up percentage over the vanilla pipeline when using ToMe and ToMe + xFormers.
+
+| **GPU** | **Resolution** | **Batch size** | **Vanilla** | **ToMe** | **ToMe + xFormers** |
+|----------|----------------|----------------|-------------|----------------|---------------------|
+| **A100** | 512 | 10 | 6.88 | 5.26 (+23.55%) | 4.69 (+31.83%) |
+| | 768 | 10 | OOM | 14.71 | 11 |
+| | | 8 | OOM | 11.56 | 8.84 |
+| | | 4 | OOM | 5.98 | 4.66 |
+| | | 2 | 4.99 | 3.24 (+35.07%) | 2.1 (+37.88%) |
+| | | 1 | 3.29 | 2.24 (+31.91%) | 2.03 (+38.3%) |
+| | 1024 | 10 | OOM | OOM | OOM |
+| | | 8 | OOM | OOM | OOM |
+| | | 4 | OOM | 12.51 | 9.09 |
+| | | 2 | OOM | 6.52 | 4.96 |
+| | | 1 | 6.4 | 3.61 (+43.59%) | 2.81 (+56.09%) |
+| **V100** | 512 | 10 | OOM | 10.03 | 9.29 |
+| | | 8 | OOM | 8.05 | 7.47 |
+| | | 4 | 5.7 | 4.3 (+24.56%) | 3.98 (+30.18%) |
+| | | 2 | 3.14 | 2.43 (+22.61%) | 2.27 (+27.71%) |
+| | | 1 | 1.88 | 1.57 (+16.49%) | 1.57 (+16.49%) |
+| | 768 | 10 | OOM | OOM | 23.67 |
+| | | 8 | OOM | OOM | 18.81 |
+| | | 4 | OOM | 11.81 | 9.7 |
+| | | 2 | OOM | 6.27 | 5.2 |
+| | | 1 | 5.43 | 3.38 (+37.75%) | 2.82 (+48.07%) |
+| | 1024 | 10 | OOM | OOM | OOM |
+| | | 8 | OOM | OOM | OOM |
+| | | 4 | OOM | OOM | 19.35 |
+| | | 2 | OOM | 13 | 10.78 |
+| | | 1 | OOM | 6.66 | 5.54 |
+
+As seen in the tables above, the speed-up from `tomesd` becomes more pronounced for larger image resolutions. It is also interesting to note that with `tomesd`, it is possible to run the pipeline on a higher resolution like 1024x1024. You may be able to speed-up inference even more with [`torch.compile`](torch2.0).
diff --git a/diffusers/docs/source/en/optimization/torch2.0.md b/diffusers/docs/source/en/optimization/torch2.0.md
new file mode 100644
index 0000000000000000000000000000000000000000..4775fda0fcf91f1e7aed8cc21109952e8f1f79a7
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/torch2.0.md
@@ -0,0 +1,418 @@
+
+
+# PyTorch 2.0
+
+🤗 Diffusers supports the latest optimizations from [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) which include:
+
+1. A memory-efficient attention implementation, scaled dot product attention, without requiring any extra dependencies such as xFormers.
+2. [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), a just-in-time (JIT) compiler to provide an extra performance boost when individual models are compiled.
+
+Both of these optimizations require PyTorch 2.0 or later and 🤗 Diffusers > 0.13.0.
+
+```bash
+pip install --upgrade torch diffusers
+```
+
+## Scaled dot product attention
+
+[`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) (SDPA) is an optimized and memory-efficient attention (similar to xFormers) that automatically enables several other optimizations depending on the model inputs and GPU type. SDPA is enabled by default if you're using PyTorch 2.0 and the latest version of 🤗 Diffusers, so you don't need to add anything to your code.
+
+However, if you want to explicitly enable it, you can set a [`DiffusionPipeline`] to use [`~models.attention_processor.AttnProcessor2_0`]:
+
+```diff
+ import torch
+ from diffusers import DiffusionPipeline
++ from diffusers.models.attention_processor import AttnProcessor2_0
+
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
++ pipe.unet.set_attn_processor(AttnProcessor2_0())
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ image = pipe(prompt).images[0]
+```
+
+SDPA should be as fast and memory efficient as `xFormers`; check the [benchmark](#benchmark) for more details.
+
+In some cases - such as making the pipeline more deterministic or converting it to other formats - it may be helpful to use the vanilla attention processor, [`~models.attention_processor.AttnProcessor`]. To revert to [`~models.attention_processor.AttnProcessor`], call the [`~UNet2DConditionModel.set_default_attn_processor`] function on the pipeline:
+
+```diff
+ import torch
+ from diffusers import DiffusionPipeline
+
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
++ pipe.unet.set_default_attn_processor()
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ image = pipe(prompt).images[0]
+```
+
+## torch.compile
+
+The `torch.compile` function can often provide an additional speed-up to your PyTorch code. In 🤗 Diffusers, it is usually best to wrap the UNet with `torch.compile` because it does most of the heavy lifting in the pipeline.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images[0]
+```
+
+Depending on GPU type, `torch.compile` can provide an *additional speed-up* of **5-300x** on top of SDPA! If you're using more recent GPU architectures such as Ampere (A100, 3090), Ada (4090), and Hopper (H100), `torch.compile` is able to squeeze even more performance out of these GPUs.
+
+Compilation requires some time to complete, so it is best suited for situations where you prepare your pipeline once and then perform the same type of inference operations multiple times. For example, calling the compiled pipeline on a different image size triggers compilation again which can be expensive.
+
+For more information and different options about `torch.compile`, refer to the [`torch_compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) tutorial.
+
+## Benchmark
+
+We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
+
+Expand the dropdown below to find the code used to benchmark each pipeline:
+
+
+
+### Stable Diffusion text-to-image
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+path = "runwayml/stable-diffusion-v1-5"
+
+run_compile = True # Set True / False
+
+pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ images = pipe(prompt=prompt).images
+```
+
+### Stable Diffusion image-to-image
+
+```python
+from diffusers import StableDiffusionImg2ImgPipeline
+from diffusers.utils import load_image
+import torch
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+init_image = load_image(url)
+init_image = init_image.resize((512, 512))
+
+path = "runwayml/stable-diffusion-v1-5"
+
+run_compile = True # Set True / False
+
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ image = pipe(prompt=prompt, image=init_image).images[0]
+```
+
+### Stable Diffusion inpainting
+
+```python
+from diffusers import StableDiffusionInpaintPipeline
+from diffusers.utils import load_image
+import torch
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).resize((512, 512))
+mask_image = load_image(mask_url).resize((512, 512))
+
+path = "runwayml/stable-diffusion-inpainting"
+
+run_compile = True # Set True / False
+
+pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+### ControlNet
+
+```python
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+from diffusers.utils import load_image
+import torch
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+init_image = load_image(url)
+init_image = init_image.resize((512, 512))
+
+path = "runwayml/stable-diffusion-v1-5"
+
+run_compile = True # Set True / False
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ path, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
+)
+
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+pipe.controlnet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ image = pipe(prompt=prompt, image=init_image).images[0]
+```
+
+### DeepFloyd IF text-to-image + upscaling
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+run_compile = True # Set True / False
+
+pipe_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
+pipe_1.to("cuda")
+pipe_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
+pipe_2.to("cuda")
+pipe_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, use_safetensors=True)
+pipe_3.to("cuda")
+
+
+pipe_1.unet.to(memory_format=torch.channels_last)
+pipe_2.unet.to(memory_format=torch.channels_last)
+pipe_3.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ pipe_1.unet = torch.compile(pipe_1.unet, mode="reduce-overhead", fullgraph=True)
+ pipe_2.unet = torch.compile(pipe_2.unet, mode="reduce-overhead", fullgraph=True)
+ pipe_3.unet = torch.compile(pipe_3.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "the blue hulk"
+
+prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
+neg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
+
+for _ in range(3):
+ image_1 = pipe_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
+ image_2 = pipe_2(image=image_1, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
+ image_3 = pipe_3(prompt=prompt, image=image_1, noise_level=100).images
+```
+
+
+The graph below highlights the relative speed-ups for the [`StableDiffusionPipeline`] across five GPU families with PyTorch 2.0 and `torch.compile` enabled. The benchmarks for the following graphs are measured in *number of iterations/second*.
+
+![t2i_speedup](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/t2i_speedup.png)
+
+To give you an even better idea of how this speed-up holds for the other pipelines, consider the following
+graph for an A100 with PyTorch 2.0 and `torch.compile`:
+
+![a100_numbers](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/a100_numbers.png)
+
+In the following tables, we report our findings in terms of the *number of iterations/second*.
+
+### A100 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 21.66 | 23.13 | 44.03 | 49.74 |
+| SD - img2img | 21.81 | 22.40 | 43.92 | 46.32 |
+| SD - inpaint | 22.24 | 23.23 | 43.76 | 49.25 |
+| SD - controlnet | 15.02 | 15.82 | 32.13 | 36.08 |
+| IF | 20.21 / 13.84 / 24.00 | 20.12 / 13.70 / 24.03 | ❌ | 97.34 / 27.23 / 111.66 |
+| SDXL - txt2img | 8.64 | 9.9 | - | - |
+
+### A100 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 11.6 | 13.12 | 14.62 | 17.27 |
+| SD - img2img | 11.47 | 13.06 | 14.66 | 17.25 |
+| SD - inpaint | 11.67 | 13.31 | 14.88 | 17.48 |
+| SD - controlnet | 8.28 | 9.38 | 10.51 | 12.41 |
+| IF | 25.02 | 18.04 | ❌ | 48.47 |
+| SDXL - txt2img | 2.44 | 2.74 | - | - |
+
+### A100 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 3.04 | 3.6 | 3.83 | 4.68 |
+| SD - img2img | 2.98 | 3.58 | 3.83 | 4.67 |
+| SD - inpaint | 3.04 | 3.66 | 3.9 | 4.76 |
+| SD - controlnet | 2.15 | 2.58 | 2.74 | 3.35 |
+| IF | 8.78 | 9.82 | ❌ | 16.77 |
+| SDXL - txt2img | 0.64 | 0.72 | - | - |
+
+### V100 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 18.99 | 19.14 | 20.95 | 22.17 |
+| SD - img2img | 18.56 | 19.18 | 20.95 | 22.11 |
+| SD - inpaint | 19.14 | 19.06 | 21.08 | 22.20 |
+| SD - controlnet | 13.48 | 13.93 | 15.18 | 15.88 |
+| IF | 20.01 / 9.08 / 23.34 | 19.79 / 8.98 / 24.10 | ❌ | 55.75 / 11.57 / 57.67 |
+
+### V100 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 5.96 | 5.89 | 6.83 | 6.86 |
+| SD - img2img | 5.90 | 5.91 | 6.81 | 6.82 |
+| SD - inpaint | 5.99 | 6.03 | 6.93 | 6.95 |
+| SD - controlnet | 4.26 | 4.29 | 4.92 | 4.93 |
+| IF | 15.41 | 14.76 | ❌ | 22.95 |
+
+### V100 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 1.66 | 1.66 | 1.92 | 1.90 |
+| SD - img2img | 1.65 | 1.65 | 1.91 | 1.89 |
+| SD - inpaint | 1.69 | 1.69 | 1.95 | 1.93 |
+| SD - controlnet | 1.19 | 1.19 | OOM after warmup | 1.36 |
+| IF | 5.43 | 5.29 | ❌ | 7.06 |
+
+### T4 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 6.9 | 6.95 | 7.3 | 7.56 |
+| SD - img2img | 6.84 | 6.99 | 7.04 | 7.55 |
+| SD - inpaint | 6.91 | 6.7 | 7.01 | 7.37 |
+| SD - controlnet | 4.89 | 4.86 | 5.35 | 5.48 |
+| IF | 17.42 / 2.47 / 18.52 | 16.96 / 2.45 / 18.69 | ❌ | 24.63 / 2.47 / 23.39 |
+| SDXL - txt2img | 1.15 | 1.16 | - | - |
+
+### T4 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 1.79 | 1.79 | 2.03 | 1.99 |
+| SD - img2img | 1.77 | 1.77 | 2.05 | 2.04 |
+| SD - inpaint | 1.81 | 1.82 | 2.09 | 2.09 |
+| SD - controlnet | 1.34 | 1.27 | 1.47 | 1.46 |
+| IF | 5.79 | 5.61 | ❌ | 7.39 |
+| SDXL - txt2img | 0.288 | 0.289 | - | - |
+
+### T4 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 2.34s | 2.30s | OOM after 2nd iteration | 1.99s |
+| SD - img2img | 2.35s | 2.31s | OOM after warmup | 2.00s |
+| SD - inpaint | 2.30s | 2.26s | OOM after 2nd iteration | 1.95s |
+| SD - controlnet | OOM after 2nd iteration | OOM after 2nd iteration | OOM after warmup | OOM after warmup |
+| IF * | 1.44 | 1.44 | ❌ | 1.94 |
+| SDXL - txt2img | OOM | OOM | - | - |
+
+### RTX 3090 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 22.56 | 22.84 | 23.84 | 25.69 |
+| SD - img2img | 22.25 | 22.61 | 24.1 | 25.83 |
+| SD - inpaint | 22.22 | 22.54 | 24.26 | 26.02 |
+| SD - controlnet | 16.03 | 16.33 | 17.38 | 18.56 |
+| IF | 27.08 / 9.07 / 31.23 | 26.75 / 8.92 / 31.47 | ❌ | 68.08 / 11.16 / 65.29 |
+
+### RTX 3090 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 6.46 | 6.35 | 7.29 | 7.3 |
+| SD - img2img | 6.33 | 6.27 | 7.31 | 7.26 |
+| SD - inpaint | 6.47 | 6.4 | 7.44 | 7.39 |
+| SD - controlnet | 4.59 | 4.54 | 5.27 | 5.26 |
+| IF | 16.81 | 16.62 | ❌ | 21.57 |
+
+### RTX 3090 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 1.7 | 1.69 | 1.93 | 1.91 |
+| SD - img2img | 1.68 | 1.67 | 1.93 | 1.9 |
+| SD - inpaint | 1.72 | 1.71 | 1.97 | 1.94 |
+| SD - controlnet | 1.23 | 1.22 | 1.4 | 1.38 |
+| IF | 5.01 | 5.00 | ❌ | 6.33 |
+
+### RTX 4090 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 40.5 | 41.89 | 44.65 | 49.81 |
+| SD - img2img | 40.39 | 41.95 | 44.46 | 49.8 |
+| SD - inpaint | 40.51 | 41.88 | 44.58 | 49.72 |
+| SD - controlnet | 29.27 | 30.29 | 32.26 | 36.03 |
+| IF | 69.71 / 18.78 / 85.49 | 69.13 / 18.80 / 85.56 | ❌ | 124.60 / 26.37 / 138.79 |
+| SDXL - txt2img | 6.8 | 8.18 | - | - |
+
+### RTX 4090 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 12.62 | 12.84 | 15.32 | 15.59 |
+| SD - img2img | 12.61 | 12,.79 | 15.35 | 15.66 |
+| SD - inpaint | 12.65 | 12.81 | 15.3 | 15.58 |
+| SD - controlnet | 9.1 | 9.25 | 11.03 | 11.22 |
+| IF | 31.88 | 31.14 | ❌ | 43.92 |
+| SDXL - txt2img | 2.19 | 2.35 | - | - |
+
+### RTX 4090 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 3.17 | 3.2 | 3.84 | 3.85 |
+| SD - img2img | 3.16 | 3.2 | 3.84 | 3.85 |
+| SD - inpaint | 3.17 | 3.2 | 3.85 | 3.85 |
+| SD - controlnet | 2.23 | 2.3 | 2.7 | 2.75 |
+| IF | 9.26 | 9.2 | ❌ | 13.31 |
+| SDXL - txt2img | 0.52 | 0.53 | - | - |
+
+## Notes
+
+* Follow this [PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.
+* For the DeepFloyd IF pipeline where batch sizes > 1, we only used a batch size of > 1 in the first IF pipeline for text-to-image generation and NOT for upscaling. That means the two upscaling pipelines received a batch size of 1.
+
+*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*
diff --git a/diffusers/docs/source/en/optimization/xformers.md b/diffusers/docs/source/en/optimization/xformers.md
new file mode 100644
index 0000000000000000000000000000000000000000..e5aa4d106ad2e211c0484acb7ebe588ea8e72ec9
--- /dev/null
+++ b/diffusers/docs/source/en/optimization/xformers.md
@@ -0,0 +1,35 @@
+
+
+# xFormers
+
+We recommend [xFormers](https://github.com/facebookresearch/xformers) for both inference and training. In our tests, the optimizations performed in the attention blocks allow for both faster speed and reduced memory consumption.
+
+Install xFormers from `pip`:
+
+```bash
+pip install xformers
+```
+
+
+
+The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).
+
+
+
+After xFormers is installed, you can use `enable_xformers_memory_efficient_attention()` for faster inference and reduced memory consumption as shown in this [section](memory#memory-efficient-attention).
+
+
+
+According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.
+
+
diff --git a/diffusers/docs/source/en/quicktour.md b/diffusers/docs/source/en/quicktour.md
new file mode 100644
index 0000000000000000000000000000000000000000..89792d5c05b330b0a2dbea5067940ae8a785ab4b
--- /dev/null
+++ b/diffusers/docs/source/en/quicktour.md
@@ -0,0 +1,320 @@
+
+
+[[open-in-colab]]
+
+# Quicktour
+
+Diffusion models are trained to denoise random Gaussian noise step-by-step to generate a sample of interest, such as an image or audio. This has sparked a tremendous amount of interest in generative AI, and you have probably seen examples of diffusion generated images on the internet. 🧨 Diffusers is a library aimed at making diffusion models widely accessible to everyone.
+
+Whether you're a developer or an everyday user, this quicktour will introduce you to 🧨 Diffusers and help you get up and generating quickly! There are three main components of the library to know about:
+
+* The [`DiffusionPipeline`] is a high-level end-to-end class designed to rapidly generate samples from pretrained diffusion models for inference.
+* Popular pretrained [model](./api/models) architectures and modules that can be used as building blocks for creating diffusion systems.
+* Many different [schedulers](./api/schedulers/overview) - algorithms that control how noise is added for training, and how to generate denoised images during inference.
+
+The quicktour will show you how to use the [`DiffusionPipeline`] for inference, and then walk you through how to combine a model and scheduler to replicate what's happening inside the [`DiffusionPipeline`].
+
+
+
+The quicktour is a simplified version of the introductory 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) to help you get started quickly. If you want to learn more about 🧨 Diffusers' goal, design philosophy, and additional details about its core API, check out the notebook!
+
+
+
+Before you begin, make sure you have all the necessary libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install --upgrade diffusers accelerate transformers
+```
+
+- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training.
+- [🤗 Transformers](https://huggingface.co/docs/transformers/index) is required to run the most popular diffusion models, such as [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
+
+## DiffusionPipeline
+
+The [`DiffusionPipeline`] is the easiest way to use a pretrained diffusion system for inference. It is an end-to-end system containing the model and the scheduler. You can use the [`DiffusionPipeline`] out-of-the-box for many tasks. Take a look at the table below for some supported tasks, and for a complete list of supported tasks, check out the [🧨 Diffusers Summary](./api/pipelines/overview#diffusers-summary) table.
+
+| **Task** | **Description** | **Pipeline**
+|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|
+| Unconditional Image Generation | generate an image from Gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
+| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
+| Text-Guided Image-to-Image Translation | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |
+| Text-Guided Image-Inpainting | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |
+| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2img](./using-diffusers/depth2img) |
+
+Start by creating an instance of a [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
+You can use the [`DiffusionPipeline`] for any [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) stored on the Hugging Face Hub.
+In this quicktour, you'll load the [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint for text-to-image generation.
+
+
+
+For [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) models, please carefully read the [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) first before running the model. 🧨 Diffusers implements a [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) to prevent offensive or harmful content, but the model's improved image generation capabilities can still produce potentially harmful content.
+
+
+
+Load the model with the [`~DiffusionPipeline.from_pretrained`] method:
+
+```python
+>>> from diffusers import DiffusionPipeline
+
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+```
+
+The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. You'll see that the Stable Diffusion pipeline is composed of the [`UNet2DConditionModel`] and [`PNDMScheduler`] among other things:
+
+```py
+>>> pipeline
+StableDiffusionPipeline {
+ "_class_name": "StableDiffusionPipeline",
+ "_diffusers_version": "0.21.4",
+ ...,
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ ...,
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
+
+We strongly recommend running the pipeline on a GPU because the model consists of roughly 1.4 billion parameters.
+You can move the generator object to a GPU, just like you would in PyTorch:
+
+```python
+>>> pipeline.to("cuda")
+```
+
+Now you can pass a text prompt to the `pipeline` to generate an image, and then access the denoised image. By default, the image output is wrapped in a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
+
+```python
+>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
+>>> image
+```
+
+
+
+
+
+Save the image by calling `save`:
+
+```python
+>>> image.save("image_of_squirrel_painting.png")
+```
+
+### Local pipeline
+
+You can also use the pipeline locally. The only difference is you need to download the weights first:
+
+```bash
+!git lfs install
+!git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+Then load the saved weights into the pipeline:
+
+```python
+>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", use_safetensors=True)
+```
+
+Now, you can run the pipeline as you would in the section above.
+
+### Swapping schedulers
+
+Different schedulers come with different denoising speeds and quality trade-offs. The best way to find out which one works best for you is to try them out! One of the main features of 🧨 Diffusers is to allow you to easily switch between schedulers. For example, to replace the default [`PNDMScheduler`] with the [`EulerDiscreteScheduler`], load it with the [`~diffusers.ConfigMixin.from_config`] method:
+
+```py
+>>> from diffusers import EulerDiscreteScheduler
+
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+```
+
+Try generating an image with the new scheduler and see if you notice a difference!
+
+In the next section, you'll take a closer look at the components - the model and scheduler - that make up the [`DiffusionPipeline`] and learn how to use these components to generate an image of a cat.
+
+## Models
+
+Most models take a noisy sample, and at each timestep it predicts the *noise residual* (other models learn to predict the previous sample directly or the velocity or [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)), the difference between a less noisy image and the input image. You can mix and match models to create other diffusion systems.
+
+Models are initiated with the [`~ModelMixin.from_pretrained`] method which also locally caches the model weights so it is faster the next time you load the model. For the quicktour, you'll load the [`UNet2DModel`], a basic unconditional image generation model with a checkpoint trained on cat images:
+
+```py
+>>> from diffusers import UNet2DModel
+
+>>> repo_id = "google/ddpm-cat-256"
+>>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
+```
+
+To access the model parameters, call `model.config`:
+
+```py
+>>> model.config
+```
+
+The model configuration is a 🧊 frozen 🧊 dictionary, which means those parameters can't be changed after the model is created. This is intentional and ensures that the parameters used to define the model architecture at the start remain the same, while other parameters can still be adjusted during inference.
+
+Some of the most important parameters are:
+
+* `sample_size`: the height and width dimension of the input sample.
+* `in_channels`: the number of input channels of the input sample.
+* `down_block_types` and `up_block_types`: the type of down- and upsampling blocks used to create the UNet architecture.
+* `block_out_channels`: the number of output channels of the downsampling blocks; also used in reverse order for the number of input channels of the upsampling blocks.
+* `layers_per_block`: the number of ResNet blocks present in each UNet block.
+
+To use the model for inference, create the image shape with random Gaussian noise. It should have a `batch` axis because the model can receive multiple random noises, a `channel` axis corresponding to the number of input channels, and a `sample_size` axis for the height and width of the image:
+
+```py
+>>> import torch
+
+>>> torch.manual_seed(0)
+
+>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+>>> noisy_sample.shape
+torch.Size([1, 3, 256, 256])
+```
+
+For inference, pass the noisy image and a `timestep` to the model. The `timestep` indicates how noisy the input image is, with more noise at the beginning and less at the end. This helps the model determine its position in the diffusion process, whether it is closer to the start or the end. Use the `sample` method to get the model output:
+
+```py
+>>> with torch.no_grad():
+... noisy_residual = model(sample=noisy_sample, timestep=2).sample
+```
+
+To generate actual examples though, you'll need a scheduler to guide the denoising process. In the next section, you'll learn how to couple a model with a scheduler.
+
+## Schedulers
+
+Schedulers manage going from a noisy sample to a less noisy sample given the model output - in this case, it is the `noisy_residual`.
+
+
+
+🧨 Diffusers is a toolbox for building diffusion systems. While the [`DiffusionPipeline`] is a convenient way to get started with a pre-built diffusion system, you can also choose your own model and scheduler components separately to build a custom diffusion system.
+
+
+
+For the quicktour, you'll instantiate the [`DDPMScheduler`] with its [`~diffusers.ConfigMixin.from_config`] method:
+
+```py
+>>> from diffusers import DDPMScheduler
+
+>>> scheduler = DDPMScheduler.from_pretrained(repo_id)
+>>> scheduler
+DDPMScheduler {
+ "_class_name": "DDPMScheduler",
+ "_diffusers_version": "0.21.4",
+ "beta_end": 0.02,
+ "beta_schedule": "linear",
+ "beta_start": 0.0001,
+ "clip_sample": true,
+ "clip_sample_range": 1.0,
+ "dynamic_thresholding_ratio": 0.995,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "sample_max_value": 1.0,
+ "steps_offset": 0,
+ "thresholding": false,
+ "timestep_spacing": "leading",
+ "trained_betas": null,
+ "variance_type": "fixed_small"
+}
+```
+
+
+
+💡 Unlike a model, a scheduler does not have trainable weights and is parameter-free!
+
+
+
+Some of the most important parameters are:
+
+* `num_train_timesteps`: the length of the denoising process or, in other words, the number of timesteps required to process random Gaussian noise into a data sample.
+* `beta_schedule`: the type of noise schedule to use for inference and training.
+* `beta_start` and `beta_end`: the start and end noise values for the noise schedule.
+
+To predict a slightly less noisy image, pass the following to the scheduler's [`~diffusers.DDPMScheduler.step`] method: model output, `timestep`, and current `sample`.
+
+```py
+>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample
+>>> less_noisy_sample.shape
+torch.Size([1, 3, 256, 256])
+```
+
+The `less_noisy_sample` can be passed to the next `timestep` where it'll get even less noisy! Let's bring it all together now and visualize the entire denoising process.
+
+First, create a function that postprocesses and displays the denoised image as a `PIL.Image`:
+
+```py
+>>> import PIL.Image
+>>> import numpy as np
+
+
+>>> def display_sample(sample, i):
+... image_processed = sample.cpu().permute(0, 2, 3, 1)
+... image_processed = (image_processed + 1.0) * 127.5
+... image_processed = image_processed.numpy().astype(np.uint8)
+
+... image_pil = PIL.Image.fromarray(image_processed[0])
+... display(f"Image at step {i}")
+... display(image_pil)
+```
+
+To speed up the denoising process, move the input and model to a GPU:
+
+```py
+>>> model.to("cuda")
+>>> noisy_sample = noisy_sample.to("cuda")
+```
+
+Now create a denoising loop that predicts the residual of the less noisy sample, and computes the less noisy sample with the scheduler:
+
+```py
+>>> import tqdm
+
+>>> sample = noisy_sample
+
+>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
+... # 1. predict noise residual
+... with torch.no_grad():
+... residual = model(sample, t).sample
+
+... # 2. compute less noisy image and set x_t -> x_t-1
+... sample = scheduler.step(residual, t, sample).prev_sample
+
+... # 3. optionally look at image
+... if (i + 1) % 50 == 0:
+... display_sample(sample, i + 1)
+```
+
+Sit back and watch as a cat is generated from nothing but noise! 😻
+
+
+
+
+
+## Next steps
+
+Hopefully, you generated some cool images with 🧨 Diffusers in this quicktour! For your next steps, you can:
+
+* Train or finetune a model to generate your own images in the [training](./tutorials/basic_training) tutorial.
+* See example official and community [training or finetuning scripts](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) for a variety of use cases.
+* Learn more about loading, accessing, changing, and comparing schedulers in the [Using different Schedulers](./using-diffusers/schedulers) guide.
+* Explore prompt engineering, speed and memory optimizations, and tips and tricks for generating higher-quality images with the [Stable Diffusion](./stable_diffusion) guide.
+* Dive deeper into speeding up 🧨 Diffusers with guides on [optimized PyTorch on a GPU](./optimization/fp16), and inference guides for running [Stable Diffusion on Apple Silicon (M1/M2)](./optimization/mps) and [ONNX Runtime](./optimization/onnx).
diff --git a/diffusers/docs/source/en/stable_diffusion.md b/diffusers/docs/source/en/stable_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..c0298eeeb3c1d5cd6669aed9a796e63b9ea7ae9b
--- /dev/null
+++ b/diffusers/docs/source/en/stable_diffusion.md
@@ -0,0 +1,261 @@
+
+
+# Effective and efficient diffusion
+
+[[open-in-colab]]
+
+Getting the [`DiffusionPipeline`] to generate images in a certain style or include what you want can be tricky. Often times, you have to run the [`DiffusionPipeline`] several times before you end up with an image you're happy with. But generating something out of nothing is a computationally intensive process, especially if you're running inference over and over again.
+
+This is why it's important to get the most *computational* (speed) and *memory* (GPU vRAM) efficiency from the pipeline to reduce the time between inference cycles so you can iterate faster.
+
+This tutorial walks you through how to generate faster and better with the [`DiffusionPipeline`].
+
+Begin by loading the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model:
+
+```python
+from diffusers import DiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+```
+
+The example prompt you'll use is a portrait of an old warrior chief, but feel free to use your own prompt:
+
+```python
+prompt = "portrait photo of a old warrior chief"
+```
+
+## Speed
+
+
+
+💡 If you don't have access to a GPU, you can use one for free from a GPU provider like [Colab](https://colab.research.google.com/)!
+
+
+
+One of the simplest ways to speed up inference is to place the pipeline on a GPU the same way you would with any PyTorch module:
+
+```python
+pipeline = pipeline.to("cuda")
+```
+
+To make sure you can use the same image and improve on it, use a [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed for [reproducibility](./using-diffusers/reproducibility):
+
+```python
+import torch
+
+generator = torch.Generator("cuda").manual_seed(0)
+```
+
+Now you can generate an image:
+
+```python
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+This process took ~30 seconds on a T4 GPU (it might be faster if your allocated GPU is better than a T4). By default, the [`DiffusionPipeline`] runs inference with full `float32` precision for 50 inference steps. You can speed this up by switching to a lower precision like `float16` or running fewer inference steps.
+
+Let's start by loading the model in `float16` and generate an image:
+
+```python
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
+pipeline = pipeline.to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+This time, it only took ~11 seconds to generate the image, which is almost 3x faster than before!
+
+
+
+💡 We strongly suggest always running your pipelines in `float16`, and so far, we've rarely seen any degradation in output quality.
+
+
+
+Another option is to reduce the number of inference steps. Choosing a more efficient scheduler could help decrease the number of steps without sacrificing output quality. You can find which schedulers are compatible with the current model in the [`DiffusionPipeline`] by calling the `compatibles` method:
+
+```python
+pipeline.scheduler.compatibles
+[
+ diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
+ diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
+ diffusers.utils.dummy_torch_and_torchsde_objects.DPMSolverSDEScheduler,
+ diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+]
+```
+
+The Stable Diffusion model uses the [`PNDMScheduler`] by default which usually requires ~50 inference steps, but more performant schedulers like [`DPMSolverMultistepScheduler`], require only ~20 or 25 inference steps. Use the [`~ConfigMixin.from_config`] method to load a new scheduler:
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+```
+
+Now set the `num_inference_steps` to 20:
+
+```python
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+Great, you've managed to cut the inference time to just 4 seconds! ⚡️
+
+## Memory
+
+The other key to improving pipeline performance is consuming less memory, which indirectly implies more speed, since you're often trying to maximize the number of images generated per second. The easiest way to see how many images you can generate at once is to try out different batch sizes until you get an `OutOfMemoryError` (OOM).
+
+Create a function that'll generate a batch of images from a list of prompts and `Generators`. Make sure to assign each `Generator` a seed so you can reuse it if it produces a good result.
+
+```python
+def get_inputs(batch_size=1):
+ generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
+ prompts = batch_size * [prompt]
+ num_inference_steps = 20
+
+ return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
+```
+
+Start with `batch_size=4` and see how much memory you've consumed:
+
+```python
+from diffusers.utils import make_image_grid
+
+images = pipeline(**get_inputs(batch_size=4)).images
+make_image_grid(images, 2, 2)
+```
+
+Unless you have a GPU with more vRAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function:
+
+```python
+pipeline.enable_attention_slicing()
+```
+
+Now try increasing the `batch_size` to 8!
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+Whereas before you couldn't even generate a batch of 4 images, now you can generate a batch of 8 images at ~3.5 seconds per image! This is probably the fastest you can go on a T4 GPU without sacrificing quality.
+
+## Quality
+
+In the last two sections, you learned how to optimize the speed of your pipeline by using `fp16`, reducing the number of inference steps by using a more performant scheduler, and enabling attention slicing to reduce memory consumption. Now you're going to focus on how to improve the quality of generated images.
+
+### Better checkpoints
+
+The most obvious step is to use better checkpoints. The Stable Diffusion model is a good starting point, and since its official launch, several improved versions have also been released. However, using a newer version doesn't automatically mean you'll get better results. You'll still have to experiment with different checkpoints yourself, and do a little research (such as using [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) to get the best results.
+
+As the field grows, there are more and more high-quality checkpoints finetuned to produce certain styles. Try exploring the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) and [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) to find one you're interested in!
+
+### Better pipeline components
+
+You can also try replacing the current pipeline components with a newer version. Let's try loading the latest [autoencoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) from Stability AI into the pipeline, and generate some images:
+
+```python
+from diffusers import AutoencoderKL
+
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
+pipeline.vae = vae
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+### Better prompt engineering
+
+The text prompt you use to generate an image is super important, so much so that it is called *prompt engineering*. Some considerations to keep during prompt engineering are:
+
+- How is the image or similar images of the one I want to generate stored on the internet?
+- What additional detail can I give that steers the model towards the style I want?
+
+With this in mind, let's improve the prompt to include color and higher quality details:
+
+```python
+prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
+prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
+```
+
+Generate a batch of images with the new prompt:
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+Pretty impressive! Let's tweak the second image - corresponding to the `Generator` with a seed of `1` - a bit more by adding some text about the age of the subject:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+
+
+
+## Next steps
+
+In this tutorial, you learned how to optimize a [`DiffusionPipeline`] for computational and memory efficiency as well as improving the quality of generated outputs. If you're interested in making your pipeline even faster, take a look at the following resources:
+
+- Learn how [PyTorch 2.0](./optimization/torch2.0) and [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 5 - 300% faster inference speed. On an A100 GPU, inference can be up to 50% faster!
+- If you can't use PyTorch 2, we recommend you install [xFormers](./optimization/xformers). Its memory-efficient attention mechanism works great with PyTorch 1.13.1 for faster speed and reduced memory consumption.
+- Other optimization techniques, such as model offloading, are covered in [this guide](./optimization/fp16).
diff --git a/diffusers/docs/source/en/training/adapt_a_model.md b/diffusers/docs/source/en/training/adapt_a_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..57bc1a37e05be78149810c73586e63a393b6e341
--- /dev/null
+++ b/diffusers/docs/source/en/training/adapt_a_model.md
@@ -0,0 +1,47 @@
+# Adapt a model to a new task
+
+Many diffusion systems share the same components, allowing you to adapt a pretrained model for one task to an entirely different task.
+
+This guide will show you how to adapt a pretrained text-to-image model for inpainting by initializing and modifying the architecture of a pretrained [`UNet2DConditionModel`].
+
+## Configure UNet2DConditionModel parameters
+
+A [`UNet2DConditionModel`] by default accepts 4 channels in the [input sample](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels). For example, load a pretrained text-to-image model like [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) and take a look at the number of `in_channels`:
+
+```py
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+pipeline.unet.config["in_channels"]
+4
+```
+
+Inpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting):
+
+```py
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_safetensors=True)
+pipeline.unet.config["in_channels"]
+9
+```
+
+To adapt your text-to-image model for inpainting, you'll need to change the number of `in_channels` from 4 to 9.
+
+Initialize a [`UNet2DConditionModel`] with the pretrained text-to-image model weights, and change `in_channels` to 9. Changing the number of `in_channels` means you need to set `ignore_mismatched_sizes=True` and `low_cpu_mem_usage=False` to avoid a size mismatch error because the shape is different now.
+
+```py
+from diffusers import UNet2DConditionModel
+
+model_id = "runwayml/stable-diffusion-v1-5"
+unet = UNet2DConditionModel.from_pretrained(
+ model_id,
+ subfolder="unet",
+ in_channels=9,
+ low_cpu_mem_usage=False,
+ ignore_mismatched_sizes=True,
+ use_safetensors=True,
+)
+```
+
+The pretrained weights of the other components from the text-to-image model are initialized from their checkpoints, but the input channel weights (`conv_in.weight`) of the `unet` are randomly initialized. It is important to finetune the model for inpainting because otherwise the model returns noise.
diff --git a/diffusers/docs/source/en/training/controlnet.md b/diffusers/docs/source/en/training/controlnet.md
new file mode 100644
index 0000000000000000000000000000000000000000..4be2cbc932528b6a83a2eaa0d0e6be90af7075d8
--- /dev/null
+++ b/diffusers/docs/source/en/training/controlnet.md
@@ -0,0 +1,366 @@
+
+
+# ControlNet
+
+[ControlNet](https://hf.co/papers/2302.05543) models are adapters trained on top of another pretrained model. It allows for a greater degree of control over image generation by conditioning the model with an additional input image. The input image can be a canny edge, depth map, human pose, and many more.
+
+If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing or xFormers. You should have a GPU with >30GB of memory if you want to train faster with Flax.
+
+This guide will explore the [train_controlnet.py](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+
+
+```bash
+cd examples/controlnet
+pip install -r requirements.txt
+```
+
+
+
+If you have access to a TPU, the Flax training script runs even faster! Let's run the training script on the [Google Cloud TPU VM](https://cloud.google.com/tpu/docs/run-calculation-jax). Create a single TPU v4-8 VM and connect to it:
+
+```bash
+ZONE=us-central2-b
+TPU_TYPE=v4-8
+VM_NAME=hg_flax
+
+gcloud alpha compute tpus tpu-vm create $VM_NAME \
+ --zone $ZONE \
+ --accelerator-type $TPU_TYPE \
+ --version tpu-vm-v4-base
+
+gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
+```
+
+Install JAX 0.4.5:
+
+```bash
+pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+Then install the required dependencies for the Flax script:
+
+```bash
+cd examples/controlnet
+pip install -r requirements_flax.txt
+```
+
+
+
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L231) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_controlnet.py \
+ --mixed_precision="fp16"
+```
+
+Many of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the relevant parameters for ControlNet:
+
+- `--max_train_samples`: the number of training samples; this can be lowered for faster training, but if you want to stream really large datasets, you'll need to include this parameter and the `--streaming` parameter in your training command
+- `--gradient_accumulation_steps`: number of update steps to accumulate before the backward pass; this allows you to train with a bigger batch size than your GPU memory can typically handle
+
+### Min-SNR weighting
+
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+
+Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
+
+```bash
+accelerate launch train_controlnet.py \
+ --snr_gamma=5.0
+```
+
+## Training script
+
+As with the script parameters, a general walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the relevant parts of the ControlNet script.
+
+The training script has a [`make_train_dataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L582) function for preprocessing the dataset with image transforms and caption tokenization. You'll see that in addition to the usual caption tokenization and image transforms, the script also includes transforms for the conditioning image.
+
+
+
+If you're streaming a dataset on a TPU, performance may be bottlenecked by the 🤗 Datasets library which is not optimized for images. To ensure maximum throughput, you're encouraged to explore other dataset formats like [WebDataset](https://webdataset.github.io/webdataset/), [TorchData](https://github.com/pytorch/data), and [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds).
+
+
+
+```py
+conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+)
+```
+
+Within the [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L713) function, you'll find the code for loading the tokenizer, text encoder, scheduler and models. This is also where the ControlNet model is loaded either from existing weights or randomly initialized from a UNet:
+
+```py
+if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+else:
+ logger.info("Initializing controlnet weights from unet")
+ controlnet = ControlNetModel.from_unet(unet)
+```
+
+The [optimizer](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L871) is set up to update the ControlNet parameters:
+
+```py
+params_to_optimize = controlnet.parameters()
+optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Finally, in the [training loop](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L943), the conditioning text embeddings and image are passed to the down and mid-blocks of the ControlNet model:
+
+```py
+encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+
+down_block_res_samples, mid_block_res_sample = controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+)
+```
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Now you're ready to launch the training script! 🚀
+
+This guide uses the [fusing/fill50k](https://huggingface.co/datasets/fusing/fill50k) dataset, but remember, you can create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).
+
+Set the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model and `OUTPUT_DIR` to where you want to save the model.
+
+Download the following images to condition your training with:
+
+```bash
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+One more thing before you launch the script! Depending on the GPU you have, you may need to enable certain optimizations to train a ControlNet. The default configuration in this script requires ~38GB of vRAM. If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
+
+
+
+
+On a 16GB GPU, you can use bitsandbytes 8-bit optimizer and gradient checkpointing to optimize your training run. Install bitsandbytes:
+
+```py
+pip install bitsandbytes
+```
+
+Then, add the following parameter to your training command:
+
+```bash
+accelerate launch train_controlnet.py \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+```
+
+
+
+
+On a 12GB GPU, you'll need bitsandbytes 8-bit optimizer, gradient checkpointing, xFormers, and set the gradients to `None` instead of zero to reduce your memory-usage.
+
+```bash
+accelerate launch train_controlnet.py \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+```
+
+
+
+
+On a 8GB GPU, you'll need to use [DeepSpeed](https://www.deepspeed.ai/) to offload some of the tensors from the vRAM to either the CPU or NVME to allow training with less GPU memory.
+
+Run the following command to configure your 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+During configuration, confirm that you want to use DeepSpeed stage 2. Now it should be possible to train on under 8GB vRAM by combining DeepSpeed stage 2, fp16 mixed precision, and offloading the model parameters and the optimizer state to the CPU. The drawback is that this requires more system RAM (~25 GB). See the [DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options. Your configuration file should look something like:
+
+```bash
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 4
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+```
+
+You should also change the default Adam optimizer to DeepSpeed’s optimized version of Adam [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) for a substantial speedup. Enabling `DeepSpeedCPUAdam` requires your system’s CUDA toolchain version to be the same as the one installed with PyTorch.
+
+bitsandbytes 8-bit optimizers don’t seem to be compatible with DeepSpeed at the moment.
+
+That's it! You don't need to add any additional parameters to your training command.
+
+
+
+
+
+
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path/to/save/model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --push_to_hub
+```
+
+
+
+
+With Flax, you can [profile your code](https://jax.readthedocs.io/en/latest/profiling.html) by adding the `--profile_steps==5` parameter to your training command. Install the Tensorboard profile plugin:
+
+```bash
+pip install tensorflow tensorboard-plugin-profile
+tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
+```
+
+Then you can inspect the profile at [http://localhost:6006/#profile](http://localhost:6006/#profile).
+
+
+
+If you run into version conflicts with the plugin, try uninstalling and reinstalling all versions of TensorFlow and Tensorboard. The debugging functionality of the profile plugin is still experimental, and not all views are fully functional. The `trace_viewer` cuts off events after 1M, which can result in all your device traces getting lost if for example, you profile the compilation step by accident.
+
+
+
+```bash
+python3 train_controlnet_flax.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=1000 \
+ --train_batch_size=2 \
+ --revision="non-ema" \
+ --from_pt \
+ --report_to="wandb" \
+ --tracker_project_name=$HUB_MODEL_ID \
+ --num_train_epochs=11 \
+ --push_to_hub \
+ --hub_model_id=$HUB_MODEL_ID
+```
+
+
+
+
+Once training is complete, you can use your newly trained model for inference!
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+from diffusers.utils import load_image
+import torch
+
+controlnet = ControlNetModel.from_pretrained("path/to/controlnet", torch_dtype=torch.float16)
+pipeline = StableDiffusionControlNetPipeline.from_pretrained(
+ "path/to/base/model", controlnet=controlnet, torch_dtype=torch.float16
+).to("cuda")
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+generator = torch.manual_seed(0)
+image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
+image.save("./output.png")
+```
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [`train_controlnet_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet_sdxl.py) script to train a ControlNet adapter for the SDXL model.
+
+The SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.
+
+## Next steps
+
+Congratulations on training your own ControlNet! To learn more about how to use your new model, the following guides may be helpful:
+
+- Learn how to [use a ControlNet](../using-diffusers/controlnet) for inference on a variety of tasks.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/create_dataset.md b/diffusers/docs/source/en/training/create_dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..f215d3eb2c1b58fd442e525543b136118a8c0f70
--- /dev/null
+++ b/diffusers/docs/source/en/training/create_dataset.md
@@ -0,0 +1,90 @@
+# Create a dataset for training
+
+There are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](hf.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation.
+
+This guide will show you two ways to create a dataset to finetune on:
+
+- provide a folder of images to the `--train_data_dir` argument
+- upload a dataset to the Hub and pass the dataset repository id to the `--dataset_name` argument
+
+
+
+💡 Learn more about how to create an image dataset for training in the [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset) guide.
+
+
+
+## Provide a dataset as a folder
+
+For unconditional generation, you can provide your own dataset as a folder of images. The training script uses the [`ImageFolder`](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder) builder from 🤗 Datasets to automatically build a dataset from the folder. Your directory structure should look like:
+
+```bash
+data_dir/xxx.png
+data_dir/xxy.png
+data_dir/[...]/xxz.png
+```
+
+Pass the path to the dataset directory to the `--train_data_dir` argument, and then you can start training:
+
+```bash
+accelerate launch train_unconditional.py \
+ --train_data_dir \
+
+```
+
+## Upload your data to the Hub
+
+
+
+💡 For more details and context about creating and uploading a dataset to the Hub, take a look at the [Image search with 🤗 Datasets](https://huggingface.co/blog/image-search-datasets) post.
+
+
+
+Start by creating a dataset with the [`ImageFolder`](https://huggingface.co/docs/datasets/image_load#imagefolder) feature, which creates an `image` column containing the PIL-encoded images.
+
+You can use the `data_dir` or `data_files` parameters to specify the location of the dataset. The `data_files` parameter supports mapping specific files to dataset splits like `train` or `test`:
+
+```python
+from datasets import load_dataset
+
+# example 1: local folder
+dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
+
+# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
+
+# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset(
+ "imagefolder",
+ data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",
+)
+
+# example 4: providing several splits
+dataset = load_dataset(
+ "imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]}
+)
+```
+
+Then use the [`~datasets.Dataset.push_to_hub`] method to upload the dataset to the Hub:
+
+```python
+# assuming you have ran the huggingface-cli login command in a terminal
+dataset.push_to_hub("name_of_your_dataset")
+
+# if you want to push to a private repo, simply pass private=True:
+dataset.push_to_hub("name_of_your_dataset", private=True)
+```
+
+Now the dataset is available for training by passing the dataset name to the `--dataset_name` argument:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
+ --dataset_name="name_of_your_dataset" \
+
+```
+
+## Next steps
+
+Now that you've created a dataset, you can plug it into the `train_data_dir` (if your dataset is local) or `dataset_name` (if your dataset is on the Hub) arguments of a training script.
+
+For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/custom_diffusion.md b/diffusers/docs/source/en/training/custom_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..6601a7a93284acaf30fab10cc666b70f8ed07bb3
--- /dev/null
+++ b/diffusers/docs/source/en/training/custom_diffusion.md
@@ -0,0 +1,363 @@
+
+
+# Custom Diffusion
+
+[Custom Diffusion](https://huggingface.co/papers/2212.04488) is a training technique for personalizing image generation models. Like Textual Inversion, DreamBooth, and LoRA, Custom Diffusion only requires a few (~4-5) example images. This technique works by only training weights in the cross-attention layers, and it uses a special word to represent the newly learned concept. Custom Diffusion is unique because it can also learn multiple concepts at the same time.
+
+If you're training on a GPU with limited vRAM, you should try enabling xFormers with `--enable_xformers_memory_efficient_attention` for faster training with lower vRAM requirements (16GB). To save even more memory, add `--set_grads_to_none` in the training argument to set the gradients to `None` instead of zero (this option can cause some issues, so if you experience any, try removing this parameter).
+
+This guide will explore the [train_custom_diffusion.py](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Navigate to the example folder with the training script and install the required dependencies:
+
+```bash
+cd examples/custom_diffusion
+pip install -r requirements.txt
+pip install clip-retrieval
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training script contains all the parameters to help you customize your training run. These are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L319) function. The function comes with default values, but you can also set your own values in the training command if you'd like.
+
+For example, to change the resolution of the input image:
+
+```bash
+accelerate launch train_custom_diffusion.py \
+ --resolution=256
+```
+
+Many of the basic parameters are described in the [DreamBooth](dreambooth#script-parameters) training guide, so this guide focuses on the parameters unique to Custom Diffusion:
+
+- `--freeze_model`: freezes the key and value parameters in the cross-attention layer; the default is `crossattn_kv`, but you can set it to `crossattn` to train all the parameters in the cross-attention layer
+- `--concepts_list`: to learn multiple concepts, provide a path to a JSON file containing the concepts
+- `--modifier_token`: a special word used to represent the learned concept
+- `--initializer_token`:
+
+### Prior preservation loss
+
+Prior preservation loss is a method that uses a model's own generated samples to help it learn how to generate more diverse images. Because these generated sample images belong to the same class as the images you provided, they help the model retain what it has learned about the class and how it can use what it already knows about the class to make new compositions.
+
+Many of the parameters for prior preservation loss are described in the [DreamBooth](dreambooth#prior-preservation-loss) training guide.
+
+### Regularization
+
+Custom Diffusion includes training the target images with a small set of real images to prevent overfitting. As you can imagine, this can be easy to do when you're only training on a few images! Download 200 real images with `clip_retrieval`. The `class_prompt` should be the same category as the target images. These images are stored in `class_data_dir`.
+
+```bash
+python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
+```
+
+To enable regularization, add the following parameters:
+
+- `--with_prior_preservation`: whether to use prior preservation loss
+- `--prior_loss_weight`: controls the influence of the prior preservation loss on the model
+- `--real_prior`: whether to use a small set of real images to prevent overfitting
+
+```bash
+accelerate launch train_custom_diffusion.py \
+ --with_prior_preservation \
+ --prior_loss_weight=1.0 \
+ --class_data_dir="./real_reg/samples_cat" \
+ --class_prompt="cat" \
+ --real_prior=True \
+```
+
+## Training script
+
+
+
+A lot of the code in the Custom Diffusion training script is similar to the [DreamBooth](dreambooth#training-script) script. This guide instead focuses on the code that is relevant to Custom Diffusion.
+
+
+
+The Custom Diffusion training script has two dataset classes:
+
+- [`CustomDiffusionDataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L165): preprocesses the images, class images, and prompts for training
+- [`PromptDataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L148): prepares the prompts for generating class images
+
+Next, the `modifier_token` is [added to the tokenizer](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L811), converted to token ids, and the token embeddings are resized to account for the new `modifier_token`. Then the `modifier_token` embeddings are initialized with the embeddings of the `initializer_token`. All parameters in the text encoder are frozen, except for the token embeddings since this is what the model is trying to learn to associate with the concepts.
+
+```py
+params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+)
+freeze_params(params_to_freeze)
+```
+
+Now you'll need to add the [Custom Diffusion weights](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L911C3-L911C3) to the attention layers. This is a really important step for getting the shape and size of the attention weights correct, and for setting the appropriate number of attention processors in each UNet block.
+
+```py
+st = unet.state_dict()
+for name, _ in unet.attn_processors.items():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
+ "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
+ }
+ if train_q_out:
+ weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
+ weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
+ weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
+ if cross_attention_dim is not None:
+ custom_diffusion_attn_procs[name] = attention_class(
+ train_kv=train_kv,
+ train_q_out=train_q_out,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ ).to(unet.device)
+ custom_diffusion_attn_procs[name].load_state_dict(weights)
+ else:
+ custom_diffusion_attn_procs[name] = attention_class(
+ train_kv=False,
+ train_q_out=False,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ )
+del st
+unet.set_attn_processor(custom_diffusion_attn_procs)
+custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
+```
+
+The [optimizer](https://github.com/huggingface/diffusers/blob/84cd9e8d01adb47f046b1ee449fc76a0c32dc4e2/examples/custom_diffusion/train_custom_diffusion.py#L982) is initialized to update the cross-attention layer parameters:
+
+```py
+optimizer = optimizer_class(
+ itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters())
+ if args.modifier_token is not None
+ else custom_diffusion_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+In the [training loop](https://github.com/huggingface/diffusers/blob/84cd9e8d01adb47f046b1ee449fc76a0c32dc4e2/examples/custom_diffusion/train_custom_diffusion.py#L1048), it is important to only update the embeddings for the concept you're trying to learn. This means setting the gradients of all the other token embeddings to zero:
+
+```py
+if args.modifier_token is not None:
+ if accelerator.num_processes > 1:
+ grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad
+ else:
+ grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
+ index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
+ for i in range(len(modifier_token_id[1:])):
+ index_grads_to_zero = index_grads_to_zero & (
+ torch.arange(len(tokenizer)) != modifier_token_id[i]
+ )
+ grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
+ index_grads_to_zero, :
+ ].fill_(0)
+```
+
+## Launch the script
+
+Once you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀
+
+In this guide, you'll download and use these example [cat images](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip). You can also create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).
+
+Set the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model, `INSTANCE_DIR` to the path where you just downloaded the cat images to, and `OUTPUT_DIR` to where you want to save the model. You'll use `` as the special word to tie the newly learned embeddings to. The script creates and saves model checkpoints and a pytorch_custom_diffusion_weights.bin file to your repository.
+
+To monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation prompt with `--validation_prompt`. This is useful for debugging and saving intermediate results.
+
+
+
+If you're training on human faces, the Custom Diffusion team has found the following parameters to work well:
+
+- `--learning_rate=5e-6`
+- `--max_train_steps` can be anywhere between 1000 and 2000
+- `--freeze_model=crossattn`
+- use at least 15-20 images to train with
+
+
+
+
+
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="./data/cat"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation \
+ --real_prior \
+ --prior_loss_weight=1.0 \
+ --class_prompt="cat" \
+ --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr \
+ --hflip \
+ --modifier_token "" \
+ --validation_prompt=" cat sitting in a bucket" \
+ --report_to="wandb" \
+ --push_to_hub
+```
+
+
+
+
+Custom Diffusion can also learn multiple concepts if you provide a [JSON](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with some details about each concept it should learn.
+
+Run clip-retrieval to collect some real images to use for regularization:
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
+```
+
+Then you can launch the script:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --concepts_list=./concept_list.json \
+ --with_prior_preservation \
+ --real_prior \
+ --prior_loss_weight=1.0 \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --num_class_images=200 \
+ --scale_lr \
+ --hflip \
+ --modifier_token "+" \
+ --push_to_hub
+```
+
+
+
+
+Once training is finished, you can use your new Custom Diffusion model for inference.
+
+
+
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16,
+).to("cuda")
+pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
+pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin")
+
+image = pipeline(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+
+
+
+```py
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("sayakpaul/custom-diffusion-cat-wooden-pot", torch_dtype=torch.float16).to("cuda")
+pipeline.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipeline.load_textual_inversion(model_id, weight_name=".bin")
+pipeline.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipeline(
+ "the cat sculpture in the style of a wooden pot",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("multi-subject.png")
+```
+
+
+
+
+## Next steps
+
+Congratulations on training a model with Custom Diffusion! 🎉 To learn more:
+
+- Read the [Multi-Concept Customization of Text-to-Image Diffusion](https://www.cs.cmu.edu/~custom-diffusion/) blog post to learn more details about the experimental results from the Custom Diffusion team.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/ddpo.md b/diffusers/docs/source/en/training/ddpo.md
new file mode 100644
index 0000000000000000000000000000000000000000..1ec961dfdd04accf6afd386d5de657ef3a75139f
--- /dev/null
+++ b/diffusers/docs/source/en/training/ddpo.md
@@ -0,0 +1,17 @@
+
+
+# Reinforcement learning training with DDPO
+
+You can fine-tune Stable Diffusion on a reward function via reinforcement learning with the 🤗 TRL library and 🤗 Diffusers. This is done with the Denoising Diffusion Policy Optimization (DDPO) algorithm introduced by Black et al. in [Training Diffusion Models with Reinforcement Learning](https://arxiv.org/abs/2305.13301), which is implemented in 🤗 TRL with the [`~trl.DDPOTrainer`].
+
+For more information, check out the [`~trl.DDPOTrainer`] API reference and the [Finetune Stable Diffusion Models with DDPO via TRL](https://huggingface.co/blog/trl-ddpo) blog post.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/distributed_inference.md b/diffusers/docs/source/en/training/distributed_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..72bb5f5fd7fe048ad7fa6f742d902db75447462b
--- /dev/null
+++ b/diffusers/docs/source/en/training/distributed_inference.md
@@ -0,0 +1,108 @@
+
+
+# Distributed inference with multiple GPUs
+
+On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
+
+This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference.
+
+## 🤗 Accelerate
+
+🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code.
+
+To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process.
+
+Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
+
+```py
+import torch
+from accelerate import PartialState
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+distributed_state = PartialState()
+pipeline.to(distributed_state.device)
+
+with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
+ result = pipeline(prompt).images[0]
+ result.save(f"result_{distributed_state.process_index}.png")
+```
+
+Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
+
+```bash
+accelerate launch run_distributed.py --num_processes=2
+```
+
+
+
+To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
+
+
+
+## PyTorch Distributed
+
+PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
+
+To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]:
+
+```py
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from diffusers import DiffusionPipeline
+
+sd = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+```
+
+You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` is 2.
+
+Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt:
+
+```py
+def run_inference(rank, world_size):
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+ sd.to(rank)
+
+ if torch.distributed.get_rank() == 0:
+ prompt = "a dog"
+ elif torch.distributed.get_rank() == 1:
+ prompt = "a cat"
+
+ image = sd(prompt).images[0]
+ image.save(f"./{'_'.join(prompt)}.png")
+```
+
+To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`:
+
+```py
+def main():
+ world_size = 2
+ mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
+
+
+if __name__ == "__main__":
+ main()
+```
+
+Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script:
+
+```bash
+torchrun run_distributed.py --nproc_per_node=2
+```
diff --git a/diffusers/docs/source/en/training/dreambooth.md b/diffusers/docs/source/en/training/dreambooth.md
new file mode 100644
index 0000000000000000000000000000000000000000..e71d2ea7bbe7de1bfef4e3d98aa930af1a8f5566
--- /dev/null
+++ b/diffusers/docs/source/en/training/dreambooth.md
@@ -0,0 +1,447 @@
+
+
+# DreamBooth
+
+[DreamBooth](https://huggingface.co/papers/2208.12242) is a training technique that updates the entire diffusion model by training on just a few images of a subject or style. It works by associating a special word in the prompt with the example images.
+
+If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing or xFormers. You should have a GPU with >30GB of memory if you want to train faster with Flax.
+
+This guide will explore the [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Navigate to the example folder with the training script and install the required dependencies for the script you're using:
+
+
+
+
+```bash
+cd examples/dreambooth
+pip install -r requirements.txt
+```
+
+
+
+
+```bash
+cd examples/dreambooth
+pip install -r requirements_flax.txt
+```
+
+
+
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+
+
+DreamBooth is very sensitive to training hyperparameters, and it is easy to overfit. Read the [Training Stable Diffusion with Dreambooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog post for recommended settings for different subjects to help you choose the appropriate hyperparameters.
+
+
+
+The training script offers many parameters for customizing your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L228) function. The parameters are set with default values that should work pretty well out-of-the-box, but you can also set your own values in the training command if you'd like.
+
+For example, to train in the bf16 format:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --mixed_precision="bf16"
+```
+
+Some basic and important parameters to know and specify are:
+
+- `--pretrained_model_name_or_path`: the name of the model on the Hub or a local path to the pretrained model
+- `--instance_data_dir`: path to a folder containing the training dataset (example images)
+- `--instance_prompt`: the text prompt that contains the special word for the example images
+- `--train_text_encoder`: whether to also train the text encoder
+- `--output_dir`: where to save the trained model
+- `--push_to_hub`: whether to push the trained model to the Hub
+- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command
+
+### Min-SNR weighting
+
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+
+Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --snr_gamma=5.0
+```
+
+### Prior preservation loss
+
+Prior preservation loss is a method that uses a model's own generated samples to help it learn how to generate more diverse images. Because these generated sample images belong to the same class as the images you provided, they help the model retain what it has learned about the class and how it can use what it already knows about the class to make new compositions.
+
+- `--with_prior_preservation`: whether to use prior preservation loss
+- `--prior_loss_weight`: controls the influence of the prior preservation loss on the model
+- `--class_data_dir`: path to a folder containing the generated class sample images
+- `--class_prompt`: the text prompt describing the class of the generated sample images
+
+```bash
+accelerate launch train_dreambooth.py \
+ --with_prior_preservation \
+ --prior_loss_weight=1.0 \
+ --class_data_dir="path/to/class/images" \
+ --class_prompt="text prompt describing class"
+```
+
+### Train text encoder
+
+To improve the quality of the generated outputs, you can also train the text encoder in addition to the UNet. This requires additional memory and you'll need a GPU with at least 24GB of vRAM. If you have the necessary hardware, then training the text encoder produces better results, especially when generating images of faces. Enable this option by:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --train_text_encoder
+```
+
+## Training script
+
+DreamBooth comes with its own dataset classes:
+
+- [`DreamBoothDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L604): preprocesses the images and class images, and tokenizes the prompts for training
+- [`PromptDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L738): generates the prompt embeddings to generate the class images
+
+If you enabled [prior preservation loss](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L842), the class images are generated here:
+
+```py
+sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+sample_dataloader = accelerator.prepare(sample_dataloader)
+pipeline.to(accelerator.device)
+
+for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+):
+ images = pipeline(example["prompt"]).images
+```
+
+Next is the [`main()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L799) function which handles setting up the dataset for training and the training loop itself. The script loads the [tokenizer](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L898), [scheduler and models](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L912C1-L912C1):
+
+```py
+# Load the tokenizer
+if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+# Load scheduler and models
+noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+)
+
+if model_has_vae(args):
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+else:
+ vae = None
+
+unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+)
+```
+
+Then, it's time to [create the training dataset](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1073) and DataLoader from `DreamBoothDataset`:
+
+```py
+train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+)
+
+train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+)
+```
+
+Lastly, the [training loop](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1151) takes care of the remaining steps such as converting images to latent space, adding noise to the input, predicting the noise residual, and calculating the loss.
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+You're now ready to launch the training script! 🚀
+
+For this guide, you'll download some images of a [dog](https://huggingface.co/datasets/diffusers/dog-example) and store them in a directory. But remember, you can create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir,
+ repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+Set the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model, `INSTANCE_DIR` to the path where you just downloaded the dog images to, and `OUTPUT_DIR` to where you want to save the model. You'll use `sks` as the special word to tie the training to.
+
+If you're interested in following along with the training process, you can periodically save generated images as training progresses. Add the following parameters to the training command:
+
+```bash
+--validation_prompt="a photo of a sks dog"
+--num_validation_images=4
+--validation_steps=100
+```
+
+One more thing before you launch the script! Depending on the GPU you have, you may need to enable certain optimizations to train DreamBooth.
+
+
+
+
+On a 16GB GPU, you can use bitsandbytes 8-bit optimizer and gradient checkpointing to help you train a DreamBooth model. Install bitsandbytes:
+
+```py
+pip install bitsandbytes
+```
+
+Then, add the following parameter to your training command:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+```
+
+
+
+
+On a 12GB GPU, you'll need bitsandbytes 8-bit optimizer, gradient checkpointing, xFormers, and set the gradients to `None` instead of zero to reduce your memory-usage.
+
+```bash
+accelerate launch train_dreambooth.py \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+```
+
+
+
+
+On a 8GB GPU, you'll need [DeepSpeed](https://www.deepspeed.ai/) to offload some of the tensors from the vRAM to either the CPU or NVME to allow training with less GPU memory.
+
+Run the following command to configure your 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+During configuration, confirm that you want to use DeepSpeed. Now it should be possible to train on under 8GB vRAM by combining DeepSpeed stage 2, fp16 mixed precision, and offloading the model parameters and the optimizer state to the CPU. The drawback is that this requires more system RAM (~25 GB). See the [DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options.
+
+You should also change the default Adam optimizer to DeepSpeed’s optimized version of Adam [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) for a substantial speedup. Enabling `DeepSpeedCPUAdam` requires your system’s CUDA toolchain version to be the same as the one installed with PyTorch.
+
+bitsandbytes 8-bit optimizers don’t seem to be compatible with DeepSpeed at the moment.
+
+That's it! You don't need to add any additional parameters to your training command.
+
+
+
+
+
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export INSTANCE_DIR="./dog"
+export OUTPUT_DIR="path_to_saved_model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400 \
+ --push_to_hub
+```
+
+
+
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="./dog"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --max_train_steps=400 \
+ --push_to_hub
+```
+
+
+
+
+Once training is complete, you can use your newly trained model for inference!
+
+
+
+Can't wait to try your model for inference before training is complete? 🤭 Make sure you have the latest version of 🤗 Accelerate installed.
+
+```py
+from diffusers import DiffusionPipeline, UNet2DConditionModel
+from transformers import CLIPTextModel
+import torch
+
+unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")
+
+# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
+text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
+).to("cuda")
+
+image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save("dog-bucket.png")
+```
+
+
+
+
+
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("path_to_saved_model", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save("dog-bucket.png")
+```
+
+
+
+
+```py
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("path-to-your-trained-model", dtype=jax.numpy.bfloat16)
+
+prompt = "A photo of sks dog in a bucket"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("dog-bucket.png")
+```
+
+
+
+
+## LoRA
+
+LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) script to train with LoRA.
+
+The LoRA training script is discussed in more detail in the [LoRA training](lora) guide.
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [train_dreambooth_lora_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py) script to train a SDXL model with LoRA.
+
+The SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.
+
+## Next steps
+
+Congratulations on training your DreamBooth model! To learn more about how to use your new model, the following guide may be helpful:
+
+- Learn how to [load a DreamBooth](../using-diffusers/loading_adapters) model for inference if you trained your model with LoRA.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/instructpix2pix.md b/diffusers/docs/source/en/training/instructpix2pix.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e17af2cd988ce66da53fc5ab4acc8137585e98e
--- /dev/null
+++ b/diffusers/docs/source/en/training/instructpix2pix.md
@@ -0,0 +1,252 @@
+
+
+# InstructPix2Pix
+
+[InstructPix2Pix](https://hf.co/papers/2211.09800) is a Stable Diffusion model trained to edit images from human-provided instructions. For example, your prompt can be "turn the clouds rainy" and the model will edit the input image accordingly. This model is conditioned on the text prompt (or editing instruction) and the input image.
+
+This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+```bash
+cd examples/instruct_pix2pix
+pip install -r requirements.txt
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training script has many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L65) function. Default values are provided for most parameters that work pretty well, but you can also set your own values in the training command if you'd like.
+
+For example, to increase the resolution of the input image:
+
+```bash
+accelerate launch train_instruct_pix2pix.py \
+ --resolution=512 \
+```
+
+Many of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the relevant parameters for InstructPix2Pix:
+
+- `--original_image_column`: the original image before the edits are made
+- `--edited_image_column`: the image after the edits are made
+- `--edit_prompt_column`: the instructions to edit the image
+- `--conditioning_dropout_prob`: the dropout probability for the edited image and edit prompts during training which enables classifier-free guidance (CFG) for one or both conditioning inputs
+
+## Training script
+
+The dataset preprocessing code and training loop are found in the [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L374) function. This is where you'll make your changes to the training script to adapt it for your own use-case.
+
+As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the InstructPix2Pix relevant parts of the script.
+
+The script begins by modifing the [number of input channels](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L445) in the first convolutional layer of the UNet to account for InstructPix2Pix's additional conditioning image:
+
+```py
+in_channels = 8
+out_channels = unet.conv_in.out_channels
+unet.register_to_config(in_channels=in_channels)
+
+with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+```
+
+These UNet parameters are [updated](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L545C1-L551C6) by the optimizer:
+
+```py
+optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Next, the edited images and and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
+
+```py
+def preprocess_train(examples):
+ preprocessed_images = preprocess_images(examples)
+
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ captions = list(examples[edit_prompt_column])
+ examples["input_ids"] = tokenize_captions(captions)
+ return examples
+```
+
+Finally, in the [training loop](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L730), it starts by encoding the edited images into latent space:
+
+```py
+latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
+latents = latents * vae.config.scaling_factor
+```
+
+Then, the script applies dropout to the original image and edit instruction embeddings to support CFG. This is what enables the model to modulate the influence of the edit instruction and original image on the edited image.
+
+```py
+encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
+
+if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ original_image_embeds = image_mask * original_image_embeds
+```
+
+That's pretty much it! Aside from the differences described here, the rest of the script is very similar to the [Text-to-image](text2image#training-script) training script, so feel free to check it out for more details. If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Once you're happy with the changes to your script or if you're okay with the default configuration, you're ready to launch the training script! 🚀
+
+This guide uses the [fusing/instructpix2pix-1000-samples](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) dataset, which is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered). You can also create and use your own dataset if you'd like (see the [Create a dataset for training](create_dataset) guide).
+
+Set the `MODEL_NAME` environment variable to the name of the model (can be a model id on the Hub or a path to a local model), and the `DATASET_ID` to the name of the dataset on the Hub. The script creates and saves all the components (feature extractor, scheduler, text encoder, UNet, etc.) to a subfolder in your repository.
+
+
+
+For better results, try longer training runs with a larger dataset. We've only tested this training script on a smaller-scale dataset.
+
+
+
+To monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation image with `--val_image_url` and a validation prompt with `--validation_prompt`. This can be really useful for debugging the model.
+
+
+
+If you’re training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 \
+ --random_flip \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 \
+ --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 \
+ --max_grad_norm=1 \
+ --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+After training is finished, you can use your new InstructPix2Pix for inference:
+
+```py
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionInstructPix2PixPipeline
+from diffusers.utils import load_image
+
+pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained("your_cool_model", torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png")
+prompt = "add some ducks to the lake"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipeline(
+ prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+You should experiment with different `num_inference_steps`, `image_guidance_scale`, and `guidance_scale` values to see how they affect inference speed and quality. The guidance scale parameters are especially impactful because they control how much the original image and edit instructions affect the edited image.
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [`train_instruct_pix2pix_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py) script to train a SDXL model to follow image editing instructions.
+
+The SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.
+
+## Next steps
+
+Congratulations on training your own InstructPix2Pix model! 🥳 To learn more about the model, it may be helpful to:
+
+- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/kandinsky.md b/diffusers/docs/source/en/training/kandinsky.md
new file mode 100644
index 0000000000000000000000000000000000000000..b2174996be0de1e50dfcbe06fd2c212d236159c1
--- /dev/null
+++ b/diffusers/docs/source/en/training/kandinsky.md
@@ -0,0 +1,327 @@
+
+
+# Kandinsky 2.2
+
+
+
+This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
+
+
+
+Kandinsky 2.2 is a multilingual text-to-image model capable of producing more photorealistic images. The model includes an image prior model for creating image embeddings from text prompts, and a decoder model that generates images based on the prior model's embeddings. That's why you'll find two separate scripts in Diffusers for Kandinsky 2.2, one for training the prior model and one for training the decoder model. You can train both models separately, but to get the best results, you should train both the prior and decoder models.
+
+Depending on your GPU, you may need to enable `gradient_checkpointing` (⚠️ not supported for the prior model!), `mixed_precision`, and `gradient_accumulation_steps` to help fit the model into memory and to speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) (version [v0.0.16](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212) fails for training on some GPUs so you may need to install a development version instead).
+
+This guide explores the [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py) and the [train_text_to_image_decoder.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py) scripts to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the scripts, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+```bash
+cd examples/kandinsky2_2/text_to_image
+pip install -r requirements.txt
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the scripts in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training scripts provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L190) function. The training scripts provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16"
+```
+
+Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so let's get straight to a walkthrough of the Kandinsky training scripts!
+
+### Min-SNR weighting
+
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+
+Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
+
+```bash
+accelerate launch train_text_to_image_prior.py \
+ --snr_gamma=5.0
+```
+
+## Training script
+
+The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support training the prior and decoder models. This guide focuses on the code that is unique to the Kandinsky 2.2 training scripts.
+
+
+
+
+The [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L441) function contains the code for preparing the dataset and training the model.
+
+One of the main differences you'll notice right away is that the training script also loads a [`~transformers.CLIPImageProcessor`] - in addition to a scheduler and tokenizer - for preprocessing images and a [`~transformers.CLIPVisionModelWithProjection`] model for encoding the images:
+
+```py
+noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
+image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+)
+tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
+
+with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+```
+
+Kandinsky uses a [`PriorTransformer`] to generate the image embeddings, so you'll want to setup the optimizer to learn the prior mode's parameters.
+
+```py
+prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+prior.train()
+optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Next, the input captions are tokenized, and images are [preprocessed](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L632) by the [`~transformers.CLIPImageProcessor`]:
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+```
+
+Finally, the [training loop](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L718) converts the input images into latents, adds noise to the image embeddings, and makes a prediction:
+
+```py
+model_pred = prior(
+ noisy_latents,
+ timestep=timesteps,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+).predicted_image_embedding
+```
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+
+
+
+The [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L440) function contains the code for preparing the dataset and training the model.
+
+Unlike the prior model, the decoder initializes a [`VQModel`] to decode the latents into images and it uses a [`UNet2DConditionModel`]:
+
+```py
+with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ vae = VQModel.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
+ ).eval()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+```
+
+Next, the script includes several image transforms and a [preprocessing](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L622) function for applying the transforms to the images and returning the pixel values:
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ return examples
+```
+
+Lastly, the [training loop](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L706) handles converting the images to latents, adding noise, and predicting the noise residual.
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+```py
+model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
+```
+
+
+
+
+## Launch the script
+
+Once you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀
+
+You'll train on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon, but you can also create and train on your own dataset by following the [Create a dataset for training](create_dataset) guide. Set the environment variable `DATASET_NAME` to the name of the dataset on the Hub or if you're training on your own files, set the environment variable `TRAIN_DIR` to a path to your dataset.
+
+If you’re training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
+
+
+
+To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
+
+
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-prior-pokemon-model"
+```
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-decoder-pokemon-model"
+```
+
+
+
+
+Once training is finished, you can use your newly trained model for inference!
+
+
+
+
+```py
+from diffusers import AutoPipelineForText2Image, DiffusionPipeline
+import torch
+
+prior_pipeline = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
+prior_components = {"prior_" + k: v for k,v in prior_pipeline.components.items()}
+pipeline = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)
+
+pipe.enable_model_cpu_offload()
+prompt="A robot pokemon, 4k photo"
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt).images[0]
+```
+
+
+
+Feel free to replace `kandinsky-community/kandinsky-2-2-decoder` with your own trained decoder checkpoint!
+
+
+
+
+
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+prompt="A robot pokemon, 4k photo"
+image = pipeline(prompt=prompt).images[0]
+```
+
+For the decoder model, you can also perform inference from a saved checkpoint which can be useful for viewing intermediate results. In this case, load the checkpoint into the UNet:
+
+```py
+from diffusers import AutoPipelineForText2Image, UNet2DConditionModel
+
+unet = UNet2DConditionModel.from_pretrained("path/to/saved/model" + "/checkpoint-/unet")
+
+pipeline = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+image = pipeline(prompt="A robot pokemon, 4k photo").images[0]
+```
+
+
+
+
+## Next steps
+
+Congratulations on training a Kandinsky 2.2 model! To learn more about how to use your new model, the following guides may be helpful:
+
+- Read the [Kandinsky](../using-diffusers/kandinsky) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting, interpolation), and how it can be combined with a ControlNet.
+- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized Kandinsky model with just a few example images. These two training techniques can even be combined!
diff --git a/diffusers/docs/source/en/training/lora.md b/diffusers/docs/source/en/training/lora.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ad088917dbca8a4d22d7f62b0dbe205024aa0a2
--- /dev/null
+++ b/diffusers/docs/source/en/training/lora.md
@@ -0,0 +1,217 @@
+
+
+# LoRA
+
+
+
+This is experimental and the API may change in the future.
+
+
+
+[LoRA (Low-Rank Adaptation of Large Language Models)](https://hf.co/papers/2106.09685) is a popular and lightweight training technique that significantly reduces the number of trainable parameters. It works by inserting a smaller number of new weights into the model and only these are trained. This makes training with LoRA much faster, memory-efficient, and produces smaller model weights (a few hundred MBs), which are easier to store and share. LoRA can also be combined with other training techniques like DreamBooth to speedup training.
+
+
+
+LoRA is very versatile and supported for [DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py), [Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py), [Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py), [text-to-image](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py), and [Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py).
+
+
+
+This guide will explore the [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Navigate to the example folder with the training script and install the required dependencies for the script you're using:
+
+
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements.txt
+```
+
+
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements_flax.txt
+```
+
+
+
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/text_to_image_lora.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training script has many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L85) function. Default values are provided for most parameters that work pretty well, but you can also set your own values in the training command if you'd like.
+
+For example, to increase the number of epochs to train:
+
+```bash
+accelerate launch train_text_to_image_lora.py \
+ --num_train_epochs=150 \
+```
+
+Many of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the LoRA relevant parameters:
+
+- `--rank`: the number of low-rank matrices to train
+- `--learning_rate`: the default learning rate is 1e-4, but with LoRA, you can use a higher learning rate
+
+## Training script
+
+The dataset preprocessing code and training loop are found in the [`main()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L371) function, and if you need to adapt the training script, this is where you'll make your changes.
+
+As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the LoRA relevant parts of the script.
+
+The script begins by adding the [new LoRA weights](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L447) to the attention layers. This involves correctly configuring the weight size for each block in the UNet. You'll see the `rank` parameter is used to create the [`~models.attention_processor.LoRAAttnProcessor`]:
+
+```py
+lora_attn_procs = {}
+for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=args.rank,
+ )
+
+unet.set_attn_processor(lora_attn_procs)
+lora_layers = AttnProcsLayers(unet.attn_processors)
+```
+
+The [optimizer](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L519) is initialized with the `lora_layers` because these are the only weights that'll be optimized:
+
+```py
+optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Aside from setting up the LoRA layers, the training script is more or less the same as train_text_to_image.py!
+
+## Launch the script
+
+Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
+
+Let's train on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate our yown Pokémon. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and dataset respectively. You should also specify where to save the model in `OUTPUT_DIR`, and the name of the model to save to on the Hub with `HUB_MODEL_ID`. The script creates and saves the following files to your repository:
+
+- saved model checkpoints
+- `pytorch_lora_weights.safetensors` (the trained LoRA weights)
+
+If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
+
+
+
+A full training run takes ~5 hours on a 2080 Ti GPU with 11GB of VRAM.
+
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
+export HUB_MODEL_ID="pokemon-lora"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --dataloader_num_workers=8 \
+ --resolution=512
+ --center_crop \
+ --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-04 \
+ --max_grad_norm=1 \
+ --lr_scheduler="cosine" \
+ --lr_warmup_steps=0 \
+ --output_dir=${OUTPUT_DIR} \
+ --push_to_hub \
+ --hub_model_id=${HUB_MODEL_ID} \
+ --report_to=wandb \
+ --checkpointing_steps=500 \
+ --validation_prompt="A pokemon with blue eyes." \
+ --seed=1337
+```
+
+Once training has been completed, you can use your model for inference:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+pipeline.load_lora_weights("path/to/lora/model", weight_name="pytorch_lora_weights.safetensors")
+image = pipeline("A pokemon with blue eyes").images[0]
+```
+
+## Next steps
+
+Congratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:
+
+- Learn how to [load different LoRA formats](../using-diffusers/loading_adapters#LoRA) trained using community trainers like Kohya and TheLastBen.
+- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/overview.md b/diffusers/docs/source/en/training/overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..50a9417972a049dfe2207e6f26cfb8969733ff52
--- /dev/null
+++ b/diffusers/docs/source/en/training/overview.md
@@ -0,0 +1,63 @@
+
+
+# Overview
+
+🤗 Diffusers provides a collection of training scripts for you to train your own diffusion models. You can find all of our training scripts in [diffusers/examples](https://github.com/huggingface/diffusers/tree/main/examples).
+
+Each training script is:
+
+- **Self-contained**: the training script does not depend on any local files, and all packages required to run the script are installed from the `requirements.txt` file.
+- **Easy-to-tweak**: the training scripts are an example of how to train a diffusion model for a specific task and won't work out-of-the-box for every training scenario. You'll likely need to adapt the training script for your specific use-case. To help you with that, we've fully exposed the data preprocessing code and the training loop so you can modify it for your own use.
+- **Beginner-friendly**: the training scripts are designed to be beginner-friendly and easy to understand, rather than including the latest state-of-the-art methods to get the best and most competitive results. Any training methods we consider too complex are purposefully left out.
+- **Single-purpose**: each training script is expressly designed for only one task to keep it readable and understandable.
+
+Our current collection of training scripts include:
+
+| Training | SDXL-support | LoRA-support | Flax-support |
+|---|---|---|---|
+| [unconditional image generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | | | |
+| [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) | 👍 | 👍 | 👍 |
+| [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | | | 👍 |
+| [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | 👍 | 👍 | 👍 |
+| [ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) | 👍 | | 👍 |
+| [InstructPix2Pix](https://github.com/huggingface/diffusers/tree/main/examples/instruct_pix2pix) | 👍 | | |
+| [Custom Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) | | | |
+| [T2I-Adapters](https://github.com/huggingface/diffusers/tree/main/examples/t2i_adapter) | 👍 | | |
+| [Kandinsky 2.2](https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image) | | 👍 | |
+| [Wuerstchen](https://github.com/huggingface/diffusers/tree/main/examples/wuerstchen/text_to_image) | | 👍 | |
+
+These examples are **actively** maintained, so please feel free to open an issue if they aren't working as expected. If you feel like another training example should be included, you're more than welcome to start a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) to discuss your feature idea with us and whether it meets our criteria of being self-contained, easy-to-tweak, beginner-friendly, and single-purpose.
+
+## Install
+
+Make sure you can successfully run the latest versions of the example scripts by installing the library from source in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the folder of the training script (for example, [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)) and install the `requirements.txt` file. Some training scripts have a specific requirement file for SDXL, LoRA or Flax. If you're using one of these scripts, make sure you install its corresponding requirements file.
+
+```bash
+cd examples/dreambooth
+pip install -r requirements.txt
+# to train SDXL with DreamBooth
+pip install -r requirements_sdxl.txt
+```
+
+To speedup training and reduce memory-usage, we recommend:
+
+- using PyTorch 2.0 or higher to automatically use [scaled dot product attention](../optimization/torch2.0#scaled-dot-product-attention) during training (you don't need to make any changes to the training code)
+- installing [xFormers](../optimization/xformers) to enable memory-efficient attention
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/sdxl.md b/diffusers/docs/source/en/training/sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..eebb614e907b44b6c22ac80333cefd38e82a43db
--- /dev/null
+++ b/diffusers/docs/source/en/training/sdxl.md
@@ -0,0 +1,266 @@
+
+
+# Stable Diffusion XL
+
+
+
+This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
+
+
+
+[Stable Diffusion XL (SDXL)](https://hf.co/papers/2307.01952) is a larger and more powerful iteration of the Stable Diffusion model, capable of producing higher resolution images.
+
+SDXL's UNet is 3x larger and the model adds a second text encoder to the architecture. Depending on the hardware available to you, this can be very computationally intensive and it may not run on a consumer GPU like a Tesla T4. To help fit this larger model into memory and to speedup training, try enabling `gradient_checkpointing`, `mixed_precision`, and `gradient_accumulation_steps`. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and using [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer.
+
+This guide will explore the [train_text_to_image_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) training script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+```bash
+cd examples/text_to_image
+pip install -r requirements_sdxl.txt
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+## Script parameters
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) and let us know if you have any questions or concerns.
+
+
+
+The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L129) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the bf16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_text_to_image_sdxl.py \
+ --mixed_precision="bf16"
+```
+
+Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to training SDXL in this guide.
+
+- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify a better [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)
+- `--proportion_empty_prompts`: the proportion of image prompts to replace with empty strings
+- `--timestep_bias_strategy`: where (earlier vs. later) in the timestep to apply a bias, which can encourage the model to either learn low or high frequency details
+- `--timestep_bias_multiplier`: the weight of the bias to apply to the timestep
+- `--timestep_bias_begin`: the timestep to begin applying the bias
+- `--timestep_bias_end`: the timestep to end applying the bias
+- `--timestep_bias_portion`: the proportion of timesteps to apply the bias to
+
+### Min-SNR weighting
+
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting either `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+
+Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
+
+```bash
+accelerate launch train_text_to_image_sdxl.py \
+ --snr_gamma=5.0
+```
+
+## Training script
+
+The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support SDXL training. This guide will focus on the code that is unique to the SDXL training script.
+
+It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
+
+Within the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L572) function, in addition to loading a tokenizer, the script loads a second tokenizer and text encoder because the SDXL architecture uses two of each:
+
+```py
+tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+)
+tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+)
+
+text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+)
+text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+)
+```
+
+The [prompt and image embeddings](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L857) are computed first and kept in memory, which isn't typically an issue for a smaller dataset, but for larger datasets it can lead to memory problems. If this is the case, you should save the pre-computed embeddings to disk separately and load them into memory during the training process (see this [PR](https://github.com/huggingface/diffusers/pull/4505) for more discussion about this topic).
+
+```py
+text_encoders = [text_encoder_one, text_encoder_two]
+tokenizers = [tokenizer_one, tokenizer_two]
+compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ caption_column=args.caption_column,
+)
+
+train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+train_dataset = train_dataset.map(
+ compute_vae_encodings_fn,
+ batched=True,
+ batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
+ new_fingerprint=new_fingerprint_for_vae,
+)
+```
+
+After calculating the embeddings, the text encoder, VAE, and tokenizer are deleted to free up some memory:
+
+```py
+del text_encoders, tokenizers, vae
+gc.collect()
+torch.cuda.empty_cache()
+```
+
+Finally, the [training loop](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L943) takes care of the rest. If you chose to apply a timestep bias strategy, you'll see the timestep weights are calculated and added as noise:
+
+```py
+weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+```
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Once you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀
+
+Let’s train on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and the dataset (either from the Hub or a local path). You should also specify a VAE other than the SDXL VAE (either from the Hub or a local path) with `VAE_NAME` to avoid numerical instabilities.
+
+
+
+To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` and `--validation_epochs` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
+
+
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_NAME \
+ --dataset_name=$DATASET_NAME \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 \
+ --center_crop \
+ --random_flip \
+ --proportion_empty_prompts=0.2 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=10000 \
+ --use_8bit_adam \
+ --learning_rate=1e-06 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --mixed_precision="fp16" \
+ --report_to="wandb" \
+ --validation_prompt="a cute Sundar Pichai creature" \
+ --validation_epochs 5 \
+ --checkpointing_steps=5000 \
+ --output_dir="sdxl-pokemon-model" \
+ --push_to_hub
+```
+
+After you've finished training, you can use your newly trained SDXL model for inference!
+
+
+
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("path/to/your/model", torch_dtype=torch.float16).to("cuda")
+
+prompt = "A pokemon with green eyes and red legs."
+image = pipeline(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("pokemon.png")
+```
+
+
+
+
+[PyTorch XLA](https://pytorch.org/xla) allows you to run PyTorch on XLA devices such as TPUs, which can be faster. The initial warmup step takes longer because the model needs to be compiled and optimized. However, subsequent calls to the pipeline on an input **with the same length** as the original prompt are much faster because it can reuse the optimized graph.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+import torch_xla.core.xla_model as xm
+
+device = xm.xla_device()
+pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to(device)
+
+prompt = "A pokemon with green eyes and red legs."
+start = time()
+image = pipeline(prompt, num_inference_steps=inference_steps).images[0]
+print(f'Compilation time is {time()-start} sec')
+image.save("pokemon.png")
+
+start = time()
+image = pipeline(prompt, num_inference_steps=inference_steps).images[0]
+print(f'Inference time is {time()-start} sec after compilation')
+```
+
+
+
+
+## Next steps
+
+Congratulations on training a SDXL model! To learn more about how to use your new model, the following guides may be helpful:
+
+- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use it's refiner model, and the different types of micro-conditionings.
+- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized SDXL model with just a few example images. These two training techniques can even be combined!
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/t2i_adapters.md b/diffusers/docs/source/en/training/t2i_adapters.md
new file mode 100644
index 0000000000000000000000000000000000000000..9d4f292b1d3fe7f11da3e53359aa99e28aaef92c
--- /dev/null
+++ b/diffusers/docs/source/en/training/t2i_adapters.md
@@ -0,0 +1,227 @@
+
+
+# T2I-Adapter
+
+[T2I-Adapter]((https://hf.co/papers/2302.08453)) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
+
+The T2I-Adapter is only available for training with the Stable Diffusion XL (SDXL) model.
+
+This guide will explore the [train_t2i_adapter_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/t2i_adapter/train_t2i_adapter_sdxl.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+```bash
+cd examples/t2i_adapter
+pip install -r requirements.txt
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/t2i_adapter/train_t2i_adapter_sdxl.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L233) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to activate gradient accumulation, add the `--gradient_accumulation_steps` parameter to the training command:
+
+```bash
+accelerate launch train_t2i_adapter_sdxl.py \
+ ----gradient_accumulation_steps=4
+```
+
+Many of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the relevant T2I-Adapter parameters:
+
+- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify a better [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)
+- `--crops_coords_top_left_h` and `--crops_coords_top_left_w`: height and width coordinates to include in SDXL's crop coordinate embeddings
+- `--conditioning_image_column`: the column of the conditioning images in the dataset
+- `--proportion_empty_prompts`: the proportion of image prompts to replace with empty strings
+
+## Training script
+
+As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the T2I-Adapter relevant parts of the script.
+
+The training script begins by preparing the dataset. This incudes [tokenizing](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L674) the prompt and [applying transforms](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L714) to the images and conditioning images.
+
+```py
+conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+)
+```
+
+Within the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L770) function, the T2I-Adapter is either loaded from a pretrained adapter or it is randomly initialized:
+
+```py
+if args.adapter_model_name_or_path:
+ logger.info("Loading existing adapter weights.")
+ t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
+else:
+ logger.info("Initializing t2iadapter weights.")
+ t2iadapter = T2IAdapter(
+ in_channels=3,
+ channels=(320, 640, 1280, 1280),
+ num_res_blocks=2,
+ downscale_factor=16,
+ adapter_type="full_adapter_xl",
+ )
+```
+
+The [optimizer](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L952) is initialized for the T2I-Adapter parameters:
+
+```py
+params_to_optimize = t2iadapter.parameters()
+optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Lastly, in the [training loop](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1086), the adapter conditioning image and the text embeddings are passed to the UNet to predict the noise residual:
+
+```py
+t2iadapter_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+down_block_additional_residuals = t2iadapter(t2iadapter_image)
+down_block_additional_residuals = [
+ sample.to(dtype=weight_dtype) for sample in down_block_additional_residuals
+]
+
+model_pred = unet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ down_block_additional_residuals=down_block_additional_residuals,
+).sample
+```
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Now you’re ready to launch the training script! 🚀
+
+For this example training, you'll use the [fusing/fill50k](https://huggingface.co/datasets/fusing/fill50k) dataset. You can also create and use your own dataset if you want (see the [Create a dataset for training](https://moon-ci-docs.huggingface.co/docs/diffusers/pr_5512/en/training/create_dataset) guide).
+
+Set the environment variable `MODEL_DIR` to a model id on the Hub or a path to a local model and `OUTPUT_DIR` to where you want to save the model.
+
+Download the following images to condition your training with:
+
+```bash
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+
+
+To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You'll also need to add the `--validation_image`, `--validation_prompt`, and `--validation_steps` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
+
+
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_t2i_adapter_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --mixed_precision="fp16" \
+ --resolution=1024 \
+ --learning_rate=1e-5 \
+ --max_train_steps=15000 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=100 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --report_to="wandb" \
+ --seed=42 \
+ --push_to_hub
+```
+
+Once training is complete, you can use your T2I-Adapter for inference:
+
+```py
+from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest
+from diffusers.utils import load_image
+import torch
+
+adapter = T2IAdapter.from_pretrained("path/to/adapter", torch_dtype=torch.float16)
+pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", adapter=adapter, torch_dtype=torch.float16
+)
+
+pipeline.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)
+pipeline.enable_xformers_memory_efficient_attention()
+pipeline.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+generator = torch.manual_seed(0)
+image = pipeline(
+ prompt, image=control_image, generator=generator
+).images[0]
+image.save("./output.png")
+```
+
+## Next steps
+
+Congratulations on training a T2I-Adapter model! 🎉 To learn more:
+
+- Read the [Efficient Controllable Generation for SDXL with T2I-Adapters](https://www.cs.cmu.edu/~custom-diffusion/) blog post to learn more details about the experimental results from the T2I-Adapter team.
diff --git a/diffusers/docs/source/en/training/text2image.md b/diffusers/docs/source/en/training/text2image.md
new file mode 100644
index 0000000000000000000000000000000000000000..9fa353ae31223d10cf4cdea40c57c60f20122505
--- /dev/null
+++ b/diffusers/docs/source/en/training/text2image.md
@@ -0,0 +1,275 @@
+
+
+# Text-to-image
+
+
+
+The text-to-image script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
+
+
+
+Text-to-image models like Stable Diffusion are conditioned to generate images given a text prompt.
+
+Training a model can be taxing on your hardware, but if you enable `gradient_checkpointing` and `mixed_precision`, it is possible to train a model on a single 24GB GPU. If you're training with larger batch sizes or want to train faster, it's better to use GPUs with more than 30GB of memory. You can reduce your memory footprint by enabling memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing, gradient accumulation or xFormers. A GPU with at least 30GB of memory or a TPU v3 is recommended for training with Flax.
+
+This guide will explore the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements.txt
+```
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements_flax.txt
+```
+
+
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+## Script parameters
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) and let us know if you have any questions or concerns.
+
+
+
+The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_text_to_image.py \
+ --mixed_precision="fp16"
+```
+
+Some basic and important parameters include:
+
+- `--pretrained_model_name_or_path`: the name of the model on the Hub or a local path to the pretrained model
+- `--dataset_name`: the name of the dataset on the Hub or a local path to the dataset to train on
+- `--image_column`: the name of the image column in the dataset to train on
+- `--caption_column`: the name of the text column in the dataset to train on
+- `--output_dir`: where to save the trained model
+- `--push_to_hub`: whether to push the trained model to the Hub
+- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command
+
+### Min-SNR weighting
+
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+
+Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
+
+```bash
+accelerate launch train_text_to_image.py \
+ --snr_gamma=5.0
+```
+
+You can compare the loss surfaces for different `snr_gamma` values in this [Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) report. For smaller datasets, the effects of Min-SNR may not be as obvious compared to larger datasets.
+
+## Training script
+
+The dataset preprocessing code and training loop are found in the [`main()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L490) function. If you need to adapt the training script, this is where you'll need to make your changes.
+
+The `train_text_to_image` script starts by [loading a scheduler](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L543) and tokenizer. You can choose to use a different scheduler here if you want:
+
+```py
+noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+)
+```
+
+Then the script [loads the UNet](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L619) model:
+
+```py
+load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+model.register_to_config(**load_model.config)
+
+model.load_state_dict(load_model.state_dict())
+```
+
+Next, the text and image columns of the dataset need to be preprocessed. The [`tokenize_captions`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L724) function handles tokenizing the inputs, and the [`train_transforms`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L742) function specifies the type of transforms to apply to the image. Both of these functions are bundled into `preprocess_train`:
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+```
+
+Lastly, the [training loop](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L878) handles everything else. It encodes images into latent space, adds noise to the latents, computes the text embeddings to condition on, updates the model parameters, and saves and pushes the model to the Hub. If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
+
+
+
+
+Let's train on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon. Set the environment variables `MODEL_NAME` and `dataset_name` to the model and the dataset (either from the Hub or a local path). If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
+
+
+
+To train on a local dataset, set the `TRAIN_DIR` and `OUTPUT_DIR` environment variables to the path of the dataset and where to save the model to.
+
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export dataset_name="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --enable_xformers_memory_efficient_attention
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model" \
+ --push_to_hub
+```
+
+
+
+
+Training with Flax can be faster on TPUs and GPUs thanks to [@duongna211](https://github.com/duongna21). Flax is more efficient on a TPU, but GPU performance is also great.
+
+Set the environment variables `MODEL_NAME` and `dataset_name` to the model and the dataset (either from the Hub or a local path).
+
+
+
+To train on a local dataset, set the `TRAIN_DIR` and `OUTPUT_DIR` environment variables to the path of the dataset and where to save the model to.
+
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export dataset_name="lambdalabs/pokemon-blip-captions"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-pokemon-model" \
+ --push_to_hub
+```
+
+
+
+
+Once training is complete, you can use your newly trained model for inference:
+
+
+
+
+```py
+from diffusers import StableDiffusionPipeline
+import torch
+
+pipeline = StableDiffusionPipeline.from_pretrained("path/to/saved_model", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+
+image = pipeline(prompt="yoda").images[0]
+image.save("yoda-pokemon.png")
+```
+
+
+
+
+```py
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("path/to/saved_model", dtype=jax.numpy.bfloat16)
+
+prompt = "yoda pokemon"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("yoda-pokemon.png")
+```
+
+
+
+
+## Next steps
+
+Congratulations on training your own text-to-image model! To learn more about how to use your new model, the following guides may be helpful:
+
+- Learn how to [load LoRA weights](../using-diffusers/loading_adapters#LoRA) for inference if you trained your model with LoRA.
+- Learn more about how certain parameters like guidance scale or techniques such as prompt weighting can help you control inference in the [Text-to-image](../using-diffusers/conditional_image_generation) task guide.
diff --git a/diffusers/docs/source/en/training/text_inversion.md b/diffusers/docs/source/en/training/text_inversion.md
new file mode 100644
index 0000000000000000000000000000000000000000..025dd457c55ace17922fbe90941d1240c2ab7d43
--- /dev/null
+++ b/diffusers/docs/source/en/training/text_inversion.md
@@ -0,0 +1,298 @@
+
+
+# Textual Inversion
+
+[Textual Inversion](https://hf.co/papers/2208.01618) is a training technique for personalizing image generation models with just a few example images of what you want it to learn. This technique works by learning and updating the text embeddings (the new embeddings are tied to a special word you must use in the prompt) to match the example images you provide.
+
+If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing or xFormers. With the same configuration and setup as PyTorch, the Flax training script should be at least ~70% faster!
+
+This guide will explore the [textual_inversion.py](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Navigate to the example folder with the training script and install the required dependencies for the script you're using:
+
+
+
+
+```bash
+cd examples/textual_inversion
+pip install -r requirements.txt
+```
+
+
+
+
+```bash
+cd examples/textual_inversion
+pip install -r requirements_flax.txt
+```
+
+
+
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training script has many parameters to help you tailor the training run to your needs. All of the parameters and their descriptions are listed in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L176) function. Where applicable, Diffusers provides default values for each parameter such as the training batch size and learning rate, but feel free to change these values in the training command if you'd like.
+
+For example, to increase the number of gradient accumulation steps above the default value of 1:
+
+```bash
+accelerate launch textual_inversion.py \
+ --gradient_accumulation_steps=4
+```
+
+Some other basic and important parameters to specify include:
+
+- `--pretrained_model_name_or_path`: the name of the model on the Hub or a local path to the pretrained model
+- `--train_data_dir`: path to a folder containing the training dataset (example images)
+- `--output_dir`: where to save the trained model
+- `--push_to_hub`: whether to push the trained model to the Hub
+- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command
+- `--num_vectors`: the number of vectors to learn the embeddings with; increasing this parameter helps the model learn better but it comes with increased training costs
+- `--placeholder_token`: the special word to tie the learned embeddings to (you must use the word in your prompt for inference)
+- `--initializer_token`: a single-word that roughly describes the object or style you're trying to train on
+- `--learnable_property`: whether you're training the model to learn a new "style" (for example, Van Gogh's painting style) or "object" (for example, your dog)
+
+## Training script
+
+Unlike some of the other training scripts, textual_inversion.py has a custom dataset class, [`TextualInversionDataset`](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L487) for creating a dataset. You can customize the image size, placeholder token, interpolation method, whether to crop the image, and more. If you need to change how the dataset is created, you can modify `TextualInversionDataset`.
+
+Next, you'll find the dataset preprocessing code and training loop in the [`main()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L573) function.
+
+The script starts by loading the [tokenizer](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L616), [scheduler and model](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L622):
+
+```py
+# Load tokenizer
+if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+# Load scheduler and models
+noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+)
+vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+)
+```
+
+The special [placeholder token](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L632) is added next to the tokenizer, and the embedding is readjusted to account for the new token.
+
+Then, the script [creates a dataset](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L716) from the `TextualInversionDataset`:
+
+```py
+train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+)
+train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+)
+```
+
+Finally, the [training loop](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L784) handles everything else from predicting the noisy residual to updating the embedding weights of the special placeholder token.
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
+
+For this guide, you'll download some images of a [cat toy](https://huggingface.co/datasets/diffusers/cat_toy_example) and store them in a directory. But remember, you can create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download(
+ "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
+)
+```
+
+Set the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model, and `DATA_DIR` to the path where you just downloaded the cat images to. The script creates and saves the following files to your repository:
+
+- `learned_embeds.bin`: the learned embedding vectors corresponding to your example images
+- `token_identifier.txt`: the special placeholder token
+- `type_of_concept.txt`: the type of concept you're training on (either "object" or "style")
+
+
+
+A full training run takes ~1 hour on a single V100 GPU.
+
+
+
+One more thing before you launch the script. If you're interested in following along with the training process, you can periodically save generated images as training progresses. Add the following parameters to the training command:
+
+```bash
+--validation_prompt="A train"
+--num_validation_images=4
+--validation_steps=100
+```
+
+
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATA_DIR="./cat"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat" \
+ --push_to_hub
+```
+
+
+
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="./cat"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --output_dir="textual_inversion_cat" \
+ --push_to_hub
+```
+
+
+
+
+After training is complete, you can use your newly trained model for inference like:
+
+
+
+
+```py
+from diffusers import StableDiffusionPipeline
+import torch
+
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+pipeline.load_textual_inversion("sd-concepts-library/cat-toy")
+image = pipeline("A train", num_inference_steps=50).images[0]
+image.save("cat-train.png")
+```
+
+
+
+
+Flax doesn't support the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] method, but the textual_inversion_flax.py script [saves](https://github.com/huggingface/diffusers/blob/c0f058265161178f2a88849e92b37ffdc81f1dcc/examples/textual_inversion/textual_inversion_flax.py#L636C2-L636C2) the learned embeddings as a part of the model after training. This means you can use the model for inference like any other Flax model:
+
+```py
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+model_path = "path-to-your-trained-model"
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
+
+prompt = "A train"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("cat-train.png")
+```
+
+
+
+
+## Next steps
+
+Congratulations on training your own Textual Inversion model! 🎉 To learn more about how to use your new model, the following guides may be helpful:
+
+- Learn how to [load Textual Inversion embeddings](../using-diffusers/loading_adapters) and also use them as negative embeddings.
+- Learn how to use [Textual Inversion](textual_inversion_inference) for inference with Stable Diffusion 1/2 and Stable Diffusion XL.
\ No newline at end of file
diff --git a/diffusers/docs/source/en/training/unconditional_training.md b/diffusers/docs/source/en/training/unconditional_training.md
new file mode 100644
index 0000000000000000000000000000000000000000..97b644883caed113c82572fae75fc3707a6059e0
--- /dev/null
+++ b/diffusers/docs/source/en/training/unconditional_training.md
@@ -0,0 +1,207 @@
+
+
+# Unconditional image generation
+
+Unconditional image generation models are not conditioned on text or images during training. It only generates images that resemble its training data distribution.
+
+This guide will explore the [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies:
+
+```bash
+cd examples/unconditional_image_generation
+pip install -r requirements.txt
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+## Script parameters
+
+
+
+The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) and let us know if you have any questions or concerns.
+
+
+
+The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L55) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the bf16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_unconditional.py \
+ --mixed_precision="bf16"
+```
+
+Some basic and important parameters to specify include:
+
+- `--dataset_name`: the name of the dataset on the Hub or a local path to the dataset to train on
+- `--output_dir`: where to save the trained model
+- `--push_to_hub`: whether to push the trained model to the Hub
+- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command
+
+Bring your dataset, and let the training script handle everything else!
+
+## Training script
+
+The code for preprocessing the dataset and the training loop is found in the [`main()`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L275) function. If you need to adapt the training script, this is where you'll need to make your changes.
+
+The `train_unconditional` script [initializes a `UNet2DModel`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L356) if you don't provide a model configuration. You can configure the UNet here if you'd like:
+
+```py
+model = UNet2DModel(
+ sample_size=args.resolution,
+ in_channels=3,
+ out_channels=3,
+ layers_per_block=2,
+ block_out_channels=(128, 128, 256, 256, 512, 512),
+ down_block_types=(
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "AttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "AttnUpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+)
+```
+
+Next, the script initializes a [scheduler](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L418) and [optimizer](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L429):
+
+```py
+# Initialize the scheduler
+accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
+if accepts_prediction_type:
+ noise_scheduler = DDPMScheduler(
+ num_train_timesteps=args.ddpm_num_steps,
+ beta_schedule=args.ddpm_beta_schedule,
+ prediction_type=args.prediction_type,
+ )
+else:
+ noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
+
+# Initialize the optimizer
+optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Then it [loads a dataset](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L451) and you can specify how to [preprocess](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L455) it:
+
+```py
+dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
+
+augmentations = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+)
+```
+
+Finally, the [training loop](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L540) handles everything else such as adding noise to the images, predicting the noise residual, calculating the loss, saving checkpoints at specified steps, and saving and pushing the model to the Hub. If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
+
+
+
+A full training run takes 2 hours on 4xV100 GPUs.
+
+
+
+
+
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --output_dir="ddpm-ema-flowers-64" \
+ --mixed_precision="fp16" \
+ --push_to_hub
+```
+
+
+
+
+If you're training with more than one GPU, add the `--multi_gpu` parameter to the training command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --output_dir="ddpm-ema-flowers-64" \
+ --mixed_precision="fp16" \
+ --push_to_hub
+```
+
+
+
+
+The training script creates and saves a checkpoint file in your repository. Now you can load and use your trained model for inference:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128").to("cuda")
+image = pipeline().images[0]
+```
diff --git a/diffusers/docs/source/en/training/wuerstchen.md b/diffusers/docs/source/en/training/wuerstchen.md
new file mode 100644
index 0000000000000000000000000000000000000000..9f04c8556a75f42ed3d7136b802fd22c5dbc6d2e
--- /dev/null
+++ b/diffusers/docs/source/en/training/wuerstchen.md
@@ -0,0 +1,189 @@
+
+
+# Wuerstchen
+
+The [Wuerstchen](https://hf.co/papers/2306.00637) model drastically reduces computational costs by compressing the latent space by 42x, without compromising image quality and accelerating inference. During training, Wuerstchen uses two models (VQGAN + autoencoder) to compress the latents, and then a third model (text-conditioned latent diffusion model) is conditioned on this highly compressed space to generate an image.
+
+To fit the prior model into GPU memory and to speedup training, try enabling `gradient_accumulation_steps`, `gradient_checkpointing`, and `mixed_precision` respectively.
+
+This guide explores the [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
+
+Before running the script, make sure you install the library from source:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
+
+```bash
+cd examples/wuerstchen/text_to_image
+pip install -r requirements.txt
+```
+
+
+
+🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
+
+
+
+Initialize an 🤗 Accelerate environment:
+
+```bash
+accelerate config
+```
+
+To setup a default 🤗 Accelerate environment without choosing any configurations:
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell, like a notebook, you can use:
+
+```bash
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
+
+
+
+The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the [script](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
+
+
+
+## Script parameters
+
+The training scripts provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L192) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
+
+For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
+
+```bash
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16"
+```
+
+Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so let's dive right into the Wuerstchen training script!
+
+## Training script
+
+The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support Wuerstchen. This guide focuses on the code that is unique to the Wuerstchen training script.
+
+The [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L441) function starts by initializing the image encoder - an [EfficientNet](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py) - in addition to the usual scheduler and tokenizer.
+
+```py
+with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
+ state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
+ image_encoder = EfficientNetEncoder()
+ image_encoder.load_state_dict(state_dict["effnet_state_dict"])
+ image_encoder.eval()
+```
+
+You'll also load the [`WuerstchenPrior`] model for optimization.
+
+```py
+prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+
+optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+Next, you'll apply some [transforms](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) to the images and [tokenize](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L637) the captions:
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+```
+
+Finally, the [training loop](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) handles compressing the images to latent space with the `EfficientNetEncoder`, adding noise to the latents, and predicting the noise residual with the [`WuerstchenPrior`] model.
+
+```py
+pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
+```
+
+If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.
+
+## Launch the script
+
+Once you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀
+
+Set the `DATASET_NAME` environment variable to the dataset name from the Hub. This guide uses the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset, but you can create and train on your own datasets as well (see the [Create a dataset for training](create_dataset) guide).
+
+
+
+To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --dataloader_num_workers=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="wuerstchen-prior-pokemon-model"
+```
+
+Once training is complete, you can use your newly trained model for inference!
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
+
+pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16).to("cuda")
+
+caption = "A cute bird pokemon holding a shield"
+images = pipeline(
+ caption,
+ width=1024,
+ height=1536,
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ prior_guidance_scale=4.0,
+ num_images_per_prompt=2,
+).images
+```
+
+## Next steps
+
+Congratulations on training a Wuerstchen model! To learn more about how to use your new model, the following may be helpful:
+
+- Take a look at the [Wuerstchen](../api/pipelines/wuerstchen#text-to-image-generation) API documentation to learn more about how to use the pipeline for text-to-image generation and its limitations.
diff --git a/diffusers/docs/source/en/tutorials/autopipeline.md b/diffusers/docs/source/en/tutorials/autopipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..4f17760e8bc0abeeac2512caf1047585774dec67
--- /dev/null
+++ b/diffusers/docs/source/en/tutorials/autopipeline.md
@@ -0,0 +1,170 @@
+
+
+# AutoPipeline
+
+🤗 Diffusers is able to complete many different tasks, and you can often reuse the same pretrained weights for multiple tasks such as text-to-image, image-to-image, and inpainting. If you're new to the library and diffusion models though, it may be difficult to know which pipeline to use for a task. For example, if you're using the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint for text-to-image, you might not know that you could also use it for image-to-image and inpainting by loading the checkpoint with the [`StableDiffusionImg2ImgPipeline`] and [`StableDiffusionInpaintPipeline`] classes respectively.
+
+The `AutoPipeline` class is designed to simplify the variety of pipelines in 🤗 Diffusers. It is a generic, *task-first* pipeline that lets you focus on the task. The `AutoPipeline` automatically detects the correct pipeline class to use, which makes it easier to load a checkpoint for a task without knowing the specific pipeline class name.
+
+
+
+Take a look at the [AutoPipeline](../api/pipelines/auto_pipeline) reference to see which tasks are supported. Currently, it supports text-to-image, image-to-image, and inpainting.
+
+
+
+This tutorial shows you how to use an `AutoPipeline` to automatically infer the pipeline class to load for a specific task, given the pretrained weights.
+
+## Choose an AutoPipeline for your task
+
+Start by picking a checkpoint. For example, if you're interested in text-to-image with the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint, use [`AutoPipelineForText2Image`]:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+prompt = "peasant and dragon combat, wood cutting style, viking era, bevel with rune"
+
+image = pipeline(prompt, num_inference_steps=25).images[0]
+image
+```
+
+
+
+
+
+Under the hood, [`AutoPipelineForText2Image`]:
+
+1. automatically detects a `"stable-diffusion"` class from the [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) file
+2. loads the corresponding text-to-image [`StableDiffusionPipeline`] based on the `"stable-diffusion"` class name
+
+Likewise, for image-to-image, [`AutoPipelineForImage2Image`] detects a `"stable-diffusion"` checkpoint from the `model_index.json` file and it'll load the corresponding [`StableDiffusionImg2ImgPipeline`] behind the scenes. You can also pass any additional arguments specific to the pipeline class such as `strength`, which determines the amount of noise or variation added to an input image:
+
+```py
+from diffusers import AutoPipelineForImage2Image
+import torch
+import requests
+from PIL import Image
+from io import BytesIO
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+prompt = "a portrait of a dog wearing a pearl earring"
+
+url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg"
+
+response = requests.get(url)
+image = Image.open(BytesIO(response.content)).convert("RGB")
+image.thumbnail((768, 768))
+
+image = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0]
+image
+```
+
+
+
+
+
+And if you want to do inpainting, then [`AutoPipelineForInpainting`] loads the underlying [`StableDiffusionInpaintPipeline`] class in the same way:
+
+```py
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+import torch
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).convert("RGB")
+mask_image = load_image(mask_url).convert("RGB")
+
+prompt = "A majestic tiger sitting on a bench"
+image = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
+image
+```
+
+
+
+
+
+If you try to load an unsupported checkpoint, it'll throw an error:
+
+```py
+from diffusers import AutoPipelineForImage2Image
+import torch
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
+)
+"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
+```
+
+## Use multiple pipelines
+
+For some workflows or if you're loading many pipelines, it is more memory-efficient to reuse the same components from a checkpoint instead of reloading them which would unnecessarily consume additional memory. For example, if you're using a checkpoint for text-to-image and you want to use it again for image-to-image, use the [`~AutoPipelineForImage2Image.from_pipe`] method. This method creates a new pipeline from the components of a previously loaded pipeline at no additional memory cost.
+
+The [`~AutoPipelineForImage2Image.from_pipe`] method detects the original pipeline class and maps it to the new pipeline class corresponding to the task you want to do. For example, if you load a `"stable-diffusion"` class pipeline for text-to-image:
+
+```py
+from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
+import torch
+
+pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+print(type(pipeline_text2img))
+""
+```
+
+Then [`~AutoPipelineForImage2Image.from_pipe`] maps the original `"stable-diffusion"` pipeline class to [`StableDiffusionImg2ImgPipeline`]:
+
+```py
+pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
+print(type(pipeline_img2img))
+""
+```
+
+If you passed an optional argument - like disabling the safety checker - to the original pipeline, this argument is also passed on to the new pipeline:
+
+```py
+from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
+import torch
+
+pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ requires_safety_checker=False,
+).to("cuda")
+
+pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
+print(pipeline_img2img.config.requires_safety_checker)
+"False"
+```
+
+You can overwrite any of the arguments and even configuration from the original pipeline if you want to change the behavior of the new pipeline. For example, to turn the safety checker back on and add the `strength` argument:
+
+```py
+pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3)
+print(pipeline_img2img.config.requires_safety_checker)
+"True"
+```
diff --git a/diffusers/docs/source/en/tutorials/basic_training.md b/diffusers/docs/source/en/tutorials/basic_training.md
new file mode 100644
index 0000000000000000000000000000000000000000..ba7e0f01bf8df728a0844c7b87cd8ae6f5755a90
--- /dev/null
+++ b/diffusers/docs/source/en/tutorials/basic_training.md
@@ -0,0 +1,403 @@
+
+
+[[open-in-colab]]
+
+# Train a diffusion model
+
+Unconditional image generation is a popular application of diffusion models that generates images that look like those in the dataset used for training. Typically, the best results are obtained from finetuning a pretrained model on a specific dataset. You can find many of these checkpoints on the [Hub](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model), but if you can't find one you like, you can always train your own!
+
+This tutorial will teach you how to train a [`UNet2DModel`] from scratch on a subset of the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset to generate your own 🦋 butterflies 🦋.
+
+
+
+💡 This training tutorial is based on the [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook. For additional details and context about diffusion models like how they work, check out the notebook!
+
+
+
+Before you begin, make sure you have 🤗 Datasets installed to load and preprocess image datasets, and 🤗 Accelerate, to simplify training on any number of GPUs. The following command will also install [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize training metrics (you can also use [Weights & Biases](https://docs.wandb.ai/) to track your training).
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install diffusers[training]
+```
+
+We encourage you to share your model with the community, and in order to do that, you'll need to login to your Hugging Face account (create one [here](https://hf.co/join) if you don't already have one!). You can login from a notebook and enter your token when prompted. Make sure your token has the write role.
+
+```py
+>>> from huggingface_hub import notebook_login
+
+>>> notebook_login()
+```
+
+Or login in from the terminal:
+
+```bash
+huggingface-cli login
+```
+
+Since the model checkpoints are quite large, install [Git-LFS](https://git-lfs.com/) to version these large files:
+
+```bash
+!sudo apt -qq install git-lfs
+!git config --global credential.helper store
+```
+
+## Training configuration
+
+For convenience, create a `TrainingConfig` class containing the training hyperparameters (feel free to adjust them):
+
+```py
+>>> from dataclasses import dataclass
+
+>>> @dataclass
+... class TrainingConfig:
+... image_size = 128 # the generated image resolution
+... train_batch_size = 16
+... eval_batch_size = 16 # how many images to sample during evaluation
+... num_epochs = 50
+... gradient_accumulation_steps = 1
+... learning_rate = 1e-4
+... lr_warmup_steps = 500
+... save_image_epochs = 10
+... save_model_epochs = 30
+... mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
+... output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub
+
+... push_to_hub = True # whether to upload the saved model to the HF Hub
+... hub_model_id = "/" # the name of the repository to create on the HF Hub
+... hub_private_repo = False
+... overwrite_output_dir = True # overwrite the old model when re-running the notebook
+... seed = 0
+
+
+>>> config = TrainingConfig()
+```
+
+## Load the dataset
+
+You can easily load the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset with the 🤗 Datasets library:
+
+```py
+>>> from datasets import load_dataset
+
+>>> config.dataset_name = "huggan/smithsonian_butterflies_subset"
+>>> dataset = load_dataset(config.dataset_name, split="train")
+```
+
+
+
+💡 You can find additional datasets from the [HugGan Community Event](https://huggingface.co/huggan) or you can use your own dataset by creating a local [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). Set `config.dataset_name` to the repository id of the dataset if it is from the HugGan Community Event, or `imagefolder` if you're using your own images.
+
+
+
+🤗 Datasets uses the [`~datasets.Image`] feature to automatically decode the image data and load it as a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html) which we can visualize:
+
+```py
+>>> import matplotlib.pyplot as plt
+
+>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))
+>>> for i, image in enumerate(dataset[:4]["image"]):
+... axs[i].imshow(image)
+... axs[i].set_axis_off()
+>>> fig.show()
+```
+
+
+
+
+
+The images are all different sizes though, so you'll need to preprocess them first:
+
+* `Resize` changes the image size to the one defined in `config.image_size`.
+* `RandomHorizontalFlip` augments the dataset by randomly mirroring the images.
+* `Normalize` is important to rescale the pixel values into a [-1, 1] range, which is what the model expects.
+
+```py
+>>> from torchvision import transforms
+
+>>> preprocess = transforms.Compose(
+... [
+... transforms.Resize((config.image_size, config.image_size)),
+... transforms.RandomHorizontalFlip(),
+... transforms.ToTensor(),
+... transforms.Normalize([0.5], [0.5]),
+... ]
+... )
+```
+
+Use 🤗 Datasets' [`~datasets.Dataset.set_transform`] method to apply the `preprocess` function on the fly during training:
+
+```py
+>>> def transform(examples):
+... images = [preprocess(image.convert("RGB")) for image in examples["image"]]
+... return {"images": images}
+
+
+>>> dataset.set_transform(transform)
+```
+
+Feel free to visualize the images again to confirm that they've been resized. Now you're ready to wrap the dataset in a [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader) for training!
+
+```py
+>>> import torch
+
+>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
+```
+
+## Create a UNet2DModel
+
+Pretrained models in 🧨 Diffusers are easily created from their model class with the parameters you want. For example, to create a [`UNet2DModel`]:
+
+```py
+>>> from diffusers import UNet2DModel
+
+>>> model = UNet2DModel(
+... sample_size=config.image_size, # the target image resolution
+... in_channels=3, # the number of input channels, 3 for RGB images
+... out_channels=3, # the number of output channels
+... layers_per_block=2, # how many ResNet layers to use per UNet block
+... block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block
+... down_block_types=(
+... "DownBlock2D", # a regular ResNet downsampling block
+... "DownBlock2D",
+... "DownBlock2D",
+... "DownBlock2D",
+... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
+... "DownBlock2D",
+... ),
+... up_block_types=(
+... "UpBlock2D", # a regular ResNet upsampling block
+... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
+... "UpBlock2D",
+... "UpBlock2D",
+... "UpBlock2D",
+... "UpBlock2D",
+... ),
+... )
+```
+
+It is often a good idea to quickly check the sample image shape matches the model output shape:
+
+```py
+>>> sample_image = dataset[0]["images"].unsqueeze(0)
+>>> print("Input shape:", sample_image.shape)
+Input shape: torch.Size([1, 3, 128, 128])
+
+>>> print("Output shape:", model(sample_image, timestep=0).sample.shape)
+Output shape: torch.Size([1, 3, 128, 128])
+```
+
+Great! Next, you'll need a scheduler to add some noise to the image.
+
+## Create a scheduler
+
+The scheduler behaves differently depending on whether you're using the model for training or inference. During inference, the scheduler generates image from the noise. During training, the scheduler takes a model output - or a sample - from a specific point in the diffusion process and applies noise to the image according to a *noise schedule* and an *update rule*.
+
+Let's take a look at the [`DDPMScheduler`] and use the `add_noise` method to add some random noise to the `sample_image` from before:
+
+```py
+>>> import torch
+>>> from PIL import Image
+>>> from diffusers import DDPMScheduler
+
+>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
+>>> noise = torch.randn(sample_image.shape)
+>>> timesteps = torch.LongTensor([50])
+>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
+
+>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])
+```
+
+
+
+
+
+The training objective of the model is to predict the noise added to the image. The loss at this step can be calculated by:
+
+```py
+>>> import torch.nn.functional as F
+
+>>> noise_pred = model(noisy_image, timesteps).sample
+>>> loss = F.mse_loss(noise_pred, noise)
+```
+
+## Train the model
+
+By now, you have most of the pieces to start training the model and all that's left is putting everything together.
+
+First, you'll need an optimizer and a learning rate scheduler:
+
+```py
+>>> from diffusers.optimization import get_cosine_schedule_with_warmup
+
+>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
+>>> lr_scheduler = get_cosine_schedule_with_warmup(
+... optimizer=optimizer,
+... num_warmup_steps=config.lr_warmup_steps,
+... num_training_steps=(len(train_dataloader) * config.num_epochs),
+... )
+```
+
+Then, you'll need a way to evaluate the model. For evaluation, you can use the [`DDPMPipeline`] to generate a batch of sample images and save it as a grid:
+
+```py
+>>> from diffusers import DDPMPipeline
+>>> from diffusers.utils import make_image_grid
+>>> import os
+
+>>> def evaluate(config, epoch, pipeline):
+... # Sample some images from random noise (this is the backward diffusion process).
+... # The default pipeline output type is `List[PIL.Image]`
+... images = pipeline(
+... batch_size=config.eval_batch_size,
+... generator=torch.manual_seed(config.seed),
+... ).images
+
+... # Make a grid out of the images
+... image_grid = make_image_grid(images, rows=4, cols=4)
+
+... # Save the images
+... test_dir = os.path.join(config.output_dir, "samples")
+... os.makedirs(test_dir, exist_ok=True)
+... image_grid.save(f"{test_dir}/{epoch:04d}.png")
+```
+
+Now you can wrap all these components together in a training loop with 🤗 Accelerate for easy TensorBoard logging, gradient accumulation, and mixed precision training. To upload the model to the Hub, write a function to get your repository name and information and then push it to the Hub.
+
+
+
+💡 The training loop below may look intimidating and long, but it'll be worth it later when you launch your training in just one line of code! If you can't wait and want to start generating images, feel free to copy and run the code below. You can always come back and examine the training loop more closely later, like when you're waiting for your model to finish training. 🤗
+
+
+
+```py
+>>> from accelerate import Accelerator
+>>> from huggingface_hub import create_repo, upload_folder
+>>> from tqdm.auto import tqdm
+>>> from pathlib import Path
+>>> import os
+
+>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
+... # Initialize accelerator and tensorboard logging
+... accelerator = Accelerator(
+... mixed_precision=config.mixed_precision,
+... gradient_accumulation_steps=config.gradient_accumulation_steps,
+... log_with="tensorboard",
+... project_dir=os.path.join(config.output_dir, "logs"),
+... )
+... if accelerator.is_main_process:
+... if config.output_dir is not None:
+... os.makedirs(config.output_dir, exist_ok=True)
+... if config.push_to_hub:
+... repo_id = create_repo(
+... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
+... ).repo_id
+... accelerator.init_trackers("train_example")
+
+... # Prepare everything
+... # There is no specific order to remember, you just need to unpack the
+... # objects in the same order you gave them to the prepare method.
+... model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+... model, optimizer, train_dataloader, lr_scheduler
+... )
+
+... global_step = 0
+
+... # Now you train the model
+... for epoch in range(config.num_epochs):
+... progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
+... progress_bar.set_description(f"Epoch {epoch}")
+
+... for step, batch in enumerate(train_dataloader):
+... clean_images = batch["images"]
+... # Sample noise to add to the images
+... noise = torch.randn(clean_images.shape, device=clean_images.device)
+... bs = clean_images.shape[0]
+
+... # Sample a random timestep for each image
+... timesteps = torch.randint(
+... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
+... dtype=torch.int64
+... )
+
+... # Add noise to the clean images according to the noise magnitude at each timestep
+... # (this is the forward diffusion process)
+... noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
+
+... with accelerator.accumulate(model):
+... # Predict the noise residual
+... noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
+... loss = F.mse_loss(noise_pred, noise)
+... accelerator.backward(loss)
+
+... accelerator.clip_grad_norm_(model.parameters(), 1.0)
+... optimizer.step()
+... lr_scheduler.step()
+... optimizer.zero_grad()
+
+... progress_bar.update(1)
+... logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+... progress_bar.set_postfix(**logs)
+... accelerator.log(logs, step=global_step)
+... global_step += 1
+
+... # After each epoch you optionally sample some demo images with evaluate() and save the model
+... if accelerator.is_main_process:
+... pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
+
+... if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
+... evaluate(config, epoch, pipeline)
+
+... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
+... if config.push_to_hub:
+... upload_folder(
+... repo_id=repo_id,
+... folder_path=config.output_dir,
+... commit_message=f"Epoch {epoch}",
+... ignore_patterns=["step_*", "epoch_*"],
+... )
+... else:
+... pipeline.save_pretrained(config.output_dir)
+```
+
+Phew, that was quite a bit of code! But you're finally ready to launch the training with 🤗 Accelerate's [`~accelerate.notebook_launcher`] function. Pass the function the training loop, all the training arguments, and the number of processes (you can change this value to the number of GPUs available to you) to use for training:
+
+```py
+>>> from accelerate import notebook_launcher
+
+>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
+
+>>> notebook_launcher(train_loop, args, num_processes=1)
+```
+
+Once training is complete, take a look at the final 🦋 images 🦋 generated by your diffusion model!
+
+```py
+>>> import glob
+
+>>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
+>>> Image.open(sample_images[-1])
+```
+
+
+
+
+
+## Next steps
+
+Unconditional image generation is one example of a task that can be trained. You can explore other tasks and training techniques by visiting the [🧨 Diffusers Training Examples](../training/overview) page. Here are some examples of what you can learn:
+
+* [Textual Inversion](../training/text_inversion), an algorithm that teaches a model a specific visual concept and integrates it into the generated image.
+* [DreamBooth](../training/dreambooth), a technique for generating personalized images of a subject given several input images of the subject.
+* [Guide](../training/text2image) to finetuning a Stable Diffusion model on your own dataset.
+* [Guide](../training/lora) to using LoRA, a memory-efficient technique for finetuning really large models faster.
diff --git a/diffusers/docs/source/en/tutorials/tutorial_overview.md b/diffusers/docs/source/en/tutorials/tutorial_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee7a49e43851d6f49476d80279106f9bd7a588f4
--- /dev/null
+++ b/diffusers/docs/source/en/tutorials/tutorial_overview.md
@@ -0,0 +1,23 @@
+
+
+# Overview
+
+Welcome to 🧨 Diffusers! If you're new to diffusion models and generative AI, and want to learn more, then you've come to the right place. These beginner-friendly tutorials are designed to provide a gentle introduction to diffusion models and help you understand the library fundamentals - the core components and how 🧨 Diffusers is meant to be used.
+
+You'll learn how to use a pipeline for inference to rapidly generate things, and then deconstruct that pipeline to really understand how to use the library as a modular toolbox for building your own diffusion systems. In the next lesson, you'll learn how to train your own diffusion model to generate what you want.
+
+After completing the tutorials, you'll have gained the necessary skills to start exploring the library on your own and see how to use it for your own projects and applications.
+
+Feel free to join our community on [Discord](https://discord.com/invite/JfAtkvEtRb) or the [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) to connect and collaborate with other users and developers!
+
+Let's start diffusing! 🧨
diff --git a/diffusers/docs/source/en/tutorials/using_peft_for_inference.md b/diffusers/docs/source/en/tutorials/using_peft_for_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..6f317a7610b222c3e70fe6c8dd59ba72ef7189ff
--- /dev/null
+++ b/diffusers/docs/source/en/tutorials/using_peft_for_inference.md
@@ -0,0 +1,185 @@
+
+
+[[open-in-colab]]
+
+# Load LoRAs for inference
+
+There are many adapters (with LoRAs being the most common type) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) for inference.
+
+Throughout this guide, you'll use LoRA as the main adapter technique, so we'll use the terms LoRA and adapter interchangeably. You should have some familiarity with LoRA, and if you don't, we welcome you to check out the [LoRA guide](https://huggingface.co/docs/peft/conceptual_guides/lora).
+
+Let's first install all the required libraries.
+
+```bash
+!pip install -q transformers accelerate
+!pip install peft
+!pip install diffusers
+```
+
+Now, let's load a pipeline with a SDXL checkpoint:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
+```
+
+
+Next, load a LoRA checkpoint with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method.
+
+With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
+
+```python
+pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+```
+
+And then perform inference:
+
+```python
+prompt = "toy_face of a hacker with a hoodie"
+
+lora_scale= 0.9
+image = pipe(
+ prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
+).images[0]
+image
+```
+
+![toy-face](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_8_1.png)
+
+
+With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images, and let's call it `"pixel"`.
+
+The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter. But you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method as shown below:
+
+```python
+pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+pipe.set_adapters("pixel")
+```
+
+Let's now generate an image with the second adapter and check the result:
+
+```python
+prompt = "a hacker with a hoodie, pixel art"
+image = pipe(
+ prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
+).images[0]
+image
+```
+
+![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png)
+
+## Combine multiple adapters
+
+You can also perform multi-adapter inference where you combine different adapter checkpoints for inference.
+
+Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate two LoRA checkpoints and specify the weight for how the checkpoints should be combined.
+
+```python
+pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
+```
+
+Now that we have set these two adapters, let's generate an image from the combined adapters!
+
+
+
+LoRA checkpoints in the diffusion community are almost always obtained with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth). DreamBooth training often relies on "trigger" words in the input text prompts in order for the generation results to look as expected. When you combine multiple LoRA checkpoints, it's important to ensure the trigger words for the corresponding LoRA checkpoints are present in the input text prompts.
+
+
+
+The trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) are found in their repositories.
+
+
+```python
+# Notice how the prompt is constructed.
+prompt = "toy_face of a hacker with a hoodie, pixel art"
+image = pipe(
+ prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)
+).images[0]
+image
+```
+
+![toy-face-pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_16_1.png)
+
+Impressive! As you can see, the model was able to generate an image that mixes the characteristics of both adapters.
+
+If you want to go back to using only one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
+
+```python
+# First, set the adapter.
+pipe.set_adapters("toy")
+
+# Then, run inference.
+prompt = "toy_face of a hacker with a hoodie"
+lora_scale= 0.9
+image = pipe(
+ prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
+).images[0]
+image
+```
+
+![toy-face-again](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_18_1.png)
+
+
+If you want to switch to only the base model, disable all LoRAs with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method.
+
+
+```python
+pipe.disable_lora()
+
+prompt = "toy_face of a hacker with a hoodie"
+lora_scale= 0.9
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+image
+```
+
+![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)
+
+## Monitoring active adapters
+
+You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, you can easily check the list of active adapters using the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method:
+
+```py
+active_adapters = pipe.get_active_adapters()
+active_adapters
+["toy", "pixel"]
+```
+
+You can also get the active adapters of each pipeline component with [`~diffusers.loaders.LoraLoaderMixin.get_list_adapters`]:
+
+```py
+list_adapters_component_wise = pipe.get_list_adapters()
+list_adapters_component_wise
+{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
+```
+
+## Fusing adapters into the model
+
+You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
+
+```py
+pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+
+pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
+# Fuses the LoRAs into the Unet
+pipe.fuse_lora()
+
+prompt = "toy_face of a hacker with a hoodie, pixel art"
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+
+# Gets the Unet back to the original state
+pipe.unfuse_lora()
+```
diff --git a/diffusers/docs/source/en/using-diffusers/callback.md b/diffusers/docs/source/en/using-diffusers/callback.md
new file mode 100644
index 0000000000000000000000000000000000000000..690d86c17a54e46aa7708ee9427c5dcb5dfe269a
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/callback.md
@@ -0,0 +1,65 @@
+
+
+# Pipeline callbacks
+
+The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. This can be really useful for *dynamically* adjusting certain pipeline attributes, or modifying tensor variables. The flexibility of callbacks opens up some interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale.
+
+This guide will show you how to use the `callback_on_step_end` parameter to disable classifier-free guidance (CFG) after 40% of the inference steps to save compute with minimal cost to performance.
+
+The callback function should have the following arguments:
+
+* `pipe` (or the pipeline instance) provides access to useful properties such as `num_timestep` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipe._guidance_scale=0.0`.
+* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timestep`.
+* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
+
+Your callback function should look something like this:
+
+```python
+def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
+ # adjust the batch_size of prompt_embeds according to guidance_scale
+ if step_index == int(pipe.num_timestep * 0.4):
+ prompt_embeds = callback_kwargs["prompt_embeds"]
+ prompt_embeds = prompt_embeds.chunk(2)[-1]
+
+ # update guidance_scale and prompt_embeds
+ pipe._guidance_scale = 0.0
+ callback_kwargs["prompt_embeds"] = prompt_embeds
+ return callback_kwargs
+```
+
+Now, you can pass the callback function to the `callback_on_step_end` parameter and the `prompt_embeds` to `callback_on_step_end_tensor_inputs`.
+
+```py
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+
+generator = torch.Generator(device="cuda").manual_seed(1)
+out = pipe(prompt, generator=generator, callback_on_step_end=callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
+
+out.images[0].save("out_custom_cfg.png")
+```
+
+The callback function is executed at the end of each denoising step, and modifies the pipeline attributes and tensor variables for the next denoising step.
+
+With callbacks, you can implement features such as dynamic CFG without having to modify the underlying code at all!
+
+
+
+🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/conditional_image_generation.md b/diffusers/docs/source/en/using-diffusers/conditional_image_generation.md
new file mode 100644
index 0000000000000000000000000000000000000000..eaca038d59fe7b5eae4283f62c252e63dedb36f0
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/conditional_image_generation.md
@@ -0,0 +1,316 @@
+
+
+# Text-to-image
+
+[[open-in-colab]]
+
+When you think of diffusion models, text-to-image is usually one of the first things that come to mind. Text-to-image generates an image from a text description (for example, "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k") which is also known as a *prompt*.
+
+From a very high level, a diffusion model takes a prompt and some random initial noise, and iteratively removes the noise to construct an image. The *denoising* process is guided by the prompt, and once the denoising process ends after a predetermined number of time steps, the image representation is decoded into an image.
+
+
+
+Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog post to learn more about how a latent diffusion model works.
+
+
+
+You can generate images from a prompt in 🤗 Diffusers in two steps:
+
+1. Load a checkpoint into the [`AutoPipelineForText2Image`] class, which automatically detects the appropriate pipeline class to use based on the checkpoint:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+```
+
+2. Pass a prompt to the pipeline to generate an image:
+
+```py
+image = pipeline(
+ "stained glass of darth vader, backlight, centered composition, masterpiece, photorealistic, 8k"
+).images[0]
+image
+```
+
+
+
+
+
+## Popular models
+
+The most common text-to-image models are [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). There are also ControlNet models or adapters that can be used with text-to-image models for more direct control in generating images. The results from each model are slightly different because of their architecture and training process, but no matter which model you choose, their usage is more or less the same. Let's use the same prompt for each model and compare their results.
+
+### Stable Diffusion v1.5
+
+[Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) is a latent diffusion model initialized from [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4), and finetuned for 595K steps on 512x512 images from the LAION-Aesthetics V2 dataset. You can use this model like:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+generator = torch.Generator("cuda").manual_seed(31)
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
+image
+```
+
+### Stable Diffusion XL
+
+SDXL is a much larger version of the previous Stable Diffusion models, and involves a two-stage model process that adds even more details to an image. It also includes some additional *micro-conditionings* to generate high-quality images centered subjects. Take a look at the more comprehensive [SDXL](sdxl) guide to learn more about how to use it. In general, you can use SDXL like:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+generator = torch.Generator("cuda").manual_seed(31)
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
+image
+```
+
+### Kandinsky 2.2
+
+The Kandinsky model is a bit different from the Stable Diffusion models because it also uses an image prior model to create embeddings that are used to better align text and images in the diffusion model.
+
+The easiest way to use Kandinsky 2.2 is:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
+).to("cuda")
+generator = torch.Generator("cuda").manual_seed(31)
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
+image
+```
+
+### ControlNet
+
+ControlNet models are auxiliary models or adapters that are finetuned on top of text-to-image models, such as [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Using ControlNet models in combination with text-to-image models offers diverse options for more explicit control over how to generate an image. With ControlNet, you add an additional conditioning input image to the model. For example, if you provide an image of a human pose (usually represented as multiple keypoints that are connected into a skeleton) as a conditioning input, the model generates an image that follows the pose of the image. Check out the more in-depth [ControlNet](controlnet) guide to learn more about other conditioning inputs and how to use them.
+
+In this example, let's condition the ControlNet with a human pose estimation image. Load the ControlNet model pretrained on human pose estimations:
+
+```py
+from diffusers import ControlNetModel, AutoPipelineForText2Image
+from diffusers.utils import load_image
+import torch
+
+controlnet = ControlNetModel.from_pretrained(
+ "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pose_image = load_image("https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png")
+```
+
+Pass the `controlnet` to the [`AutoPipelineForText2Image`], and provide the prompt and pose estimation image:
+
+```py
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+generator = torch.Generator("cuda").manual_seed(31)
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=pose_image, generator=generator).images[0]
+image
+```
+
+
+
+
+ Stable Diffusion v1.5
+
+
+
+ Stable Diffusion XL
+
+
+
+ Kandinsky 2.2
+
+
+
+ ControlNet (pose conditioning)
+
+
+
+## Configure pipeline parameters
+
+There are a number of parameters that can be configured in the pipeline that affect how an image is generated. You can change the image's output size, specify a negative prompt to improve image quality, and more. This section dives deeper into how to use these parameters.
+
+### Height and width
+
+The `height` and `width` parameters control the height and width (in pixels) of the generated image. By default, the Stable Diffusion v1.5 model outputs 512x512 images, but you can change this to any size that is a multiple of 8. For example, to create a rectangular image:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+image = pipeline(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", height=768, width=512
+).images[0]
+image
+```
+
+
+
+
+
+
+
+Other models may have different default image sizes depending on the image sizes in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!
+
+
+
+### Guidance scale
+
+The `guidance_scale` parameter affects how much the prompt influences image generation. A lower value gives the model "creativity" to generate images that are more loosely related to the prompt. Higher `guidance_scale` values push the model to follow the prompt more closely, and if this value is too high, you may observe some artifacts in the generated image.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+).to("cuda")
+image = pipeline(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", guidance_scale=3.5
+).images[0]
+image
+```
+
+
+
+
+ guidance_scale = 2.5
+
+
+
+ guidance_scale = 7.5
+
+
+
+ guidance_scale = 10.5
+
+
+
+### Negative prompt
+
+Just like how a prompt guides generation, a *negative prompt* steers the model away from things you don't want the model to generate. This is commonly used to improve overall image quality by removing poor or bad image features such as "low resolution" or "bad details". You can also use a negative prompt to remove or modify the content and style of an image.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+).to("cuda")
+image = pipeline(
+ prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy",
+).images[0]
+image
+```
+
+
+
+### Generator
+
+A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator) object enables reproducibility in a pipeline by setting a manual seed. You can use a `Generator` to generate batches of images and iteratively improve on an image generated from a seed as detailed in the [Improve image quality with deterministic generation](reusing_seeds) guide.
+
+You can set a seed and `Generator` as shown below. Creating an image with a `Generator` should return the same result each time instead of randomly generating a new image.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+).to("cuda")
+generator = torch.Generator(device="cuda").manual_seed(30)
+image = pipeline(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ generator=generator,
+).images[0]
+image
+```
+
+## Control image generation
+
+There are several ways to exert more control over how an image is generated outside of configuring a pipeline's parameters, such as prompt weighting and ControlNet models.
+
+### Prompt weighting
+
+Prompt weighting is a technique for increasing or decreasing the importance of concepts in a prompt to emphasize or minimize certain features in an image. We recommend using the [Compel](https://github.com/damian0815/compel) library to help you generate the weighted prompt embeddings.
+
+
+
+Learn how to create the prompt embeddings in the [Prompt weighting](weighted_prompts) guide. This example focuses on how to use the prompt embeddings in the pipeline.
+
+
+
+Once you've created the embeddings, you can pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the pipeline.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+).to("cuda")
+image = pipeline(
+ prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
+).images[0]
+```
+
+### ControlNet
+
+As you saw in the [ControlNet](#controlnet) section, these models offer a more flexible and accurate way to generate images by incorporating an additional conditioning image input. Each ControlNet model is pretrained on a particular type of conditioning image to generate new images that resemble it. For example, if you take a ControlNet model pretrained on depth maps, you can give the model a depth map as a conditioning input and it'll generate an image that preserves the spatial information in it. This is quicker and easier than specifying the depth information in a prompt. You can even combine multiple conditioning inputs with a [MultiControlNet](controlnet#multicontrolnet)!
+
+There are many types of conditioning inputs you can use, and 🤗 Diffusers supports ControlNet for Stable Diffusion and SDXL models. Take a look at the more comprehensive [ControlNet](controlnet) guide to learn how you can use these models.
+
+## Optimize
+
+Diffusion models are large, and the iterative nature of denoising an image is computationally expensive and intensive. But this doesn't mean you need access to powerful - or even many - GPUs to use them. There are many optimization techniques for running diffusion models on consumer and free-tier resources. For example, you can load model weights in half-precision to save GPU memory and increase speed or offload the entire model to the GPU to save even more memory.
+
+PyTorch 2.0 also supports a more memory-efficient attention mechanism called [*scaled dot product attention*](../optimization/torch2.0#scaled-dot-product-attention) that is automatically enabled if you're using PyTorch 2.0. You can combine this with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) to speed your code up even more:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16").to("cuda")
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+For more tips on how to optimize your code to save memory and speed up inference, read the [Memory and speed](../optimization/fp16) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/diffusers/docs/source/en/using-diffusers/contribute_pipeline.md b/diffusers/docs/source/en/using-diffusers/contribute_pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..ea0ec51721f294657532a1e552807664c330dd20
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/contribute_pipeline.md
@@ -0,0 +1,184 @@
+
+
+# Contribute a community pipeline
+
+
+
+💡 Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
+
+
+
+Community pipelines allow you to add any additional features you'd like on top of the [`DiffusionPipeline`]. The main benefit of building on top of the `DiffusionPipeline` is anyone can load and use your pipeline by only adding one more argument, making it super easy for the community to access.
+
+This guide will show you how to create a community pipeline and explain how they work. To keep things simple, you'll create a "one-step" pipeline where the `UNet` does a single forward pass and calls the scheduler once.
+
+## Initialize the pipeline
+
+You should start by creating a `one_step_unet.py` file for your community pipeline. In this file, create a pipeline class that inherits from the [`DiffusionPipeline`] to be able to load model weights and the scheduler configuration from the Hub. The one-step pipeline needs a `UNet` and a scheduler, so you'll need to add these as arguments to the `__init__` function:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+```
+
+To ensure your pipeline and its components (`unet` and `scheduler`) can be saved with [`~DiffusionPipeline.save_pretrained`], add them to the `register_modules` function:
+
+```diff
+ from diffusers import DiffusionPipeline
+ import torch
+
+ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
++ self.register_modules(unet=unet, scheduler=scheduler)
+```
+
+Cool, the `__init__` step is done and you can move to the forward pass now! 🔥
+
+## Define the forward pass
+
+In the forward pass, which we recommend defining as `__call__`, you have complete creative freedom to add whatever feature you'd like. For our amazing one-step pipeline, create a random image and only call the `unet` and `scheduler` once by setting `timestep=1`:
+
+```diff
+ from diffusers import DiffusionPipeline
+ import torch
+
+ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
++ def __call__(self):
++ image = torch.randn(
++ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
++ )
++ timestep = 1
+
++ model_output = self.unet(image, timestep).sample
++ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
+
++ return scheduler_output
+```
+
+That's it! 🚀 You can now run this pipeline by passing a `unet` and `scheduler` to it:
+
+```python
+from diffusers import DDPMScheduler, UNet2DModel
+
+scheduler = DDPMScheduler()
+unet = UNet2DModel()
+
+pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)
+
+output = pipeline()
+```
+
+But what's even better is you can load pre-existing weights into the pipeline if the pipeline structure is identical. For example, you can load the [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32) weights into the one-step pipeline:
+
+```python
+pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
+
+output = pipeline()
+```
+
+## Share your pipeline
+
+Open a Pull Request on the 🧨 Diffusers [repository](https://github.com/huggingface/diffusers) to add your awesome pipeline in `one_step_unet.py` to the [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) subfolder.
+
+Once it is merged, anyone with `diffusers >= 0.4.0` installed can use this pipeline magically 🪄 by specifying it in the `custom_pipeline` argument:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="one_step_unet", use_safetensors=True
+)
+pipe()
+```
+
+Another way to share your community pipeline is to upload the `one_step_unet.py` file directly to your preferred [model repository](https://huggingface.co/docs/hub/models-uploading) on the Hub. Instead of specifying the `one_step_unet.py` file, pass the model repository id to the `custom_pipeline` argument:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet", use_safetensors=True
+)
+```
+
+Take a look at the following table to compare the two sharing workflows to help you decide the best option for you:
+
+| | GitHub community pipeline | HF Hub community pipeline |
+|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
+| usage | same | same |
+| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow |
+| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility |
+
+
+
+💡 You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` because this is automatically detected.
+
+
+
+## How do community pipelines work?
+
+A community pipeline is a class that inherits from [`DiffusionPipeline`] which means:
+
+- It can be loaded with the [`custom_pipeline`] argument.
+- The model weights and scheduler configuration are loaded from [`pretrained_model_name_or_path`].
+- The code that implements a feature in the community pipeline is defined in a `pipeline.py` file.
+
+Sometimes you can't load all the pipeline components weights from an official repository. In this case, the other components should be passed directly to the pipeline:
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+
+model_id = "CompVis/stable-diffusion-v1-4"
+clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id)
+clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
+
+pipeline = DiffusionPipeline.from_pretrained(
+ model_id,
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+```
+
+The magic behind community pipelines is contained in the following code. It allows the community pipeline to be loaded from GitHub or the Hub, and it'll be available to all 🧨 Diffusers packages.
+
+```python
+# 2. Load the pipeline class, if using custom module then load it from the Hub
+# if we load from explicit class, let's use it
+if custom_pipeline is not None:
+ pipeline_class = get_class_from_dynamic_module(
+ custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
+ )
+elif cls != DiffusionPipeline:
+ pipeline_class = cls
+else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+```
diff --git a/diffusers/docs/source/en/using-diffusers/control_brightness.md b/diffusers/docs/source/en/using-diffusers/control_brightness.md
new file mode 100644
index 0000000000000000000000000000000000000000..c5f9870776dc29b53839e0b095e03c4104334ac8
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/control_brightness.md
@@ -0,0 +1,58 @@
+
+
+# Control image brightness
+
+The Stable Diffusion pipeline is mediocre at generating images that are either very bright or dark as explained in the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper. The solutions proposed in the paper are currently implemented in the [`DDIMScheduler`] which you can use to improve the lighting in your images.
+
+
+
+💡 Take a look at the paper linked above for more details about the proposed solutions!
+
+
+
+One of the solutions is to train a model with *v prediction* and *v loss*. Add the following flag to the [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts to enable `v_prediction`:
+
+```bash
+--prediction_type="v_prediction"
+```
+
+For example, let's use the [`ptx0/pseudo-journey-v2`](https://huggingface.co/ptx0/pseudo-journey-v2) checkpoint which has been finetuned with `v_prediction`.
+
+Next, configure the following parameters in the [`DDIMScheduler`]:
+
+1. `rescale_betas_zero_snr=True`, rescales the noise schedule to zero terminal signal-to-noise ratio (SNR)
+2. `timestep_spacing="trailing"`, starts sampling from the last timestep
+
+```py
+from diffusers import DiffusionPipeline, DDIMScheduler
+
+pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
+
+# switch the scheduler in the pipeline to use the DDIMScheduler
+pipeline.scheduler = DDIMScheduler.from_config(
+ pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
+)
+pipeline.to("cuda")
+```
+
+Finally, in your call to the pipeline, set `guidance_rescale` to prevent overexposure:
+
+```py
+prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
+image = pipeline(prompt, guidance_rescale=0.7).images[0]
+image
+```
+
+
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/controlling_generation.md b/diffusers/docs/source/en/using-diffusers/controlling_generation.md
new file mode 100644
index 0000000000000000000000000000000000000000..34ce0584bdbbf6b502b325f54eb817d3576cd2a4
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/controlling_generation.md
@@ -0,0 +1,217 @@
+
+
+# Controlled generation
+
+Controlling outputs generated by diffusion models has been long pursued by the community and is now an active research topic. In many popular diffusion models, subtle changes in inputs, both images and text prompts, can drastically change outputs. In an ideal world we want to be able to control how semantics are preserved and changed.
+
+Most examples of preserving semantics reduce to being able to accurately map a change in input to a change in output. I.e. adding an adjective to a subject in a prompt preserves the entire image, only modifying the changed subject. Or, image variation of a particular subject preserves the subject's pose.
+
+Additionally, there are qualities of generated images that we would like to influence beyond semantic preservation. I.e. in general, we would like our outputs to be of good quality, adhere to a particular style, or be realistic.
+
+We will document some of the techniques `diffusers` supports to control generation of diffusion models. Much is cutting edge research and can be quite nuanced. If something needs clarifying or you have a suggestion, don't hesitate to open a discussion on the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or a [GitHub issue](https://github.com/huggingface/diffusers/issues).
+
+We provide a high level explanation of how the generation can be controlled as well as a snippet of the technicals. For more in depth explanations on the technicals, the original papers which are linked from the pipelines are always the best resources.
+
+Depending on the use case, one should choose a technique accordingly. In many cases, these techniques can be combined. For example, one can combine Textual Inversion with SEGA to provide more semantic guidance to the outputs generated using Textual Inversion.
+
+Unless otherwise mentioned, these are techniques that work with existing models and don't require their own weights.
+
+1. [InstructPix2Pix](#instruct-pix2pix)
+2. [Pix2Pix Zero](#pix2pix-zero)
+3. [Attend and Excite](#attend-and-excite)
+4. [Semantic Guidance](#semantic-guidance-sega)
+5. [Self-attention Guidance](#self-attention-guidance-sag)
+6. [Depth2Image](#depth2image)
+7. [MultiDiffusion Panorama](#multidiffusion-panorama)
+8. [DreamBooth](#dreambooth)
+9. [Textual Inversion](#textual-inversion)
+10. [ControlNet](#controlnet)
+11. [Prompt Weighting](#prompt-weighting)
+12. [Custom Diffusion](#custom-diffusion)
+13. [Model Editing](#model-editing)
+14. [DiffEdit](#diffedit)
+15. [T2I-Adapter](#t2i-adapter)
+16. [FABRIC](#fabric)
+
+For convenience, we provide a table to denote which methods are inference-only and which require fine-tuning/training.
+
+| **Method** | **Inference only** | **Requires training / fine-tuning** | **Comments** |
+| :-------------------------------------------------: | :----------------: | :-------------------------------------: | :---------------------------------------------------------------------------------------------: |
+| [InstructPix2Pix](#instruct-pix2pix) | ✅ | ❌ | Can additionally be fine-tuned for better performance on specific edit instructions. |
+| [Pix2Pix Zero](#pix2pix-zero) | ✅ | ❌ | |
+| [Attend and Excite](#attend-and-excite) | ✅ | ❌ | |
+| [Semantic Guidance](#semantic-guidance-sega) | ✅ | ❌ | |
+| [Self-attention Guidance](#self-attention-guidance-sag) | ✅ | ❌ | |
+| [Depth2Image](#depth2image) | ✅ | ❌ | |
+| [MultiDiffusion Panorama](#multidiffusion-panorama) | ✅ | ❌ | |
+| [DreamBooth](#dreambooth) | ❌ | ✅ | |
+| [Textual Inversion](#textual-inversion) | ❌ | ✅ | |
+| [ControlNet](#controlnet) | ✅ | ❌ | A ControlNet can be trained/fine-tuned on a custom conditioning. |
+| [Prompt Weighting](#prompt-weighting) | ✅ | ❌ | |
+| [Custom Diffusion](#custom-diffusion) | ❌ | ✅ | |
+| [Model Editing](#model-editing) | ✅ | ❌ | |
+| [DiffEdit](#diffedit) | ✅ | ❌ | |
+| [T2I-Adapter](#t2i-adapter) | ✅ | ❌ | |
+| [Fabric](#fabric) | ✅ | ❌ | |
+## InstructPix2Pix
+
+[Paper](https://arxiv.org/abs/2211.09800)
+
+[InstructPix2Pix](../api/pipelines/pix2pix) is fine-tuned from Stable Diffusion to support editing input images. It takes as inputs an image and a prompt describing an edit, and it outputs the edited image.
+InstructPix2Pix has been explicitly trained to work well with [InstructGPT](https://openai.com/blog/instruction-following/)-like prompts.
+
+## Pix2Pix Zero
+
+[Paper](https://arxiv.org/abs/2302.03027)
+
+[Pix2Pix Zero](../api/pipelines/pix2pix_zero) allows modifying an image so that one concept or subject is translated to another one while preserving general image semantics.
+
+The denoising process is guided from one conceptual embedding towards another conceptual embedding. The intermediate latents are optimized during the denoising process to push the attention maps towards reference attention maps. The reference attention maps are from the denoising process of the input image and are used to encourage semantic preservation.
+
+Pix2Pix Zero can be used both to edit synthetic images as well as real images.
+
+- To edit synthetic images, one first generates an image given a caption.
+ Next, we generate image captions for the concept that shall be edited and for the new target concept. We can use a model like [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) for this purpose. Then, "mean" prompt embeddings for both the source and target concepts are created via the text encoder. Finally, the pix2pix-zero algorithm is used to edit the synthetic image.
+- To edit a real image, one first generates an image caption using a model like [BLIP](https://huggingface.co/docs/transformers/model_doc/blip). Then one applies DDIM inversion on the prompt and image to generate "inverse" latents. Similar to before, "mean" prompt embeddings for both source and target concepts are created and finally the pix2pix-zero algorithm in combination with the "inverse" latents is used to edit the image.
+
+
+
+Pix2Pix Zero is the first model that allows "zero-shot" image editing. This means that the model
+can edit an image in less than a minute on a consumer GPU as shown [here](../api/pipelines/pix2pix_zero#usage-example).
+
+
+
+As mentioned above, Pix2Pix Zero includes optimizing the latents (and not any of the UNet, VAE, or the text encoder) to steer the generation toward a specific concept. This means that the overall
+pipeline might require more memory than a standard [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img).
+
+
+
+An important distinction between methods like InstructPix2Pix and Pix2Pix Zero is that the former
+involves fine-tuning the pre-trained weights while the latter does not. This means that you can
+apply Pix2Pix Zero to any of the available Stable Diffusion models.
+
+
+
+## Attend and Excite
+
+[Paper](https://arxiv.org/abs/2301.13826)
+
+[Attend and Excite](../api/pipelines/attend_and_excite) allows subjects in the prompt to be faithfully represented in the final image.
+
+A set of token indices are given as input, corresponding to the subjects in the prompt that need to be present in the image. During denoising, each token index is guaranteed to have a minimum attention threshold for at least one patch of the image. The intermediate latents are iteratively optimized during the denoising process to strengthen the attention of the most neglected subject token until the attention threshold is passed for all subject tokens.
+
+Like Pix2Pix Zero, Attend and Excite also involves a mini optimization loop (leaving the pre-trained weights untouched) in its pipeline and can require more memory than the usual [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img).
+
+## Semantic Guidance (SEGA)
+
+[Paper](https://arxiv.org/abs/2301.12247)
+
+[SEGA](../api/pipelines/semantic_stable_diffusion) allows applying or removing one or more concepts from an image. The strength of the concept can also be controlled. I.e. the smile concept can be used to incrementally increase or decrease the smile of a portrait.
+
+Similar to how classifier free guidance provides guidance via empty prompt inputs, SEGA provides guidance on conceptual prompts. Multiple of these conceptual prompts can be applied simultaneously. Each conceptual prompt can either add or remove their concept depending on if the guidance is applied positively or negatively.
+
+Unlike Pix2Pix Zero or Attend and Excite, SEGA directly interacts with the diffusion process instead of performing any explicit gradient-based optimization.
+
+## Self-attention Guidance (SAG)
+
+[Paper](https://arxiv.org/abs/2210.00939)
+
+[Self-attention Guidance](../api/pipelines/self_attention_guidance) improves the general quality of images.
+
+SAG provides guidance from predictions not conditioned on high-frequency details to fully conditioned images. The high frequency details are extracted out of the UNet self-attention maps.
+
+## Depth2Image
+
+[Project](https://huggingface.co/stabilityai/stable-diffusion-2-depth)
+
+[Depth2Image](../api/pipelines/stable_diffusion/depth2img) is fine-tuned from Stable Diffusion to better preserve semantics for text guided image variation.
+
+It conditions on a monocular depth estimate of the original image.
+
+## MultiDiffusion Panorama
+
+[Paper](https://arxiv.org/abs/2302.08113)
+
+[MultiDiffusion Panorama](../api/pipelines/panorama) defines a new generation process over a pre-trained diffusion model. This process binds together multiple diffusion generation methods that can be readily applied to generate high quality and diverse images. Results adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.
+MultiDiffusion Panorama allows to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
+
+## Fine-tuning your own models
+
+In addition to pre-trained models, Diffusers has training scripts for fine-tuning models on user-provided data.
+
+## DreamBooth
+
+[Project](https://dreambooth.github.io/)
+
+[DreamBooth](../training/dreambooth) fine-tunes a model to teach it about a new subject. I.e. a few pictures of a person can be used to generate images of that person in different styles.
+
+## Textual Inversion
+
+[Paper](https://arxiv.org/abs/2208.01618)
+
+[Textual Inversion](../training/text_inversion) fine-tunes a model to teach it about a new concept. I.e. a few pictures of a style of artwork can be used to generate images in that style.
+
+## ControlNet
+
+[Paper](https://arxiv.org/abs/2302.05543)
+
+[ControlNet](../api/pipelines/controlnet) is an auxiliary network which adds an extra condition.
+There are 8 canonical pre-trained ControlNets trained on different conditionings such as edge detection, scribbles,
+depth maps, and semantic segmentations.
+
+## Prompt Weighting
+
+[Prompt weighting](../using-diffusers/weighted_prompts) is a simple technique that puts more attention weight on certain parts of the text
+input.
+
+## Custom Diffusion
+
+[Paper](https://arxiv.org/abs/2212.04488)
+
+[Custom Diffusion](../training/custom_diffusion) only fine-tunes the cross-attention maps of a pre-trained
+text-to-image diffusion model. It also allows for additionally performing Textual Inversion. It supports
+multi-concept training by design. Like DreamBooth and Textual Inversion, Custom Diffusion is also used to
+teach a pre-trained text-to-image diffusion model about new concepts to generate outputs involving the
+concept(s) of interest.
+
+## Model Editing
+
+[Paper](https://arxiv.org/abs/2303.08084)
+
+The [text-to-image model editing pipeline](../api/pipelines/model_editing) helps you mitigate some of the incorrect implicit assumptions a pre-trained text-to-image
+diffusion model might make about the subjects present in the input prompt. For example, if you prompt Stable Diffusion to generate images for "A pack of roses", the roses in the generated images
+are more likely to be red. This pipeline helps you change that assumption.
+
+## DiffEdit
+
+[Paper](https://arxiv.org/abs/2210.11427)
+
+[DiffEdit](../api/pipelines/diffedit) allows for semantic editing of input images along with
+input prompts while preserving the original input images as much as possible.
+
+## T2I-Adapter
+
+[Paper](https://arxiv.org/abs/2302.08453)
+
+[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition.
+There are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch,
+depth maps, and semantic segmentations.
+
+## Fabric
+
+[Paper](https://arxiv.org/abs/2307.10159)
+
+[Fabric](https://github.com/huggingface/diffusers/tree/442017ccc877279bcf24fbe92f92d3d0def191b6/examples/community#stable-diffusion-fabric-pipeline) is a training-free
+approach applicable to a wide range of popular diffusion models, which exploits
+the self-attention layer present in the most widely used architectures to condition
+the diffusion process on a set of feedback images.
diff --git a/diffusers/docs/source/en/using-diffusers/controlnet.md b/diffusers/docs/source/en/using-diffusers/controlnet.md
new file mode 100644
index 0000000000000000000000000000000000000000..c50d2e96e8ed40f58890a05b1bc8f0a07551c54a
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/controlnet.md
@@ -0,0 +1,566 @@
+
+
+# ControlNet
+
+ControlNet is a type of model for controlling image diffusion models by conditioning the model with an additional input image. There are many types of conditioning inputs (canny edge, user sketching, human pose, depth, and more) you can use to control a diffusion model. This is hugely useful because it affords you greater control over image generation, making it easier to generate specific images without experimenting with different text prompts or denoising values as much.
+
+
+
+Check out Section 3.5 of the [ControlNet](https://huggingface.co/papers/2302.05543) paper v1 for a list of ControlNet implementations on various conditioning inputs. You can find the official Stable Diffusion ControlNet conditioned models on [lllyasviel](https://huggingface.co/lllyasviel)'s Hub profile, and more [community-trained](https://huggingface.co/models?other=stable-diffusion&other=controlnet) ones on the Hub.
+
+For Stable Diffusion XL (SDXL) ControlNet models, you can find them on the 🤗 [Diffusers](https://huggingface.co/diffusers) Hub organization, or you can browse [community-trained](https://huggingface.co/models?other=stable-diffusion-xl&other=controlnet) ones on the Hub.
+
+
+
+A ControlNet model has two sets of weights (or blocks) connected by a zero-convolution layer:
+
+- a *locked copy* keeps everything a large pretrained diffusion model has learned
+- a *trainable copy* is trained on the additional conditioning input
+
+Since the locked copy preserves the pretrained model, training and implementing a ControlNet on a new conditioning input is as fast as finetuning any other model because you aren't training the model from scratch.
+
+This guide will show you how to use ControlNet for text-to-image, image-to-image, inpainting, and more! There are many types of ControlNet conditioning inputs to choose from, but in this guide we'll only focus on several of them. Feel free to experiment with other conditioning inputs!
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q diffusers transformers accelerate opencv-python
+```
+
+## Text-to-image
+
+For text-to-image, you normally pass a text prompt to the model. But with ControlNet, you can specify an additional conditioning input. Let's condition the model with a canny image, a white outline of an image on a black background. This way, the ControlNet can use the canny image as a control to guide the model to generate an image with the same outline.
+
+Load an image and use the [opencv-python](https://github.com/opencv/opencv-python) library to extract the canny image:
+
+```py
+from diffusers.utils import load_image, make_image_grid
+from PIL import Image
+import cv2
+import numpy as np
+
+original_image = load_image(
+ "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+)
+
+image = np.array(original_image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+```
+
+
+
+
+ original image
+
+
+
+ canny image
+
+
+
+Next, load a ControlNet model conditioned on canny edge detection and pass it to the [`StableDiffusionControlNetPipeline`]. Use the faster [`UniPCMultistepScheduler`] and enable model offloading to speed up inference and reduce memory usage.
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+import torch
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
+)
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+```
+
+Now pass your prompt and canny image to the pipeline:
+
+```py
+output = pipe(
+ "the mona lisa", image=canny_image
+).images[0]
+make_image_grid([original_image, canny_image, output], rows=1, cols=3)
+```
+
+
+
+
+
+## Image-to-image
+
+For image-to-image, you'd typically pass an initial image and a prompt to the pipeline to generate a new image. With ControlNet, you can pass an additional conditioning input to guide the model. Let's condition the model with a depth map, an image which contains spatial information. This way, the ControlNet can use the depth map as a control to guide the model to generate an image that preserves spatial information.
+
+You'll use the [`StableDiffusionControlNetImg2ImgPipeline`] for this task, which is different from the [`StableDiffusionControlNetPipeline`] because it allows you to pass an initial image as the starting point for the image generation process.
+
+Load an image and use the `depth-estimation` [`~transformers.Pipeline`] from 🤗 Transformers to extract the depth map of an image:
+
+```py
+import torch
+import numpy as np
+
+from transformers import pipeline
+from diffusers.utils import load_image, make_image_grid
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet-img2img.jpg"
+)
+
+def get_depth_map(image, depth_estimator):
+ image = depth_estimator(image)["depth"]
+ image = np.array(image)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ detected_map = torch.from_numpy(image).float() / 255.0
+ depth_map = detected_map.permute(2, 0, 1)
+ return depth_map
+
+depth_estimator = pipeline("depth-estimation")
+depth_map = get_depth_map(image, depth_estimator).unsqueeze(0).half().to("cuda")
+```
+
+Next, load a ControlNet model conditioned on depth maps and pass it to the [`StableDiffusionControlNetImg2ImgPipeline`]. Use the faster [`UniPCMultistepScheduler`] and enable model offloading to speed up inference and reduce memory usage.
+
+```py
+from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
+import torch
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
+)
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+```
+
+Now pass your prompt, initial image, and depth map to the pipeline:
+
+```py
+output = pipe(
+ "lego batman and robin", image=image, control_image=depth_map,
+).images[0]
+make_image_grid([image, output], rows=1, cols=2)
+```
+
+
+
+
+ original image
+
+
+
+ generated image
+
+
+
+## Inpainting
+
+For inpainting, you need an initial image, a mask image, and a prompt describing what to replace the mask with. ControlNet models allow you to add another control image to condition a model with. Let’s condition the model with an inpainting mask. This way, the ControlNet can use the inpainting mask as a control to guide the model to generate an image within the mask area.
+
+Load an initial image and a mask image:
+
+```py
+from diffusers.utils import load_image, make_image_grid
+
+init_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet-inpaint.jpg"
+)
+init_image = init_image.resize((512, 512))
+
+mask_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet-inpaint-mask.jpg"
+)
+mask_image = mask_image.resize((512, 512))
+make_image_grid([init_image, mask_image], rows=1, cols=2)
+```
+
+Create a function to prepare the control image from the initial and mask images. This'll create a tensor to mark the pixels in `init_image` as masked if the corresponding pixel in `mask_image` is over a certain threshold.
+
+```py
+import numpy as np
+import torch
+
+def make_inpaint_condition(image, image_mask):
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ assert image.shape[0:1] == image_mask.shape[0:1]
+ image[image_mask > 0.5] = 1.0 # set as masked pixel
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return image
+
+control_image = make_inpaint_condition(init_image, mask_image)
+```
+
+
+
+
+ original image
+
+
+
+ mask image
+
+
+
+Load a ControlNet model conditioned on inpainting and pass it to the [`StableDiffusionControlNetInpaintPipeline`]. Use the faster [`UniPCMultistepScheduler`] and enable model offloading to speed up inference and reduce memory usage.
+
+```py
+from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
+)
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+```
+
+Now pass your prompt, initial image, mask image, and control image to the pipeline:
+
+```py
+output = pipe(
+ "corgi face with large ears, detailed, pixar, animated, disney",
+ num_inference_steps=20,
+ eta=1.0,
+ image=init_image,
+ mask_image=mask_image,
+ control_image=control_image,
+).images[0]
+make_image_grid([init_image, mask_image, output], rows=1, cols=3)
+```
+
+
+
+
+
+## Guess mode
+
+[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) does not require supplying a prompt to a ControlNet at all! This forces the ControlNet encoder to do it's best to "guess" the contents of the input control map (depth map, pose estimation, canny edge, etc.).
+
+Guess mode adjusts the scale of the output residuals from a ControlNet by a fixed ratio depending on the block depth. The shallowest `DownBlock` corresponds to 0.1, and as the blocks get deeper, the scale increases exponentially such that the scale of the `MidBlock` output becomes 1.0.
+
+
+
+Guess mode does not have any impact on prompt conditioning and you can still provide a prompt if you want.
+
+
+
+Set `guess_mode=True` in the pipeline, and it is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) to set the `guidance_scale` value between 3.0 and 5.0.
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+from diffusers.utils import load_image, make_image_grid
+import numpy as np
+import torch
+from PIL import Image
+import cv2
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", use_safetensors=True)
+pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True).to("cuda")
+
+original_image = load_image("https://huggingface.co/takuma104/controlnet_dev/resolve/main/bird_512x512.png")
+
+image = np.array(original_image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
+make_image_grid([original_image, canny_image, image], rows=1, cols=3)
+```
+
+
+
+
+ regular mode with prompt
+
+
+
+ guess mode without prompt
+
+
+
+## ControlNet with Stable Diffusion XL
+
+There aren't too many ControlNet models compatible with Stable Diffusion XL (SDXL) at the moment, but we've trained two full-sized ControlNet models for SDXL conditioned on canny edge detection and depth maps. We're also experimenting with creating smaller versions of these SDXL-compatible ControlNet models so it is easier to run on resource-constrained hardware. You can find these checkpoints on the [🤗 Diffusers Hub organization](https://huggingface.co/diffusers)!
+
+Let's use a SDXL ControlNet conditioned on canny images to generate an image. Start by loading an image and prepare the canny image:
+
+```py
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
+from diffusers.utils import load_image, make_image_grid
+from PIL import Image
+import cv2
+import numpy as np
+import torch
+
+original_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+)
+
+image = np.array(original_image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+make_image_grid([original_image, canny_image], rows=1, cols=2)
+```
+
+
+
+
+ original image
+
+
+
+ canny image
+
+
+
+Load a SDXL ControlNet model conditioned on canny edge detection and pass it to the [`StableDiffusionXLControlNetPipeline`]. You can also enable model offloading to reduce memory usage.
+
+```py
+controlnet = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-canny-sdxl-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ controlnet=controlnet,
+ vae=vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True
+)
+pipe.enable_model_cpu_offload()
+```
+
+Now pass your prompt (and optionally a negative prompt if you're using one) and canny image to the pipeline:
+
+
+
+The [`controlnet_conditioning_scale`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet#diffusers.StableDiffusionControlNetPipeline.__call__.controlnet_conditioning_scale) parameter determines how much weight to assign to the conditioning inputs. A value of 0.5 is recommended for good generalization, but feel free to experiment with this number!
+
+
+
+```py
+prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
+negative_prompt = 'low quality, bad quality, sketches'
+
+image = pipe(
+ prompt,
+ negative_prompt=negative_prompt,
+ image=canny_image,
+ controlnet_conditioning_scale=0.5,
+).images[0]
+make_image_grid([original_image, canny_image, image], rows=1, cols=3)
+```
+
+
+
+
+
+You can use [`StableDiffusionXLControlNetPipeline`] in guess mode as well by setting the parameter to `True`:
+
+```py
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
+from diffusers.utils import load_image, make_image_grid
+import numpy as np
+import torch
+import cv2
+from PIL import Image
+
+prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
+negative_prompt = "low quality, bad quality, sketches"
+
+original_image = load_image(
+ "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+)
+
+controlnet = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True
+)
+pipe.enable_model_cpu_offload()
+
+image = np.array(original_image)
+image = cv2.Canny(image, 100, 200)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+image = pipe(
+ prompt, negative_prompt=negative_prompt, controlnet_conditioning_scale=0.5, image=canny_image, guess_mode=True,
+).images[0]
+make_image_grid([original_image, canny_image, image], rows=1, cols=3)
+```
+
+### MultiControlNet
+
+
+
+Replace the SDXL model with a model like [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) to use multiple conditioning inputs with Stable Diffusion models.
+
+
+
+You can compose multiple ControlNet conditionings from different image inputs to create a *MultiControlNet*. To get better results, it is often helpful to:
+
+1. mask conditionings such that they don't overlap (for example, mask the area of a canny image where the pose conditioning is located)
+2. experiment with the [`controlnet_conditioning_scale`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet#diffusers.StableDiffusionControlNetPipeline.__call__.controlnet_conditioning_scale) parameter to determine how much weight to assign to each conditioning input
+
+In this example, you'll combine a canny image and a human pose estimation image to generate a new image.
+
+Prepare the canny image conditioning:
+
+```py
+from diffusers.utils import load_image, make_image_grid
+from PIL import Image
+import numpy as np
+import cv2
+
+original_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png"
+)
+image = np.array(original_image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+
+# zero out middle columns of image where pose will be overlaid
+zero_start = image.shape[1] // 4
+zero_end = zero_start + image.shape[1] // 2
+image[:, zero_start:zero_end] = 0
+
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+make_image_grid([original_image, canny_image], rows=1, cols=2)
+```
+
+
+
+
+ original image
+
+
+
+ canny image
+
+
+
+For human pose estimation, install [controlnet_aux](https://github.com/patrickvonplaten/controlnet_aux):
+
+```py
+# uncomment to install the necessary library in Colab
+#!pip install -q controlnet-aux
+```
+
+Prepare the human pose estimation conditioning:
+
+```py
+from controlnet_aux import OpenposeDetector
+
+openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
+original_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/person.png"
+)
+openpose_image = openpose(original_image)
+make_image_grid([original_image, openpose_image], rows=1, cols=2)
+```
+
+
+
+
+ original image
+
+
+
+ human pose image
+
+
+
+Load a list of ControlNet models that correspond to each conditioning, and pass them to the [`StableDiffusionXLControlNetPipeline`]. Use the faster [`UniPCMultistepScheduler`] and enable model offloading to reduce memory usage.
+
+```py
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
+import torch
+
+controlnets = [
+ ControlNetModel.from_pretrained(
+ "thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16
+ ),
+ ControlNetModel.from_pretrained(
+ "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
+ ),
+]
+
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnets, vae=vae, torch_dtype=torch.float16, use_safetensors=True
+)
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+```
+
+Now you can pass your prompt (an optional negative prompt if you're using one), canny image, and pose image to the pipeline:
+
+```py
+prompt = "a giant standing in a fantasy landscape, best quality"
+negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
+
+generator = torch.manual_seed(1)
+
+images = [openpose_image.resize((1024, 1024)), canny_image.resize((1024, 1024))]
+
+images = pipe(
+ prompt,
+ image=images,
+ num_inference_steps=25,
+ generator=generator,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=3,
+ controlnet_conditioning_scale=[1.0, 0.8],
+).images
+make_image_grid([original_image, canny_image, openpose_image,
+ images[0].resize((512, 512)), images[1].resize((512, 512)), images[2].resize((512, 512))], rows=2, cols=3)
+```
+
+
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/custom_pipeline_examples.md b/diffusers/docs/source/en/using-diffusers/custom_pipeline_examples.md
new file mode 100644
index 0000000000000000000000000000000000000000..e0d3182f3e8a75c98333cd7960e2b0037e603dc1
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/custom_pipeline_examples.md
@@ -0,0 +1,119 @@
+
+
+# Community pipelines
+
+[[open-in-colab]]
+
+
+
+For more context about the design choices behind community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).
+
+
+
+Community pipelines allow you to get creative and build your own unique pipelines to share with the community. You can find all community pipelines in the [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) folder along with inference and training examples for how to use them. This guide showcases some of the community pipelines and hopefully it'll inspire you to create your own (feel free to open a PR with your own pipeline and we will merge it!).
+
+To load a community pipeline, use the `custom_pipeline` argument in [`DiffusionPipeline`] to specify one of the files in [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community):
+
+```py
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", custom_pipeline="filename_in_the_community_folder", use_safetensors=True
+)
+```
+
+If a community pipeline doesn't work as expected, please open a GitHub issue and mention the author.
+
+You can learn more about community pipelines in the how to [load community pipelines](custom_pipeline_overview) and how to [contribute a community pipeline](contribute_pipeline) guides.
+
+## Multilingual Stable Diffusion
+
+The multilingual Stable Diffusion pipeline uses a pretrained [XLM-RoBERTa](https://huggingface.co/papluca/xlm-roberta-base-language-detection) to identify a language and the [mBART-large-50](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) model to handle the translation. This allows you to generate images from text in 20 languages.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.utils import make_image_grid
+from transformers import (
+ pipeline,
+ MBart50TokenizerFast,
+ MBartForConditionalGeneration,
+)
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+device_dict = {"cuda": 0, "cpu": -1}
+
+# add language detection pipeline
+language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
+language_detection_pipeline = pipeline("text-classification",
+ model=language_detection_model_ckpt,
+ device=device_dict[device])
+
+# add model for language translation
+translation_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
+translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="multilingual_stable_diffusion",
+ detection_pipeline=language_detection_pipeline,
+ translation_model=translation_model,
+ translation_tokenizer=translation_tokenizer,
+ torch_dtype=torch.float16,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
+
+prompt = ["a photograph of an astronaut riding a horse",
+ "Una casa en la playa",
+ "Ein Hund, der Orange isst",
+ "Un restaurant parisien"]
+
+images = diffuser_pipeline(prompt).images
+make_image_grid(images, rows=2, cols=2)
+```
+
+
+
+
+
+## MagicMix
+
+[MagicMix](https://huggingface.co/papers/2210.16056) is a pipeline that can mix an image and text prompt to generate a new image that preserves the image structure. The `mix_factor` determines how much influence the prompt has on the layout generation, `kmin` controls the number of steps during the content generation process, and `kmax` determines how much information is kept in the layout of the original image.
+
+```py
+from diffusers import DiffusionPipeline, DDIMScheduler
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="magic_mix",
+ scheduler=DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
+).to('cuda')
+
+img = load_image("https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg")
+mix_img = pipeline(img, prompt="bed", kmin=0.3, kmax=0.5, mix_factor=0.5)
+make_image_grid([img, mix_img], rows=1, cols=2)
+```
+
+
+
+
+ original image
+
+
+
+ image and text prompt mix
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/custom_pipeline_overview.md b/diffusers/docs/source/en/using-diffusers/custom_pipeline_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f842c1b5b5004b859979370ed7c82ba633e5c80
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/custom_pipeline_overview.md
@@ -0,0 +1,189 @@
+
+
+# Load community pipelines and components
+
+[[open-in-colab]]
+
+## Community pipelines
+
+Community pipelines are any [`DiffusionPipeline`] class that are different from the original implementation as specified in their paper (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
+
+There are many cool community pipelines like [Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) or [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
+
+To load any community pipeline on the Hub, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you'd like to load the pipeline weights and components from. For example, the example below loads a dummy pipeline from [`hf-internal-testing/diffusers-dummy-pipeline`](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py) and the pipeline weights and components from [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32):
+
+
+
+🔒 By loading a community pipeline from the Hugging Face Hub, you are trusting that the code you are loading is safe. Make sure to inspect the code online before loading and running it automatically!
+
+
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline", use_safetensors=True
+)
+```
+
+Loading an official community pipeline is similar, but you can mix loading weights from an official repository id and pass pipeline components directly. The example below loads the community [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) pipeline, and you can pass the CLIP model components directly to it:
+
+```py
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+
+clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id)
+clip_model = CLIPModel.from_pretrained(clip_model_id)
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ use_safetensors=True,
+)
+```
+
+For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
+
+## Community components
+
+Community components allow users to build pipelines that may have customized components that are not a part of Diffusers. If your pipeline has custom components that Diffusers doesn't already support, you need to provide their implementations as Python modules. These customized components could be a VAE, UNet, and scheduler. In most cases, the text encoder is imported from the Transformers library. The pipeline code itself can also be customized.
+
+This section shows how users should use community components to build a community pipeline.
+
+You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example. So, let's start loading the components:
+
+1. Import and load the text encoder from Transformers:
+
+```python
+from transformers import T5Tokenizer, T5EncoderModel
+
+pipe_id = "showlab/show-1-base"
+tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
+text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
+```
+
+2. Load a scheduler:
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
+```
+
+3. Load an image processor:
+
+```python
+from transformers import CLIPFeatureExtractor
+
+feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")
+```
+
+
+
+In steps 4 and 5, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.
+
+
+
+4. Now you'll load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py), which in this example, has already been implemented in the `showone_unet_3d_condition.py` [script](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) for your convenience. You'll notice the `UNet3DConditionModel` class name is changed to `ShowOneUNet3DConditionModel` because [`UNet3DConditionModel`] already exists in Diffusers. Any components needed for the `ShowOneUNet3DConditionModel` class should be placed in the `showone_unet_3d_condition.py` script.
+
+Once this is done, you can initialize the UNet:
+
+```python
+from showone_unet_3d_condition import ShowOneUNet3DConditionModel
+
+unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
+```
+
+5. Finally, you'll load the custom pipeline code. For this example, it has already been created for you in the `pipeline_t2v_base_pixel.py` [script](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Just like the custom UNet, any code needed for the custom pipeline to work should go in the `pipeline_t2v_base_pixel.py` script.
+
+Once everything is in place, you can initialize the `TextToVideoIFPipeline` with the `ShowOneUNet3DConditionModel`:
+
+```python
+from pipeline_t2v_base_pixel import TextToVideoIFPipeline
+import torch
+
+pipeline = TextToVideoIFPipeline(
+ unet=unet,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor
+)
+pipeline = pipeline.to(device="cuda")
+pipeline.torch_dtype = torch.float16
+```
+
+Push the pipeline to the Hub to share with the community!
+
+```python
+pipeline.push_to_hub("custom-t2v-pipeline")
+```
+
+After the pipeline is successfully pushed, you need a couple of changes:
+
+1. Change the `_class_name` attribute in [`model_index.json`](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
+2. Upload `showone_unet_3d_condition.py` to the `unet` [directory](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).
+3. Upload `pipeline_t2v_base_pixel.py` to the pipeline base [directory](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).
+
+To run inference, simply add the `trust_remote_code` argument while initializing the pipeline to handle all the "magic" behind the scenes.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "/", trust_remote_code=True, torch_dtype=torch.float16
+).to("cuda")
+
+prompt = "hello"
+
+# Text embeds
+prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)
+
+# Keyframes generation (8x64x40, 2fps)
+video_frames = pipeline(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_embeds,
+ num_frames=8,
+ height=40,
+ width=64,
+ num_inference_steps=2,
+ guidance_scale=9.0,
+ output_type="pt"
+).frames
+```
+
+As an additional reference example, you can refer to the repository structure of [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/), that makes use of the `trust_remote_code` feature:
+
+```python
+
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/japanese-stable-diffusion-xl", trust_remote_code=True
+)
+pipeline.to("cuda")
+
+# if using torch < 2.0
+# pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "柴犬、カラフルアート"
+
+image = pipeline(prompt=prompt).images[0]
+
+```
\ No newline at end of file
diff --git a/diffusers/docs/source/en/using-diffusers/depth2img.md b/diffusers/docs/source/en/using-diffusers/depth2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..84c613b0dade1da18d7fd0892f919c0d6e6b4f4b
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/depth2img.md
@@ -0,0 +1,46 @@
+
+
+# Text-guided depth-to-image generation
+
+[[open-in-colab]]
+
+The [`StableDiffusionDepth2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. In addition, you can also pass a `depth_map` to preserve the image structure. If no `depth_map` is provided, the pipeline automatically predicts the depth via an integrated [depth-estimation model](https://github.com/isl-org/MiDaS).
+
+Start by creating an instance of the [`StableDiffusionDepth2ImgPipeline`]:
+
+```python
+import torch
+from diffusers import StableDiffusionDepth2ImgPipeline
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = StableDiffusionDepth2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-depth",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+```
+
+Now pass your prompt to the pipeline. You can also pass a `negative_prompt` to prevent certain words from guiding how an image is generated:
+
+```python
+url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+init_image = load_image(url)
+prompt = "two tigers"
+negative_prompt = "bad, deformed, ugly, bad anatomy"
+image = pipeline(prompt=prompt, image=init_image, negative_prompt=negative_prompt, strength=0.7).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+| Input | Output |
+|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|
+| | |
diff --git a/diffusers/docs/source/en/using-diffusers/diffedit.md b/diffusers/docs/source/en/using-diffusers/diffedit.md
new file mode 100644
index 0000000000000000000000000000000000000000..1c3793177ce1a87b15ed3204a969ce5e4346d9ff
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/diffedit.md
@@ -0,0 +1,285 @@
+
+
+# DiffEdit
+
+[[open-in-colab]]
+
+Image editing typically requires providing a mask of the area to be edited. DiffEdit automatically generates the mask for you based on a text query, making it easier overall to create a mask without image editing software. The DiffEdit algorithm works in three steps:
+
+1. the diffusion model denoises an image conditioned on some query text and reference text which produces different noise estimates for different areas of the image; the difference is used to infer a mask to identify which area of the image needs to be changed to match the query text
+2. the input image is encoded into latent space with DDIM
+3. the latents are decoded with the diffusion model conditioned on the text query, using the mask as a guide such that pixels outside the mask remain the same as in the input image
+
+This guide will show you how to use DiffEdit to edit images without manually creating a mask.
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q diffusers transformers accelerate
+```
+
+The [`StableDiffusionDiffEditPipeline`] requires an image mask and a set of partially inverted latents. The image mask is generated from the [`~StableDiffusionDiffEditPipeline.generate_mask`] function, and includes two parameters, `source_prompt` and `target_prompt`. These parameters determine what to edit in the image. For example, if you want to change a bowl of *fruits* to a bowl of *pears*, then:
+
+```py
+source_prompt = "a bowl of fruits"
+target_prompt = "a bowl of pears"
+```
+
+The partially inverted latents are generated from the [`~StableDiffusionDiffEditPipeline.invert`] function, and it is generally a good idea to include a `prompt` or *caption* describing the image to help guide the inverse latent sampling process. The caption can often be your `source_prompt`, but feel free to experiment with other text descriptions!
+
+Let's load the pipeline, scheduler, inverse scheduler, and enable some optimizations to reduce memory usage:
+
+```py
+import torch
+from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline
+
+pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1",
+ torch_dtype=torch.float16,
+ safety_checker=None,
+ use_safetensors=True,
+)
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
+pipeline.enable_model_cpu_offload()
+pipeline.enable_vae_slicing()
+```
+
+Load the image to edit:
+
+```py
+from diffusers.utils import load_image, make_image_grid
+
+img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
+raw_image = load_image(img_url).resize((768, 768))
+raw_image
+```
+
+Use the [`~StableDiffusionDiffEditPipeline.generate_mask`] function to generate the image mask. You'll need to pass it the `source_prompt` and `target_prompt` to specify what to edit in the image:
+
+```py
+from PIL import Image
+
+source_prompt = "a bowl of fruits"
+target_prompt = "a basket of pears"
+mask_image = pipeline.generate_mask(
+ image=raw_image,
+ source_prompt=source_prompt,
+ target_prompt=target_prompt,
+)
+Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
+```
+
+Next, create the inverted latents and pass it a caption describing the image:
+
+```py
+inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents
+```
+
+Finally, pass the image mask and inverted latents to the pipeline. The `target_prompt` becomes the `prompt` now, and the `source_prompt` is used as the `negative_prompt`:
+
+```py
+output_image = pipeline(
+ prompt=target_prompt,
+ mask_image=mask_image,
+ image_latents=inv_latents,
+ negative_prompt=source_prompt,
+).images[0]
+mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
+make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
+```
+
+
+
+
+ original image
+
+
+
+ edited image
+
+
+
+## Generate source and target embeddings
+
+The source and target embeddings can be automatically generated with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model instead of creating them manually.
+
+Load the Flan-T5 model and tokenizer from the 🤗 Transformers library:
+
+```py
+import torch
+from transformers import AutoTokenizer, T5ForConditionalGeneration
+
+tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
+model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto", torch_dtype=torch.float16)
+```
+
+Provide some initial text to prompt the model to generate the source and target prompts.
+
+```py
+source_concept = "bowl"
+target_concept = "basket"
+
+source_text = f"Provide a caption for images containing a {source_concept}. "
+"The captions should be in English and should be no longer than 150 characters."
+
+target_text = f"Provide a caption for images containing a {target_concept}. "
+"The captions should be in English and should be no longer than 150 characters."
+```
+
+Next, create a utility function to generate the prompts:
+
+```py
+@torch.no_grad()
+def generate_prompts(input_prompt):
+ input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
+
+ outputs = model.generate(
+ input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
+ )
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
+
+source_prompts = generate_prompts(source_text)
+target_prompts = generate_prompts(target_text)
+print(source_prompts)
+print(target_prompts)
+```
+
+
+
+Check out the [generation strategy](https://huggingface.co/docs/transformers/main/en/generation_strategies) guide if you're interested in learning more about strategies for generating different quality text.
+
+
+
+Load the text encoder model used by the [`StableDiffusionDiffEditPipeline`] to encode the text. You'll use the text encoder to compute the text embeddings:
+
+```py
+import torch
+from diffusers import StableDiffusionDiffEditPipeline
+
+pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+pipeline.enable_vae_slicing()
+
+@torch.no_grad()
+def embed_prompts(sentences, tokenizer, text_encoder, device="cuda"):
+ embeddings = []
+ for sent in sentences:
+ text_inputs = tokenizer(
+ sent,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
+ embeddings.append(prompt_embeds)
+ return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)
+
+source_embeds = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder)
+target_embeds = embed_prompts(target_prompts, pipeline.tokenizer, pipeline.text_encoder)
+```
+
+Finally, pass the embeddings to the [`~StableDiffusionDiffEditPipeline.generate_mask`] and [`~StableDiffusionDiffEditPipeline.invert`] functions, and pipeline to generate the image:
+
+```diff
+ from diffusers import DDIMInverseScheduler, DDIMScheduler
+ from diffusers.utils import load_image, make_image_grid
+ from PIL import Image
+
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+ pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
+
+ img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
+ raw_image = load_image(img_url).resize((768, 768))
+
+ mask_image = pipeline.generate_mask(
+ image=raw_image,
+- source_prompt=source_prompt,
+- target_prompt=target_prompt,
++ source_prompt_embeds=source_embeds,
++ target_prompt_embeds=target_embeds,
+ )
+
+ inv_latents = pipeline.invert(
+- prompt=source_prompt,
++ prompt_embeds=source_embeds,
+ image=raw_image,
+ ).latents
+
+ output_image = pipeline(
+ mask_image=mask_image,
+ image_latents=inv_latents,
+- prompt=target_prompt,
+- negative_prompt=source_prompt,
++ prompt_embeds=target_embeds,
++ negative_prompt_embeds=source_embeds,
+ ).images[0]
+ mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L")
+ make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
+```
+
+## Generate a caption for inversion
+
+While you can use the `source_prompt` as a caption to help generate the partially inverted latents, you can also use the [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) model to automatically generate a caption.
+
+Load the BLIP model and processor from the 🤗 Transformers library:
+
+```py
+import torch
+from transformers import BlipForConditionalGeneration, BlipProcessor
+
+processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
+model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16, low_cpu_mem_usage=True)
+```
+
+Create a utility function to generate a caption from the input image:
+
+```py
+@torch.no_grad()
+def generate_caption(images, caption_generator, caption_processor):
+ text = "a photograph of"
+
+ inputs = caption_processor(images, text, return_tensors="pt").to(device="cuda", dtype=caption_generator.dtype)
+ caption_generator.to("cuda")
+ outputs = caption_generator.generate(**inputs, max_new_tokens=128)
+
+ # offload caption generator
+ caption_generator.to("cpu")
+
+ caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
+ return caption
+```
+
+Load an input image and generate a caption for it using the `generate_caption` function:
+
+```py
+from diffusers.utils import load_image
+
+img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
+raw_image = load_image(img_url).resize((768, 768))
+caption = generate_caption(raw_image, model, processor)
+```
+
+
+
+
+ generated caption: "a photograph of a bowl of fruit on a table"
+
+
+
+Now you can drop the caption into the [`~StableDiffusionDiffEditPipeline.invert`] function to generate the partially inverted latents!
diff --git a/diffusers/docs/source/en/using-diffusers/distilled_sd.md b/diffusers/docs/source/en/using-diffusers/distilled_sd.md
new file mode 100644
index 0000000000000000000000000000000000000000..2dd96d98861d4476b00dde841821ee5abc038387
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/distilled_sd.md
@@ -0,0 +1,133 @@
+
+
+# Distilled Stable Diffusion inference
+
+[[open-in-colab]]
+
+Stable Diffusion inference can be a computationally intensive process because it must iteratively denoise the latents to generate an image. To reduce the computational burden, you can use a *distilled* version of the Stable Diffusion model from [Nota AI](https://huggingface.co/nota-ai). The distilled version of their Stable Diffusion model eliminates some of the residual and attention blocks from the UNet, reducing the model size by 51% and improving latency on CPU/GPU by 43%.
+
+
+
+Read this [blog post](https://huggingface.co/blog/sd_distillation) to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
+
+
+
+Let's load the distilled Stable Diffusion model and compare it against the original Stable Diffusion model:
+
+```py
+from diffusers import StableDiffusionPipeline
+import torch
+
+distilled = StableDiffusionPipeline.from_pretrained(
+ "nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
+).to("cuda")
+
+original = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, use_safetensors=True,
+).to("cuda")
+```
+
+Given a prompt, get the inference time for the original model:
+
+```py
+import time
+
+seed = 2023
+generator = torch.manual_seed(seed)
+
+NUM_ITERS_TO_RUN = 3
+NUM_INFERENCE_STEPS = 25
+NUM_IMAGES_PER_PROMPT = 4
+
+prompt = "a golden vase with different flowers"
+
+start = time.time_ns()
+for _ in range(NUM_ITERS_TO_RUN):
+ images = original(
+ prompt,
+ num_inference_steps=NUM_INFERENCE_STEPS,
+ generator=generator,
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT
+ ).images
+end = time.time_ns()
+original_sd = f"{(end - start) / 1e6:.1f}"
+
+print(f"Execution time -- {original_sd} ms\n")
+"Execution time -- 45781.5 ms"
+```
+
+Time the distilled model inference:
+
+```py
+start = time.time_ns()
+for _ in range(NUM_ITERS_TO_RUN):
+ images = distilled(
+ prompt,
+ num_inference_steps=NUM_INFERENCE_STEPS,
+ generator=generator,
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT
+ ).images
+end = time.time_ns()
+
+distilled_sd = f"{(end - start) / 1e6:.1f}"
+print(f"Execution time -- {distilled_sd} ms\n")
+"Execution time -- 29884.2 ms"
+```
+
+
+
+
+ original Stable Diffusion (45781.5 ms)
+
+
+
+ distilled Stable Diffusion (29884.2 ms)
+
+
+
+## Tiny AutoEncoder
+
+To speed inference up even more, use a tiny distilled version of the [Stable Diffusion VAE](https://huggingface.co/sayakpaul/taesdxl-diffusers) to denoise the latents into images. Replace the VAE in the distilled Stable Diffusion model with the tiny VAE:
+
+```py
+from diffusers import AutoencoderTiny
+
+distilled.vae = AutoencoderTiny.from_pretrained(
+ "sayakpaul/taesd-diffusers", torch_dtype=torch.float16, use_safetensors=True,
+).to("cuda")
+```
+
+Time the distilled model and distilled VAE inference:
+
+```py
+start = time.time_ns()
+for _ in range(NUM_ITERS_TO_RUN):
+ images = distilled(
+ prompt,
+ num_inference_steps=NUM_INFERENCE_STEPS,
+ generator=generator,
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT
+ ).images
+end = time.time_ns()
+
+distilled_tiny_sd = f"{(end - start) / 1e6:.1f}"
+print(f"Execution time -- {distilled_tiny_sd} ms\n")
+"Execution time -- 27165.7 ms"
+```
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/freeu.md b/diffusers/docs/source/en/using-diffusers/freeu.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e8f5773cd75356754a5d548409c53aaeae6f424
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/freeu.md
@@ -0,0 +1,135 @@
+
+
+# Improve generation quality with FreeU
+
+[[open-in-colab]]
+
+The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
+
+1. Backbone features primarily contribute to the denoising process
+2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features
+
+However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps.
+
+FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video.
+
+In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`]. You need to install Diffusers from source to run the examples below.
+
+## StableDiffusionPipeline
+
+Load the pipeline:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
+).to("cuda")
+```
+
+Then enable the FreeU mechanism with the FreeU-specific hyperparameters. These values are scaling factors for the backbone and skip features.
+
+```py
+pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
+```
+
+The values above are from the official FreeU [code repository](https://github.com/ChenyangSi/FreeU) where you can also find [reference hyperparameters](https://github.com/ChenyangSi/FreeU#range-for-more-parameters) for different models.
+
+
+
+Disable the FreeU mechanism by calling `disable_freeu()` on a pipeline.
+
+
+
+And then run inference:
+
+```py
+prompt = "A squirrel eating a burger"
+seed = 2023
+image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
+image
+```
+
+The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`):
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdv1_5_freeu.jpg)
+
+
+Let's see how Stable Diffusion 2 results are impacted:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None
+).to("cuda")
+
+prompt = "A squirrel eating a burger"
+seed = 2023
+
+pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
+image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
+image
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdv2_1_freeu.jpg)
+
+## Stable Diffusion XL
+
+Finally, let's take a look at how FreeU affects Stable Diffusion XL results:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
+).to("cuda")
+
+prompt = "A squirrel eating a burger"
+seed = 2023
+
+# Comes from
+# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
+pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
+image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
+image
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/freeu/sdxl_freeu.jpg)
+
+## Text-to-video generation
+
+FreeU can also be used to improve video quality:
+
+```python
+from diffusers import DiffusionPipeline
+from diffusers.utils import export_to_video
+import torch
+
+model_id = "cerspense/zeroscope_v2_576w"
+pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "an astronaut riding a horse on mars"
+seed = 2023
+
+# The values come from
+# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines
+pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)
+video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames
+export_to_video(video_frames, "astronaut_rides_horse.mp4")
+```
+
+Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.
diff --git a/diffusers/docs/source/en/using-diffusers/img2img.md b/diffusers/docs/source/en/using-diffusers/img2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..6014d87b79064fd307a150cd379a4e3568822f9a
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/img2img.md
@@ -0,0 +1,605 @@
+
+
+# Image-to-image
+
+[[open-in-colab]]
+
+Image-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.
+
+With 🤗 Diffusers, this is as easy as 1-2-3:
+
+1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+```
+
+
+
+You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
+
+
+
+2. Load an image to pass to the pipeline:
+
+```py
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
+```
+
+3. Pass a prompt and image to the pipeline to generate an image:
+
+```py
+prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
+image = pipeline(prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+## Popular models
+
+The most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.
+
+### Stable Diffusion v1.5
+
+Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+### Stable Diffusion XL (SDXL)
+
+SDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, strength=0.5).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+### Kandinsky 2.2
+
+The Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.
+
+The simplest way to use Kandinsky 2.2 is:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ generated image
+
+
+
+## Configure pipeline parameters
+
+There are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.
+
+### Strength
+
+`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:
+
+- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
+- 📉 a lower `strength` value means the generated image is more similar to the initial image
+
+The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, strength=0.8).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ strength = 0.4
+
+
+
+ strength = 0.6
+
+
+
+ strength = 1.0
+
+
+
+### Guidance scale
+
+The `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.
+
+You can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+ guidance_scale = 0.1
+
+
+
+ guidance_scale = 5.0
+
+
+
+ guidance_scale = 10.0
+
+
+
+### Negative prompt
+
+A negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like "poor details" or "blurry" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+## Chained image-to-image pipelines
+
+There are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.
+
+### Text-to-image-to-image
+
+Chaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.
+
+Start by generating an image with the text-to-image pipeline:
+
+```py
+from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
+import torch
+from diffusers.utils import make_image_grid
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
+text2image
+```
+
+Now you can pass this generated image to the image-to-image pipeline:
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0]
+make_image_grid([text2image, image2image], rows=1, cols=2)
+```
+
+### Image-to-image-to-image
+
+You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image.
+
+Start by generating an image:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image = pipeline(prompt, image=init_image, output_type="latent").images[0]
+```
+
+
+
+It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
+
+
+
+Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "ogkalu/Comic-Diffusion", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
+image = pipeline("Astronaut in a jungle, charliebo artstyle", image=image, output_type="latent").images[0]
+```
+
+Repeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "kohbanye/pixel-art-style", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# need to include the token "pixelartstyle" in the prompt to use this checkpoint
+image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+### Image-to-upscaler-to-super-resolution
+
+Another way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.
+
+Start with an image-to-image pipeline:
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
+```
+
+
+
+It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
+
+
+
+Chain it to an upscaler pipeline to increase the image resolution:
+
+```py
+from diffusers import StableDiffusionLatentUpscalePipeline
+
+upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
+ "stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+upscaler.enable_model_cpu_offload()
+upscaler.enable_xformers_memory_efficient_attention()
+
+image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
+```
+
+Finally, chain it to a super-resolution pipeline to further enhance the resolution:
+
+```py
+from diffusers import StableDiffusionUpscalePipeline
+
+super_res = StableDiffusionUpscalePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+super_res.enable_model_cpu_offload()
+super_res.enable_xformers_memory_efficient_attention()
+
+image_3 = super_res(prompt, image=image_2).images[0]
+make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2)
+```
+
+## Control image generation
+
+Trying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.
+
+### Prompt weighting
+
+Prompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", you can choose to increase or decrease the embeddings of "astronaut" and "jungle". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.
+
+[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.
+
+```py
+from diffusers import AutoPipelineForImage2Image
+import torch
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
+ image=init_image,
+).images[0]
+```
+
+### ControlNet
+
+ControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.
+
+For example, let's condition an image with a depth map to keep the spatial information in the image.
+
+```py
+from diffusers.utils import load_image, make_image_grid
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+init_image = init_image.resize((958, 960)) # resize to depth image dimensions
+depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
+make_image_grid([init_image, depth_image], rows=1, cols=2)
+```
+
+Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
+
+```py
+from diffusers import ControlNetModel, AutoPipelineForImage2Image
+import torch
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+```
+
+Now generate a new image conditioned on the depth map, initial image, and prompt:
+
+```py
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
+make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3)
+```
+
+
+
+
+ initial image
+
+
+
+ depth image
+
+
+
+ ControlNet image
+
+
+
+Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
+negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0]
+make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2)
+```
+
+
+
+
+
+## Optimize
+
+Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
+
+```diff
++ pipeline.enable_model_cpu_offload()
++ pipeline.enable_xformers_memory_efficient_attention()
+```
+
+With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it:
+
+```py
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/diffusers/docs/source/en/using-diffusers/inference_with_lcm.md b/diffusers/docs/source/en/using-diffusers/inference_with_lcm.md
new file mode 100644
index 0000000000000000000000000000000000000000..36b3c6c810fc38ffbf3182997342e34eb1b77a17
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/inference_with_lcm.md
@@ -0,0 +1,274 @@
+
+
+[[open-in-colab]]
+
+# Latent Consistency Model
+
+Latent Consistency Models (LCM) enable quality image generation in typically 2-4 steps making it possible to use diffusion models in almost real-time settings.
+
+From the [official website](https://latent-consistency-models.github.io/):
+
+> LCMs can be distilled from any pre-trained Stable Diffusion (SD) in only 4,000 training steps (~32 A100 GPU Hours) for generating high quality 768 x 768 resolution images in 2~4 steps or even one step, significantly accelerating text-to-image generation. We employ LCM to distill the Dreamshaper-V7 version of SD in just 4,000 training iterations.
+
+For a more technical overview of LCMs, refer to [the paper](https://huggingface.co/papers/2310.04378).
+
+LCM distilled models are available for [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and the [SSD-1B](https://huggingface.co/segmind/SSD-1B) model. All the checkpoints can be found in this [collection](https://huggingface.co/collections/latent-consistency/latent-consistency-models-weights-654ce61a95edd6dffccef6a8).
+
+This guide shows how to perform inference with LCMs for
+- text-to-image
+- image-to-image
+- combined with style LoRAs
+- ControlNet/T2I-Adapter
+
+## Text-to-image
+
+You'll use the [`StableDiffusionXLPipeline`] pipeline with the [`LCMScheduler`] and then load the LCM-LoRA. Together with the LCM-LoRA and the scheduler, the pipeline enables a fast inference workflow, overcoming the slow iterative nature of diffusion models.
+
+```python
+from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, LCMScheduler
+import torch
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16, variant="fp16",
+).to("cuda")
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=8.0
+).images[0]
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdxl_t2i.png)
+
+Notice that we use only 4 steps for generation which is way less than what's typically used for standard SDXL.
+
+Some details to keep in mind:
+
+* To perform classifier-free guidance, batch size is usually doubled inside the pipeline. LCM, however, applies guidance using guidance embeddings, so the batch size does not have to be doubled in this case. This leads to a faster inference time, with the drawback that negative prompts don't have any effect on the denoising process.
+* The UNet was trained using the [3., 13.] guidance scale range. So, that is the ideal range for `guidance_scale`. However, disabling `guidance_scale` using a value of 1.0 is also effective in most cases.
+
+
+## Image-to-image
+
+LCMs can be applied to image-to-image tasks too. For this example, we'll use the [LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) model, but the same steps can be applied to other LCM models as well.
+
+```python
+import torch
+from diffusers import AutoPipelineForImage2Image, UNet2DConditionModel, LCMScheduler
+from diffusers.utils import make_image_grid, load_image
+
+unet = UNet2DConditionModel.from_pretrained(
+ "SimianLuo/LCM_Dreamshaper_v7",
+ subfolder="unet",
+ torch_dtype=torch.float16,
+)
+
+pipe = AutoPipelineForImage2Image.from_pretrained(
+ "Lykon/dreamshaper-7",
+ unet=unet,
+ torch_dtype=torch.float16,
+ variant="fp16",
+).to("cuda")
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt,
+ image=init_image,
+ num_inference_steps=4,
+ guidance_scale=7.5,
+ strength=0.5,
+ generator=generator
+).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdv1-5_i2i.png)
+
+
+
+
+You can get different results based on your prompt and the image you provide. To get the best results, we recommend trying different values for `num_inference_steps`, `strength`, and `guidance_scale` parameters and choose the best one.
+
+
+
+
+## Combine with style LoRAs
+
+LCMs can be used with other styled LoRAs to generate styled-images in very few steps (4-8). In the following example, we'll use the [papercut LoRA](TheLastBen/Papercut_SDXL).
+
+```python
+from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, LCMScheduler
+import torch
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16, variant="fp16",
+).to("cuda")
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+pipe.load_lora_weights("TheLastBen/Papercut_SDXL", weight_name="papercut.safetensors", adapter_name="papercut")
+
+prompt = "papercut, a cute fox"
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=8.0
+).images[0]
+image
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdx_lora_mix.png)
+
+
+## ControlNet/T2I-Adapter
+
+Let's look at how we can perform inference with ControlNet/T2I-Adapter and a LCM.
+
+### ControlNet
+For this example, we'll use the [LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) model with canny ControlNet, but the same steps can be applied to other LCM models as well.
+
+```python
+import torch
+import cv2
+import numpy as np
+from PIL import Image
+
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, LCMScheduler
+from diffusers.utils import load_image, make_image_grid
+
+image = load_image(
+ "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+).resize((512, 512))
+
+image = np.array(image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ "SimianLuo/LCM_Dreamshaper_v7",
+ controlnet=controlnet,
+ torch_dtype=torch.float16,
+ safety_checker=None,
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+generator = torch.manual_seed(0)
+image = pipe(
+ "the mona lisa",
+ image=canny_image,
+ num_inference_steps=4,
+ generator=generator,
+).images[0]
+make_image_grid([canny_image, image], rows=1, cols=2)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdv1-5_controlnet.png)
+
+
+
+The inference parameters in this example might not work for all examples, so we recommend trying different values for the `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale`, and `cross_attention_kwargs` parameters and choosing the best one.
+
+
+### T2I-Adapter
+
+This example shows how to use the `lcm-sdxl` with the [Canny T2I-Adapter](TencentARC/t2i-adapter-canny-sdxl-1.0).
+
+```python
+import torch
+import cv2
+import numpy as np
+from PIL import Image
+
+from diffusers import StableDiffusionXLAdapterPipeline, UNet2DConditionModel, T2IAdapter, LCMScheduler
+from diffusers.utils import load_image, make_image_grid
+
+# Prepare image
+# Detect the canny map in low resolution to avoid high-frequency details
+image = load_image(
+ "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_canny.jpg"
+).resize((384, 384))
+
+image = np.array(image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image).resize((1024, 1216))
+
+# load adapter
+adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-canny-sdxl-1.0", torch_dtype=torch.float16, varient="fp16").to("cuda")
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ unet=unet,
+ adapter=adapter,
+ torch_dtype=torch.float16,
+ variant="fp16",
+).to("cuda")
+
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "Mystical fairy in real, magic, 4k picture, high quality"
+negative_prompt = "extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=canny_image,
+ num_inference_steps=4,
+ guidance_scale=5,
+ adapter_conditioning_scale=0.8,
+ adapter_conditioning_factor=1,
+ generator=generator,
+).images[0]
+grid = make_image_grid([canny_image, image], rows=1, cols=2)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdxl_t2iadapter.png)
diff --git a/diffusers/docs/source/en/using-diffusers/inference_with_lcm_lora.md b/diffusers/docs/source/en/using-diffusers/inference_with_lcm_lora.md
new file mode 100644
index 0000000000000000000000000000000000000000..554e5fda2c2a12971d93ba8c8f70424a52efbe2e
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/inference_with_lcm_lora.md
@@ -0,0 +1,422 @@
+
+
+[[open-in-colab]]
+
+# Performing inference with LCM-LoRA
+
+Latent Consistency Models (LCM) enable quality image generation in typically 2-4 steps making it possible to use diffusion models in almost real-time settings.
+
+From the [official website](https://latent-consistency-models.github.io/):
+
+> LCMs can be distilled from any pre-trained Stable Diffusion (SD) in only 4,000 training steps (~32 A100 GPU Hours) for generating high quality 768 x 768 resolution images in 2~4 steps or even one step, significantly accelerating text-to-image generation. We employ LCM to distill the Dreamshaper-V7 version of SD in just 4,000 training iterations.
+
+For a more technical overview of LCMs, refer to [the paper](https://huggingface.co/papers/2310.04378).
+
+However, each model needs to be distilled separately for latent consistency distillation. The core idea with LCM-LoRA is to train just a few adapter layers, the adapter being LoRA in this case.
+This way, we don't have to train the full model and keep the number of trainable parameters manageable. The resulting LoRAs can then be applied to any fine-tuned version of the model without distilling them separately.
+Additionally, the LoRAs can be applied to image-to-image, ControlNet/T2I-Adapter, inpainting, AnimateDiff etc.
+The LCM-LoRA can also be combined with other LoRAs to generate styled images in very few steps (4-8).
+
+LCM-LoRAs are available for [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and the [SSD-1B](https://huggingface.co/segmind/SSD-1B) model. All the checkpoints can be found in this [collection](https://huggingface.co/collections/latent-consistency/latent-consistency-models-loras-654cdd24e111e16f0865fba6).
+
+For more details about LCM-LoRA, refer to [the technical report](https://huggingface.co/papers/2311.05556).
+
+This guide shows how to perform inference with LCM-LoRAs for
+- text-to-image
+- image-to-image
+- combined with styled LoRAs
+- ControlNet/T2I-Adapter
+- inpainting
+- AnimateDiff
+
+Before going through this guide, we'll take a look at the general workflow for performing inference with LCM-LoRAs.
+LCM-LoRAs are similar to other Stable Diffusion LoRAs so they can be used with any [`DiffusionPipeline`] that supports LoRAs.
+
+- Load the task specific pipeline and model.
+- Set the scheduler to [`LCMScheduler`].
+- Load the LCM-LoRA weights for the model.
+- Reduce the `guidance_scale` between `[1.0, 2.0]` and set the `num_inference_steps` between [4, 8].
+- Perform inference with the pipeline with the usual parameters.
+
+Let's look at how we can perform inference with LCM-LoRAs for different tasks.
+
+First, make sure you have [peft](https://github.com/huggingface/peft) installed, for better LoRA support.
+
+```bash
+pip install -U peft
+```
+
+## Text-to-image
+
+You'll use the [`StableDiffusionXLPipeline`] with the scheduler: [`LCMScheduler`] and then load the LCM-LoRA. Together with the LCM-LoRA and the scheduler, the pipeline enables a fast inference workflow overcoming the slow iterative nature of diffusion models.
+
+```python
+import torch
+from diffusers import DiffusionPipeline, LCMScheduler
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ variant="fp16",
+ torch_dtype=torch.float16
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
+
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+generator = torch.manual_seed(42)
+image = pipe(
+ prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=1.0
+).images[0]
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdxl_t2i.png)
+
+Notice that we use only 4 steps for generation which is way less than what's typically used for standard SDXL.
+
+
+
+You may have noticed that we set `guidance_scale=1.0`, which disables classifer-free-guidance. This is because the LCM-LoRA is trained with guidance, so the batch size does not have to be doubled in this case. This leads to a faster inference time, with the drawback that negative prompts don't have any effect on the denoising process.
+
+You can also use guidance with LCM-LoRA, but due to the nature of training the model is very sensitve to the `guidance_scale` values, high values can lead to artifacts in the generated images. In our experiments, we found that the best values are in the range of [1.0, 2.0].
+
+
+
+### Inference with a fine-tuned model
+
+As mentioned above, the LCM-LoRA can be applied to any fine-tuned version of the model without having to distill them separately. Let's look at how we can perform inference with a fine-tuned model. In this example, we'll use the [animagine-xl](https://huggingface.co/Linaqruf/animagine-xl) model, which is a fine-tuned version of the SDXL model for generating anime.
+
+```python
+from diffusers import DiffusionPipeline, LCMScheduler
+
+pipe = DiffusionPipeline.from_pretrained(
+ "Linaqruf/animagine-xl",
+ variant="fp16",
+ torch_dtype=torch.float16
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
+
+prompt = "face focus, cute, masterpiece, best quality, 1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck"
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=1.0
+).images[0]
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdxl_t2i_finetuned.png)
+
+
+## Image-to-image
+
+LCM-LoRA can be applied to image-to-image tasks too. Let's look at how we can perform image-to-image generation with LCMs. For this example we'll use the [dreamshaper-7](https://huggingface.co/Lykon/dreamshaper-7) model and the LCM-LoRA for `stable-diffusion-v1-5 `.
+
+```python
+import torch
+from diffusers import AutoPipelineForImage2Image, LCMScheduler
+from diffusers.utils import make_image_grid, load_image
+
+pipe = AutoPipelineForImage2Image.from_pretrained(
+ "Lykon/dreamshaper-7",
+ torch_dtype=torch.float16,
+ variant="fp16",
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
+
+# prepare image
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
+init_image = load_image(url)
+prompt = "Astronauts in a jungle, cold color palette, muted colors, detailed, 8k"
+
+# pass prompt and image to pipeline
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt,
+ image=init_image,
+ num_inference_steps=4,
+ guidance_scale=1,
+ strength=0.6,
+ generator=generator
+).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdv1-5_i2i.png)
+
+
+
+
+You can get different results based on your prompt and the image you provide. To get the best results, we recommend trying different values for `num_inference_steps`, `strength`, and `guidance_scale` parameters and choose the best one.
+
+
+
+
+## Combine with styled LoRAs
+
+LCM-LoRA can be combined with other LoRAs to generate styled-images in very few steps (4-8). In the following example, we'll use the LCM-LoRA with the [papercut LoRA](TheLastBen/Papercut_SDXL).
+To learn more about how to combine LoRAs, refer to [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#combine-multiple-adapters).
+
+```python
+import torch
+from diffusers import DiffusionPipeline, LCMScheduler
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ variant="fp16",
+ torch_dtype=torch.float16
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LoRAs
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl", adapter_name="lcm")
+pipe.load_lora_weights("TheLastBen/Papercut_SDXL", weight_name="papercut.safetensors", adapter_name="papercut")
+
+# Combine LoRAs
+pipe.set_adapters(["lcm", "papercut"], adapter_weights=[1.0, 0.8])
+
+prompt = "papercut, a cute fox"
+generator = torch.manual_seed(0)
+image = pipe(prompt, num_inference_steps=4, guidance_scale=1, generator=generator).images[0]
+image
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdx_lora_mix.png)
+
+
+## ControlNet/T2I-Adapter
+
+Let's look at how we can perform inference with ControlNet/T2I-Adapter and LCM-LoRA.
+
+### ControlNet
+For this example, we'll use the SD-v1-5 model and the LCM-LoRA for SD-v1-5 with canny ControlNet.
+
+```python
+import torch
+import cv2
+import numpy as np
+from PIL import Image
+
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, LCMScheduler
+from diffusers.utils import load_image
+
+image = load_image(
+ "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+).resize((512, 512))
+
+image = np.array(image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ torch_dtype=torch.float16,
+ safety_checker=None,
+ variant="fp16"
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
+
+generator = torch.manual_seed(0)
+image = pipe(
+ "the mona lisa",
+ image=canny_image,
+ num_inference_steps=4,
+ guidance_scale=1.5,
+ controlnet_conditioning_scale=0.8,
+ cross_attention_kwargs={"scale": 1},
+ generator=generator,
+).images[0]
+make_image_grid([canny_image, image], rows=1, cols=2)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdv1-5_controlnet.png)
+
+
+
+The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
+
+
+### T2I-Adapter
+
+This example shows how to use the LCM-LoRA with the [Canny T2I-Adapter](TencentARC/t2i-adapter-canny-sdxl-1.0) and SDXL.
+
+```python
+import torch
+import cv2
+import numpy as np
+from PIL import Image
+
+from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, LCMScheduler
+from diffusers.utils import load_image, make_image_grid
+
+# Prepare image
+# Detect the canny map in low resolution to avoid high-frequency details
+image = load_image(
+ "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_canny.jpg"
+).resize((384, 384))
+
+image = np.array(image)
+
+low_threshold = 100
+high_threshold = 200
+
+image = cv2.Canny(image, low_threshold, high_threshold)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image).resize((1024, 1024))
+
+# load adapter
+adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-canny-sdxl-1.0", torch_dtype=torch.float16, varient="fp16").to("cuda")
+
+pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ adapter=adapter,
+ torch_dtype=torch.float16,
+ variant="fp16",
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
+
+prompt = "Mystical fairy in real, magic, 4k picture, high quality"
+negative_prompt = "extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured"
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=canny_image,
+ num_inference_steps=4,
+ guidance_scale=1.5,
+ adapter_conditioning_scale=0.8,
+ adapter_conditioning_factor=1,
+ generator=generator,
+).images[0]
+make_image_grid([canny_image, image], rows=1, cols=2)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdxl_t2iadapter.png)
+
+
+## Inpainting
+
+LCM-LoRA can be used for inpainting as well.
+
+```python
+import torch
+from diffusers import AutoPipelineForInpainting, LCMScheduler
+from diffusers.utils import load_image, make_image_grid
+
+pipe = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ torch_dtype=torch.float16,
+ variant="fp16",
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+# generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ image=init_image,
+ mask_image=mask_image,
+ generator=generator,
+ num_inference_steps=4,
+ guidance_scale=4,
+).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdv1-5_inpainting.png)
+
+
+## AnimateDiff
+
+[`AnimateDiff`] allows you to animate images using Stable Diffusion models. To get good results, we need to generate multiple frames (16-24), and doing this with standard SD models can be very slow.
+LCM-LoRA can be used to speed up the process significantly, as you just need to do 4-8 steps for each frame. Let's look at how we can perform animation with LCM-LoRA and AnimateDiff.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler, LCMScheduler
+from diffusers.utils import export_to_gif
+
+adapter = MotionAdapter.from_pretrained("diffusers/animatediff-motion-adapter-v1-5")
+pipe = AnimateDiffPipeline.from_pretrained(
+ "frankjoshua/toonyou_beta6",
+ motion_adapter=adapter,
+).to("cuda")
+
+# set scheduler
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+# load LCM-LoRA
+pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5", adapter_name="lcm")
+pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-in", weight_name="diffusion_pytorch_model.safetensors", adapter_name="motion-lora")
+
+pipe.set_adapters(["lcm", "motion-lora"], adapter_weights=[0.55, 1.2])
+
+prompt = "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
+generator = torch.manual_seed(0)
+frames = pipe(
+ prompt=prompt,
+ num_inference_steps=5,
+ guidance_scale=1.25,
+ cross_attention_kwargs={"scale": 1},
+ num_frames=24,
+ generator=generator
+).frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdv1-5_animatediff.gif)
\ No newline at end of file
diff --git a/diffusers/docs/source/en/using-diffusers/inpaint.md b/diffusers/docs/source/en/using-diffusers/inpaint.md
new file mode 100644
index 0000000000000000000000000000000000000000..e6b1010f13b0f91eeb3b581f57cf5248118f683d
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/inpaint.md
@@ -0,0 +1,752 @@
+
+
+# Inpainting
+
+[[open-in-colab]]
+
+Inpainting replaces or edits specific areas of an image. This makes it a useful tool for image restoration like removing defects and artifacts, or even replacing an image area with something entirely new. Inpainting relies on a mask to determine which regions of an image to fill in; the area to inpaint is represented by white pixels and the area to keep is represented by black pixels. The white pixels are filled in by the prompt.
+
+With 🤗 Diffusers, here is how you can do inpainting:
+
+1. Load an inpainting checkpoint with the [`AutoPipelineForInpainting`] class. This'll automatically detect the appropriate pipeline class to load based on the checkpoint:
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+```
+
+
+
+You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, it's not necessary to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
+
+
+
+2. Load the base and mask images:
+
+```py
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+```
+
+3. Create a prompt to inpaint the image with and pass it to the pipeline with the base and mask images:
+
+```py
+prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k"
+negative_prompt = "bad anatomy, deformed, ugly, disfigured"
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+
+ base image
+
+
+
+ mask image
+
+
+
+ generated image
+
+
+
+## Create a mask image
+
+Throughout this guide, the mask image is provided in all of the code examples for convenience. You can inpaint on your own images, but you'll need to create a mask image for it. Use the Space below to easily create a mask image.
+
+Upload a base image to inpaint on and use the sketch tool to draw a mask. Once you're done, click **Run** to generate and download the mask image.
+
+
+
+## Popular models
+
+[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
+
+### Stable Diffusion Inpainting
+
+Stable Diffusion Inpainting is a latent diffusion model finetuned on 512x512 images on inpainting. It is a good starting point because it is relatively fast and generates good quality images. To use this model for inpainting, you'll need to pass a prompt, base and mask image to the pipeline:
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+### Stable Diffusion XL (SDXL) Inpainting
+
+SDXL is a larger and more powerful version of Stable Diffusion v1.5. This model can follow a two-stage model process (though each model can also be used alone); the base model generates an image, and a refiner model takes that image and further enhances its details and quality. Take a look at the [SDXL](sdxl) guide for a more comprehensive guide on how to use SDXL and configure it's parameters.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+### Kandinsky 2.2 Inpainting
+
+The Kandinsky model family is similar to SDXL because it uses two models as well; the image prior model creates image embeddings, and the diffusion model generates images from them. You can load the image prior and diffusion model separately, but the easiest way to use Kandinsky 2.2 is to load it into the [`AutoPipelineForInpainting`] class which uses the [`KandinskyV22InpaintCombinedPipeline`] under the hood.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+
+ base image
+
+
+
+ Stable Diffusion Inpainting
+
+
+
+ Stable Diffusion XL Inpainting
+
+
+
+ Kandinsky 2.2 Inpainting
+
+
+
+## Non-inpaint specific checkpoints
+
+So far, this guide has used inpaint specific checkpoints such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). But you can also use regular checkpoints like [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Let's compare the results of the two checkpoints.
+
+The image on the left is generated from a regular checkpoint, and the image on the right is from an inpaint checkpoint. You'll immediately notice the image on the left is not as clean, and you can still see the outline of the area the model is supposed to inpaint. The image on the right is much cleaner and the inpainted area appears more natural.
+
+
+
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+generator = torch.Generator("cuda").manual_seed(92)
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+
+
+
+ runwayml/stable-diffusion-v1-5
+
+
+
+ runwayml/stable-diffusion-inpainting
+
+
+
+However, for more basic tasks like erasing an object from an image (like the rocks in the road for example), a regular checkpoint yields pretty good results. There isn't as noticeable of difference between the regular and inpaint checkpoint.
+
+
+
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png")
+
+image = pipeline(prompt="road", image=init_image, mask_image=mask_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png")
+
+image = pipeline(prompt="road", image=init_image, mask_image=mask_image).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+
+
+
+ runwayml/stable-diffusion-v1-5
+
+
+
+ runwayml/stable-diffusion-inpainting
+
+
+
+The trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.
+
+If preserving the unmasked area is important for your task, you can use the code below to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
+
+```py
+import PIL
+import numpy as np
+import torch
+
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+device = "cuda"
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ torch_dtype=torch.float16,
+)
+pipeline = pipeline.to(device)
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).resize((512, 512))
+mask_image = load_image(mask_url).resize((512, 512))
+
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+repainted_image.save("repainted_image.png")
+
+# Convert mask to grayscale NumPy array
+mask_image_arr = np.array(mask_image.convert("L"))
+# Add a channel dimension to the end of the grayscale mask
+mask_image_arr = mask_image_arr[:, :, None]
+# Binarize the mask: 1s correspond to the pixels which are repainted
+mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
+mask_image_arr[mask_image_arr < 0.5] = 0
+mask_image_arr[mask_image_arr >= 0.5] = 1
+
+# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
+unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
+unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
+unmasked_unchanged_image.save("force_unmasked_unchanged.png")
+make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
+```
+
+## Configure pipeline parameters
+
+Image features - like quality and "creativity" - are dependent on pipeline parameters. Knowing what these parameters do is important for getting the results you want. Let's take a look at the most important parameters and see how changing them affects the output.
+
+### Strength
+
+`strength` is a measure of how much noise is added to the base image, which influences how similar the output is to the base image.
+
+* 📈 a high `strength` value means more noise is added to an image and the denoising process takes longer, but you'll get higher quality images that are more different from the base image
+* 📉 a low `strength` value means less noise is added to an image and the denoising process is faster, but the image quality may not be as great and the generated image resembles the base image more
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.6).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+
+ strength = 0.6
+
+
+
+ strength = 0.8
+
+
+
+ strength = 1.0
+
+
+
+### Guidance scale
+
+`guidance_scale` affects how aligned the text prompt and generated image are.
+
+* 📈 a high `guidance_scale` value means the prompt and generated image are closely aligned, so the output is a stricter interpretation of the prompt
+* 📉 a low `guidance_scale` value means the prompt and generated image are more loosely aligned, so the output may be more varied from the prompt
+
+You can use `strength` and `guidance_scale` together for more control over how expressive the model is. For example, a combination high `strength` and `guidance_scale` values gives the model the most creative freedom.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=2.5).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+
+ guidance_scale = 2.5
+
+
+
+ guidance_scale = 7.5
+
+
+
+ guidance_scale = 12.5
+
+
+
+### Negative prompt
+
+A negative prompt assumes the opposite role of a prompt; it guides the model away from generating certain things in an image. This is useful for quickly improving image quality and preventing the model from generating things you don't want.
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+negative_prompt = "bad architecture, unstable, poor details, blurry"
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+## Chained inpainting pipelines
+
+[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.
+
+### Text-to-image-to-inpaint
+
+Chaining a text-to-image and inpainting pipeline allows you to inpaint the generated image, and you don't have to provide a base image to begin with. This makes it convenient to edit your favorite text-to-image outputs without having to generate an entirely new image.
+
+Start with the text-to-image pipeline to create a castle:
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+text2image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
+```
+
+Load the mask image of the output from above:
+
+```py
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png")
+```
+
+And let's inpaint the masked area with a waterfall:
+
+```py
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "digital painting of a fantasy waterfall, cloudy"
+image = pipeline(prompt=prompt, image=text2image, mask_image=mask_image).images[0]
+make_image_grid([text2image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+
+ text-to-image
+
+
+
+ inpaint
+
+
+
+### Inpaint-to-image-to-image
+
+You can also chain an inpainting pipeline before another pipeline like image-to-image or an upscaler to improve the quality.
+
+Begin by inpainting an image:
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image
+from diffusers.utils import load_image, make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image_inpainting = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+
+# resize image to 1024x1024 for SDXL
+image_inpainting = image_inpainting.resize((1024, 1024))
+```
+
+Now let's pass the image to another inpainting pipeline with SDXL's refiner model to enhance the image details and quality:
+
+```py
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt=prompt, image=image_inpainting, mask_image=mask_image, output_type="latent").images[0]
+```
+
+
+
+It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. For example, in the [Text-to-image-to-inpaint](#text-to-image-to-inpaint) section, Kandinsky 2.2 uses a different VAE class than the Stable Diffusion model so it won't work. But if you use Stable Diffusion v1.5 for both pipelines, then you can keep everything in latent space because they both use [`AutoencoderKL`].
+
+
+
+Finally, you can pass this image to an image-to-image pipeline to put the finishing touches on it. It is more efficient to use the [`~AutoPipelineForImage2Image.from_pipe`] method to reuse the existing pipeline components, and avoid unnecessarily loading all the pipeline components into memory again.
+
+```py
+pipeline = AutoPipelineForImage2Image.from_pipe(pipeline)
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt=prompt, image=image).images[0]
+make_image_grid([init_image, mask_image, image_inpainting, image], rows=2, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ inpaint
+
+
+
+ image-to-image
+
+
+
+Image-to-image and inpainting are actually very similar tasks. Image-to-image generates a new image that resembles the existing provided image. Inpainting does the same thing, but it only transforms the image area defined by the mask and the rest of the image is unchanged. You can think of inpainting as a more precise tool for making specific changes and image-to-image has a broader scope for making more sweeping changes.
+
+## Control image generation
+
+Getting an image to look exactly the way you want is challenging because the denoising process is random. While you can control certain aspects of generation by configuring parameters like `negative_prompt`, there are better and more efficient methods for controlling image generation.
+
+### Prompt weighting
+
+Prompt weighting provides a quantifiable way to scale the representation of concepts in a prompt. You can use it to increase or decrease the magnitude of the text embedding vector for each concept in the prompt, which subsequently determines how much of each concept is generated. The [Compel](https://github.com/damian0815/compel) library offers an intuitive syntax for scaling the prompt weights and generating the embeddings. Learn how to create the embeddings in the [Prompt weighting](../using-diffusers/weighted_prompts) guide.
+
+Once you've generated the embeddings, pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the [`AutoPipelineForInpainting`]. The embeddings replace the `prompt` parameter:
+
+```py
+import torch
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import make_image_grid
+
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
+ image=init_image,
+ mask_image=mask_image
+).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+### ControlNet
+
+ControlNet models are used with other diffusion models like Stable Diffusion, and they provide an even more flexible and accurate way to control how an image is generated. A ControlNet accepts an additional conditioning image input that guides the diffusion model to preserve the features in it.
+
+For example, let's condition an image with a ControlNet pretrained on inpaint images:
+
+```py
+import torch
+import numpy as np
+from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
+from diffusers.utils import load_image, make_image_grid
+
+# load ControlNet
+controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
+
+# pass ControlNet to the pipeline
+pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+# load base and mask image
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+
+# prepare control image
+def make_inpaint_condition(init_image, mask_image):
+ init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
+ mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
+
+ assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size"
+ init_image[mask_image > 0.5] = -1.0 # set as masked pixel
+ init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)
+ init_image = torch.from_numpy(init_image)
+ return init_image
+
+control_image = make_inpaint_condition(init_image, mask_image)
+```
+
+Now generate an image from the base, mask and control images. You'll notice features of the base image are strongly preserved in the generated image.
+
+```py
+prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
+make_image_grid([init_image, mask_image, PIL.Image.fromarray(np.uint8(control_image[0][0])).convert('RGB'), image], rows=2, cols=2)
+```
+
+You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
+
+```py
+from diffusers import AutoPipelineForImage2Image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
+)
+pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
+pipeline.enable_xformers_memory_efficient_attention()
+
+prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
+negative_prompt = "bad architecture, deformed, disfigured, poor details"
+
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
+make_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=2)
+```
+
+
+
+
+ initial image
+
+
+
+ ControlNet inpaint
+
+
+
+ image-to-image
+
+
+
+## Optimize
+
+It can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
+
+You can also offload the model to the CPU to save even more memory:
+
+```diff
++ pipeline.enable_xformers_memory_efficient_attention()
++ pipeline.enable_model_cpu_offload()
+```
+
+To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torchcompile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
+
+```py
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/diffusers/docs/source/en/using-diffusers/kandinsky.md b/diffusers/docs/source/en/using-diffusers/kandinsky.md
new file mode 100644
index 0000000000000000000000000000000000000000..05be2e1ee289b5b55cd6dfc155936349064ad861
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/kandinsky.md
@@ -0,0 +1,723 @@
+
+
+# Kandinsky
+
+[[open-in-colab]]
+
+The Kandinsky models are a series of multilingual text-to-image generation models. The Kandinsky 2.0 model uses two multilingual text encoders and concatenates those results for the UNet.
+
+[Kandinsky 2.1](../api/pipelines/kandinsky) changes the architecture to include an image prior model ([`CLIP`](https://huggingface.co/docs/transformers/model_doc/clip)) to generate a mapping between text and image embeddings. The mapping provides better text-image alignment and it is used with the text embeddings during training, leading to higher quality results. Finally, Kandinsky 2.1 uses a [Modulating Quantized Vectors (MoVQ)](https://huggingface.co/papers/2209.09002) decoder - which adds a spatial conditional normalization layer to increase photorealism - to decode the latents into images.
+
+[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.
+
+This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q diffusers transformers accelerate
+```
+
+
+
+Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
+
+
+
+## Text-to-image
+
+To use the Kandinsky models for any task, you always start by setting up the prior pipeline to encode the prompt and generate the image embeddings. The prior pipeline also generates `negative_image_embeds` that correspond to the negative prompt `""`. For better results, you can pass an actual `negative_prompt` to the prior pipeline, but this'll increase the effective batch size of the prior pipeline by 2x.
+
+
+
+
+```py
+from diffusers import KandinskyPriorPipeline, KandinskyPipeline
+import torch
+
+prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16).to("cuda")
+pipeline = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16).to("cuda")
+
+prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
+negative_prompt = "low quality, bad quality" # optional to include a negative prompt, but results are usually better
+image_embeds, negative_image_embeds = prior_pipeline(prompt, negative_prompt, guidance_scale=1.0).to_tuple()
+```
+
+Now pass all the prompts and embeddings to the [`KandinskyPipeline`] to generate an image:
+
+```py
+image = pipeline(prompt, image_embeds=image_embeds, negative_prompt=negative_prompt, negative_image_embeds=negative_image_embeds, height=768, width=768).images[0]
+image
+```
+
+
+
+
+
+
+
+
+```py
+from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
+import torch
+
+prior_pipeline = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16).to("cuda")
+pipeline = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16).to("cuda")
+
+prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
+negative_prompt = "low quality, bad quality" # optional to include a negative prompt, but results are usually better
+image_embeds, negative_image_embeds = prior_pipeline(prompt, guidance_scale=1.0).to_tuple()
+```
+
+Pass the `image_embeds` and `negative_image_embeds` to the [`KandinskyV22Pipeline`] to generate an image:
+
+```py
+image = pipeline(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768).images[0]
+image
+```
+
+
+
+
+
+
+
+
+🤗 Diffusers also provides an end-to-end API with the [`KandinskyCombinedPipeline`] and [`KandinskyV22CombinedPipeline`], meaning you don't have to separately load the prior and text-to-image pipeline. The combined pipeline automatically loads both the prior model and the decoder. You can still set different values for the prior pipeline with the `prior_guidance_scale` and `prior_num_inference_steps` parameters if you want.
+
+Use the [`AutoPipelineForText2Image`] to automatically call the combined pipelines under the hood:
+
+
+
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
+negative_prompt = "low quality, bad quality"
+
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale=1.0, guidance_scale=4.0, height=768, width=768).images[0]
+image
+```
+
+
+
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
+negative_prompt = "low quality, bad quality"
+
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale=1.0, guidance_scale=4.0, height=768, width=768).images[0]
+image
+```
+
+
+
+
+## Image-to-image
+
+For image-to-image, pass the initial image and text prompt to condition the image to the pipeline. Start by loading the prior pipeline:
+
+
+
+
+```py
+import torch
+from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline
+
+prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+pipeline = KandinskyImg2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+```
+
+
+
+
+```py
+import torch
+from diffusers import KandinskyV22Img2ImgPipeline, KandinskyPriorPipeline
+
+prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+```
+
+
+
+
+Download an image to condition on:
+
+```py
+from diffusers.utils import load_image
+
+# download image
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+original_image = load_image(url)
+original_image = original_image.resize((768, 512))
+```
+
+
+
+
+
+Generate the `image_embeds` and `negative_image_embeds` with the prior pipeline:
+
+```py
+prompt = "A fantasy landscape, Cinematic lighting"
+negative_prompt = "low quality, bad quality"
+
+image_embeds, negative_image_embeds = prior_pipeline(prompt, negative_prompt).to_tuple()
+```
+
+Now pass the original image, and all the prompts and embeddings to the pipeline to generate an image:
+
+
+
+
+```py
+from diffusers.utils import make_image_grid
+
+image = pipeline(prompt, negative_prompt=negative_prompt, image=original_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768, strength=0.3).images[0]
+make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)
+```
+
+
+
+
+
+
+🤗 Diffusers also provides an end-to-end API with the [`KandinskyImg2ImgCombinedPipeline`] and [`KandinskyV22Img2ImgCombinedPipeline`], meaning you don't have to separately load the prior and image-to-image pipeline. The combined pipeline automatically loads both the prior model and the decoder. You can still set different values for the prior pipeline with the `prior_guidance_scale` and `prior_num_inference_steps` parameters if you want.
+
+Use the [`AutoPipelineForImage2Image`] to automatically call the combined pipelines under the hood:
+
+
+
+
+```py
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+import torch
+
+pipeline = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16, use_safetensors=True)
+pipeline.enable_model_cpu_offload()
+
+prompt = "A fantasy landscape, Cinematic lighting"
+negative_prompt = "low quality, bad quality"
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+original_image = load_image(url)
+
+original_image.thumbnail((768, 768))
+
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=original_image, strength=0.3).images[0]
+make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)
+```
+
+
+
+
+```py
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
+import torch
+
+pipeline = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+prompt = "A fantasy landscape, Cinematic lighting"
+negative_prompt = "low quality, bad quality"
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+original_image = load_image(url)
+
+original_image.thumbnail((768, 768))
+
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=original_image, strength=0.3).images[0]
+make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)
+```
+
+
+
+
+## Inpainting
+
+
+
+⚠️ The Kandinsky models use ⬜️ **white pixels** to represent the masked area now instead of black pixels. If you are using [`KandinskyInpaintPipeline`] in production, you need to change the mask to use white pixels:
+
+```py
+# For PIL input
+import PIL.ImageOps
+mask = PIL.ImageOps.invert(mask)
+
+# For PyTorch and NumPy input
+mask = 1 - mask
+```
+
+
+
+For inpainting, you'll need the original image, a mask of the area to replace in the original image, and a text prompt of what to inpaint. Load the prior pipeline:
+
+
+
+
+```py
+from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
+from diffusers.utils import load_image, make_image_grid
+import torch
+import numpy as np
+from PIL import Image
+
+prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+pipeline = KandinskyInpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+```
+
+
+
+
+```py
+from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline
+from diffusers.utils import load_image, make_image_grid
+import torch
+import numpy as np
+from PIL import Image
+
+prior_pipeline = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+pipeline = KandinskyV22InpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+```
+
+
+
+
+Load an initial image and create a mask:
+
+```py
+init_image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png")
+mask = np.zeros((768, 768), dtype=np.float32)
+# mask area above cat's head
+mask[:250, 250:-250] = 1
+```
+
+Generate the embeddings with the prior pipeline:
+
+```py
+prompt = "a hat"
+prior_output = prior_pipeline(prompt)
+```
+
+Now pass the initial image, mask, and prompt and embeddings to the pipeline to generate an image:
+
+
+
+
+```py
+output_image = pipeline(prompt, image=init_image, mask_image=mask, **prior_output, height=768, width=768, num_inference_steps=150).images[0]
+mask = Image.fromarray((mask*255).astype('uint8'), 'L')
+make_image_grid([init_image, mask, output_image], rows=1, cols=3)
+```
+
+
+
+
+
+
+You can also use the end-to-end [`KandinskyInpaintCombinedPipeline`] and [`KandinskyV22InpaintCombinedPipeline`] to call the prior and decoder pipelines together under the hood. Use the [`AutoPipelineForInpainting`] for this:
+
+
+
+
+```py
+import torch
+import numpy as np
+from PIL import Image
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipe = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+init_image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png")
+mask = np.zeros((768, 768), dtype=np.float32)
+# mask area above cat's head
+mask[:250, 250:-250] = 1
+prompt = "a hat"
+
+output_image = pipe(prompt=prompt, image=init_image, mask_image=mask).images[0]
+mask = Image.fromarray((mask*255).astype('uint8'), 'L')
+make_image_grid([init_image, mask, output_image], rows=1, cols=3)
+```
+
+
+
+
+```py
+import torch
+import numpy as np
+from PIL import Image
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+pipe = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+init_image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png")
+mask = np.zeros((768, 768), dtype=np.float32)
+# mask area above cat's head
+mask[:250, 250:-250] = 1
+prompt = "a hat"
+
+output_image = pipe(prompt=prompt, image=original_image, mask_image=mask).images[0]
+mask = Image.fromarray((mask*255).astype('uint8'), 'L')
+make_image_grid([init_image, mask, output_image], rows=1, cols=3)
+```
+
+
+
+
+## Interpolation
+
+Interpolation allows you to explore the latent space between the image and text embeddings which is a cool way to see some of the prior model's intermediate outputs. Load the prior pipeline and two images you'd like to interpolate:
+
+
+
+
+```py
+from diffusers import KandinskyPriorPipeline, KandinskyPipeline
+from diffusers.utils import load_image, make_image_grid
+import torch
+
+prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+img_1 = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png")
+img_2 = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg")
+make_image_grid([img_1.resize((512,512)), img_2.resize((512,512))], rows=1, cols=2)
+```
+
+
+
+
+```py
+from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
+from diffusers.utils import load_image, make_image_grid
+import torch
+
+prior_pipeline = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+img_1 = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png")
+img_2 = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg")
+make_image_grid([img_1.resize((512,512)), img_2.resize((512,512))], rows=1, cols=2)
+```
+
+
+
+
+
+
+
+ a cat
+
+
+
+ Van Gogh's Starry Night painting
+
+
+
+Specify the text or images to interpolate, and set the weights for each text or image. Experiment with the weights to see how they affect the interpolation!
+
+```py
+images_texts = ["a cat", img_1, img_2]
+weights = [0.3, 0.3, 0.4]
+```
+
+Call the `interpolate` function to generate the embeddings, and then pass them to the pipeline to generate the image:
+
+
+
+
+```py
+# prompt can be left empty
+prompt = ""
+prior_out = prior_pipeline.interpolate(images_texts, weights)
+
+pipeline = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+
+image = pipeline(prompt, **prior_out, height=768, width=768).images[0]
+image
+```
+
+
+
+
+
+
+## ControlNet
+
+
+
+⚠️ ControlNet is only supported for Kandinsky 2.2!
+
+
+
+ControlNet enables conditioning large pretrained diffusion models with additional inputs such as a depth map or edge detection. For example, you can condition Kandinsky 2.2 with a depth map so the model understands and preserves the structure of the depth image.
+
+Let's load an image and extract it's depth map:
+
+```py
+from diffusers.utils import load_image
+
+img = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png"
+).resize((768, 768))
+img
+```
+
+
+
+
+
+Then you can use the `depth-estimation` [`~transformers.Pipeline`] from 🤗 Transformers to process the image and retrieve the depth map:
+
+```py
+import torch
+import numpy as np
+
+from transformers import pipeline
+
+def make_hint(image, depth_estimator):
+ image = depth_estimator(image)["depth"]
+ image = np.array(image)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ detected_map = torch.from_numpy(image).float() / 255.0
+ hint = detected_map.permute(2, 0, 1)
+ return hint
+
+depth_estimator = pipeline("depth-estimation")
+hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
+```
+
+### Text-to-image [[controlnet-text-to-image]]
+
+Load the prior pipeline and the [`KandinskyV22ControlnetPipeline`]:
+
+```py
+from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
+
+prior_pipeline = KandinskyV22PriorPipeline.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+
+pipeline = KandinskyV22ControlnetPipeline.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
+).to("cuda")
+```
+
+Generate the image embeddings from a prompt and negative prompt:
+
+```py
+prompt = "A robot, 4k photo"
+negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
+
+generator = torch.Generator(device="cuda").manual_seed(43)
+
+image_emb, zero_image_emb = prior_pipeline(
+ prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
+).to_tuple()
+```
+
+Finally, pass the image embeddings and the depth image to the [`KandinskyV22ControlnetPipeline`] to generate an image:
+
+```py
+image = pipeline(image_embeds=image_emb, negative_image_embeds=zero_image_emb, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768).images[0]
+image
+```
+
+
+
+
+
+### Image-to-image [[controlnet-image-to-image]]
+
+For image-to-image with ControlNet, you'll need to use the:
+
+- [`KandinskyV22PriorEmb2EmbPipeline`] to generate the image embeddings from a text prompt and an image
+- [`KandinskyV22ControlnetImg2ImgPipeline`] to generate an image from the initial image and the image embeddings
+
+Process and extract a depth map of an initial image of a cat with the `depth-estimation` [`~transformers.Pipeline`] from 🤗 Transformers:
+
+```py
+import torch
+import numpy as np
+
+from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
+from diffusers.utils import load_image
+from transformers import pipeline
+
+img = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png"
+).resize((768, 768))
+
+def make_hint(image, depth_estimator):
+ image = depth_estimator(image)["depth"]
+ image = np.array(image)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ detected_map = torch.from_numpy(image).float() / 255.0
+ hint = detected_map.permute(2, 0, 1)
+ return hint
+
+depth_estimator = pipeline("depth-estimation")
+hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
+```
+
+Load the prior pipeline and the [`KandinskyV22ControlnetImg2ImgPipeline`]:
+
+```py
+prior_pipeline = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+
+pipeline = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
+ "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
+).to("cuda")
+```
+
+Pass a text prompt and the initial image to the prior pipeline to generate the image embeddings:
+
+```py
+prompt = "A robot, 4k photo"
+negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
+
+generator = torch.Generator(device="cuda").manual_seed(43)
+
+img_emb = prior_pipeline(prompt=prompt, image=img, strength=0.85, generator=generator)
+negative_emb = prior_pipeline(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
+```
+
+Now you can run the [`KandinskyV22ControlnetImg2ImgPipeline`] to generate an image from the initial image and the image embeddings:
+
+```py
+image = pipeline(image=img, strength=0.5, image_embeds=img_emb.image_embeds, negative_image_embeds=negative_emb.image_embeds, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768).images[0]
+make_image_grid([img.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)
+```
+
+
+
+
+
+## Optimizations
+
+Kandinsky is unique because it requires a prior pipeline to generate the mappings, and a second pipeline to decode the latents into an image. Optimization efforts should be focused on the second pipeline because that is where the bulk of the computation is done. Here are some tips to improve Kandinsky during inference.
+
+1. Enable [xFormers](../optimization/xformers) if you're using PyTorch < 2.0:
+
+```diff
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
++ pipe.enable_xformers_memory_efficient_attention()
+```
+
+2. Enable `torch.compile` if you're using PyTorch >= 2.0 to automatically use scaled dot-product attention (SDPA):
+
+```diff
+ pipe.unet.to(memory_format=torch.channels_last)
++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+This is the same as explicitly setting the attention processor to use [`~models.attention_processor.AttnAddedKVProcessor2_0`]:
+
+```py
+from diffusers.models.attention_processor import AttnAddedKVProcessor2_0
+
+pipe.unet.set_attn_processor(AttnAddedKVProcessor2_0())
+```
+
+3. Offload the model to the CPU with [`~KandinskyPriorPipeline.enable_model_cpu_offload`] to avoid out-of-memory errors:
+
+```diff
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
++ pipe.enable_model_cpu_offload()
+```
+
+4. By default, the text-to-image pipeline uses the [`DDIMScheduler`] but you can replace it with another scheduler like [`DDPMScheduler`] to see how that affects the tradeoff between inference speed and image quality:
+
+```py
+from diffusers import DDPMScheduler
+from diffusers import DiffusionPipeline
+
+scheduler = DDPMScheduler.from_pretrained("kandinsky-community/kandinsky-2-1", subfolder="ddpm_scheduler")
+pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", scheduler=scheduler, torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+```
diff --git a/diffusers/docs/source/en/using-diffusers/loading.md b/diffusers/docs/source/en/using-diffusers/loading.md
new file mode 100644
index 0000000000000000000000000000000000000000..d9e19a5bdd2a327ad94c112daaff0ffc37d9661d
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/loading.md
@@ -0,0 +1,485 @@
+
+
+# Load pipelines, models, and schedulers
+
+[[open-in-colab]]
+
+Having an easy way to use a diffusion system for inference is essential to 🧨 Diffusers. Diffusion systems often consist of multiple components like parameterized models, tokenizers, and schedulers that interact in complex ways. That is why we designed the [`DiffusionPipeline`] to wrap the complexity of the entire diffusion system into an easy-to-use API, while remaining flexible enough to be adapted for other use cases, such as loading each component individually as building blocks to assemble your own diffusion system.
+
+Everything you need for inference or training is accessible with the `from_pretrained()` method.
+
+This guide will show you how to load:
+
+- pipelines from the Hub and locally
+- different components into a pipeline
+- checkpoint variants such as different floating point types or non-exponential mean averaged (EMA) weights
+- models and schedulers
+
+## Diffusion Pipeline
+
+
+
+💡 Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you are interested in learning in more detail about how the [`DiffusionPipeline`] class works.
+
+
+
+The [`DiffusionPipeline`] class is the simplest and most generic way to load the latest trending diffusion model from the [Hub](https://huggingface.co/models?library=diffusers&sort=trending). The [`DiffusionPipeline.from_pretrained`] method automatically detects the correct pipeline class from the checkpoint, downloads, and caches all the required configuration and weight files, and returns a pipeline instance ready for inference.
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipe = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+```
+
+You can also load a checkpoint with its specific pipeline class. The example above loaded a Stable Diffusion model; to get the same result, use the [`StableDiffusionPipeline`] class:
+
+```python
+from diffusers import StableDiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+```
+
+A checkpoint (such as [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) or [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)) may also be used for more than one task, like text-to-image or image-to-image. To differentiate what task you want to use the checkpoint for, you have to load it directly with its corresponding task-specific pipeline class:
+
+```python
+from diffusers import StableDiffusionImg2ImgPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
+```
+
+### Local pipeline
+
+To load a diffusion pipeline locally, use [`git-lfs`](https://git-lfs.github.com/) to manually download the checkpoint (in this case, [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)) to your local disk. This creates a local folder, `./stable-diffusion-v1-5`, on your disk:
+
+```bash
+git-lfs install
+git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+Then pass the local path to [`~DiffusionPipeline.from_pretrained`]:
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "./stable-diffusion-v1-5"
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+```
+
+The [`~DiffusionPipeline.from_pretrained`] method won't download any files from the Hub when it detects a local path, but this also means it won't download and cache the latest changes to a checkpoint.
+
+### Swap components in a pipeline
+
+You can customize the default components of any pipeline with another compatible component. Customization is important because:
+
+- Changing the scheduler is important for exploring the trade-off between generation speed and quality.
+- Different components of a model are typically trained independently and you can swap out a component with a better-performing one.
+- During finetuning, usually only some components - like the UNet or text encoder - are trained.
+
+To find out which schedulers are compatible for customization, you can use the `compatibles` method:
+
+```py
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+stable_diffusion.scheduler.compatibles
+```
+
+Let's use the [`SchedulerMixin.from_pretrained`] method to replace the default [`PNDMScheduler`] with a more performant scheduler, [`EulerDiscreteScheduler`]. The `subfolder="scheduler"` argument is required to load the scheduler configuration from the correct [subfolder](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler) of the pipeline repository.
+
+Then you can pass the new [`EulerDiscreteScheduler`] instance to the `scheduler` argument in [`DiffusionPipeline`]:
+
+```python
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler, use_safetensors=True)
+```
+
+### Safety checker
+
+Diffusion models like Stable Diffusion can generate harmful content, which is why 🧨 Diffusers has a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) to check generated outputs against known hardcoded NSFW content. If you'd like to disable the safety checker for whatever reason, pass `None` to the `safety_checker` argument:
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None, use_safetensors=True)
+"""
+You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide by the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend keeping the safety filter enabled in all public-facing circumstances, disabling it only for use cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
+"""
+```
+
+### Reuse components across pipelines
+
+You can also reuse the same components in multiple pipelines to avoid loading the weights into RAM twice. Use the [`~DiffusionPipeline.components`] method to save the components:
+
+```python
+from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+
+components = stable_diffusion_txt2img.components
+```
+
+Then you can pass the `components` to another pipeline without reloading the weights into RAM:
+
+```py
+stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
+```
+
+You can also pass the components individually to the pipeline if you want more flexibility over which components to reuse or disable. For example, to reuse the same components in the text-to-image pipeline, except for the safety checker and feature extractor, in the image-to-image pipeline:
+
+```py
+from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
+ vae=stable_diffusion_txt2img.vae,
+ text_encoder=stable_diffusion_txt2img.text_encoder,
+ tokenizer=stable_diffusion_txt2img.tokenizer,
+ unet=stable_diffusion_txt2img.unet,
+ scheduler=stable_diffusion_txt2img.scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=False,
+)
+```
+
+## Checkpoint variants
+
+A checkpoint variant is usually a checkpoint whose weights are:
+
+- Stored in a different floating point type for lower precision and lower storage, such as [`torch.float16`](https://pytorch.org/docs/stable/tensors.html#data-types), because it only requires half the bandwidth and storage to download. You can't use this variant if you're continuing training or using a CPU.
+- Non-exponential mean averaged (EMA) weights, which shouldn't be used for inference. You should use these to continue fine-tuning a model.
+
+
+
+💡 When the checkpoints have identical model structures, but they were trained on different datasets and with a different training setup, they should be stored in separate repositories instead of variations (for example, [`stable-diffusion-v1-4`] and [`stable-diffusion-v1-5`]).
+
+
+
+Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [Safetensors](./using_safetensors)), model structure, and weights that have identical tensor shapes.
+
+| **checkpoint type** | **weight name** | **argument for loading weights** |
+|---------------------|-------------------------------------|----------------------------------|
+| original | diffusion_pytorch_model.bin | |
+| floating point | diffusion_pytorch_model.fp16.bin | `variant`, `torch_dtype` |
+| non-EMA | diffusion_pytorch_model.non_ema.bin | `variant` |
+
+There are two important arguments to know for loading variants:
+
+- `torch_dtype` defines the floating point precision of the loaded checkpoints. For example, if you want to save bandwidth by loading a `fp16` variant, you should specify `torch_dtype=torch.float16` to *convert the weights* to `fp16`. Otherwise, the `fp16` weights are converted to the default `fp32` precision. You can also load the original checkpoint without defining the `variant` argument, and convert it to `fp16` with `torch_dtype=torch.float16`. In this case, the default `fp32` weights are downloaded first, and then they're converted to `fp16` after loading.
+
+- `variant` defines which files should be loaded from the repository. For example, if you want to load a `non_ema` variant from the [`diffusers/stable-diffusion-variants`](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main/unet) repository, you should specify `variant="non_ema"` to download the `non_ema` files.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+# load fp16 variant
+stable_diffusion = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16, use_safetensors=True
+)
+# load non_ema variant
+stable_diffusion = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", variant="non_ema", use_safetensors=True
+)
+```
+
+To save a checkpoint stored in a different floating-point type or as a non-EMA variant, use the [`DiffusionPipeline.save_pretrained`] method and specify the `variant` argument. You should try and save a variant to the same folder as the original checkpoint, so you can load both from the same folder:
+
+```python
+from diffusers import DiffusionPipeline
+
+# save as fp16 variant
+stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16")
+# save as non-ema variant
+stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
+```
+
+If you don't save the variant to an existing folder, you must specify the `variant` argument otherwise it'll throw an `Exception` because it can't find the original checkpoint:
+
+```python
+# 👎 this won't work
+stable_diffusion = DiffusionPipeline.from_pretrained(
+ "./stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+# 👍 this works
+stable_diffusion = DiffusionPipeline.from_pretrained(
+ "./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16, use_safetensors=True
+)
+```
+
+
+
+## Models
+
+Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
+
+Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for `runwayml/stable-diffusion-v1-5` are stored in the [`unet`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet) subfolder:
+
+```python
+from diffusers import UNet2DConditionModel
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet", use_safetensors=True)
+```
+
+Or directly from a repository's [directory](https://huggingface.co/google/ddpm-cifar10-32/tree/main):
+
+```python
+from diffusers import UNet2DModel
+
+repo_id = "google/ddpm-cifar10-32"
+model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
+```
+
+You can also load and save model variants by specifying the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`]:
+
+```python
+from diffusers import UNet2DConditionModel
+
+model = UNet2DConditionModel.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
+)
+model.save_pretrained("./local-unet", variant="non_ema")
+```
+
+## Schedulers
+
+Schedulers are loaded from the [`SchedulerMixin.from_pretrained`] method, and unlike models, schedulers are **not parameterized** or **trained**; they are defined by a configuration file.
+
+Loading schedulers does not consume any significant amount of memory and the same configuration file can be used for a variety of different schedulers.
+For example, the following schedulers are compatible with [`StableDiffusionPipeline`], which means you can load the same scheduler configuration file in any of these classes:
+
+```python
+from diffusers import StableDiffusionPipeline
+from diffusers import (
+ DDPMScheduler,
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+)
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+
+ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
+
+# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler`
+pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm, use_safetensors=True)
+```
+
+## DiffusionPipeline explained
+
+As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
+
+- Download the latest version of the folder structure required for inference and cache it. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] reuses the cache and won't redownload the files.
+- Load the cached weights into the correct pipeline [class](../api/pipelines/overview#diffusers-summary) - retrieved from the `model_index.json` file - and return an instance of it.
+
+The pipelines' underlying folder structure corresponds directly with their class instances. For example, the [`StableDiffusionPipeline`] corresponds to the folder structure in [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5).
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+print(pipeline)
+```
+
+You'll see pipeline is an instance of [`StableDiffusionPipeline`], which consists of seven components:
+
+- `"feature_extractor"`: a [`~transformers.CLIPImageProcessor`] from 🤗 Transformers.
+- `"safety_checker"`: a [component](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) for screening against harmful content.
+- `"scheduler"`: an instance of [`PNDMScheduler`].
+- `"text_encoder"`: a [`~transformers.CLIPTextModel`] from 🤗 Transformers.
+- `"tokenizer"`: a [`~transformers.CLIPTokenizer`] from 🤗 Transformers.
+- `"unet"`: an instance of [`UNet2DConditionModel`].
+- `"vae"`: an instance of [`AutoencoderKL`].
+
+```json
+StableDiffusionPipeline {
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "safety_checker": [
+ "stable_diffusion",
+ "StableDiffusionSafetyChecker"
+ ],
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ "text_encoder": [
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "tokenizer": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
+
+Compare the components of the pipeline instance to the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) folder structure, and you'll see there is a separate folder for each of the components in the repository:
+
+```
+.
+├── feature_extractor
+│ └── preprocessor_config.json
+├── model_index.json
+├── safety_checker
+│ ├── config.json
+| ├── model.fp16.safetensors
+│ ├── model.safetensors
+│ ├── pytorch_model.bin
+| └── pytorch_model.fp16.bin
+├── scheduler
+│ └── scheduler_config.json
+├── text_encoder
+│ ├── config.json
+| ├── model.fp16.safetensors
+│ ├── model.safetensors
+│ |── pytorch_model.bin
+| └── pytorch_model.fp16.bin
+├── tokenizer
+│ ├── merges.txt
+│ ├── special_tokens_map.json
+│ ├── tokenizer_config.json
+│ └── vocab.json
+├── unet
+│ ├── config.json
+│ ├── diffusion_pytorch_model.bin
+| |── diffusion_pytorch_model.fp16.bin
+│ |── diffusion_pytorch_model.f16.safetensors
+│ |── diffusion_pytorch_model.non_ema.bin
+│ |── diffusion_pytorch_model.non_ema.safetensors
+│ └── diffusion_pytorch_model.safetensors
+|── vae
+. ├── config.json
+. ├── diffusion_pytorch_model.bin
+ ├── diffusion_pytorch_model.fp16.bin
+ ├── diffusion_pytorch_model.fp16.safetensors
+ └── diffusion_pytorch_model.safetensors
+```
+
+You can access each of the components of the pipeline as an attribute to view its configuration:
+
+```py
+pipeline.tokenizer
+CLIPTokenizer(
+ name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer",
+ vocab_size=49408,
+ model_max_length=77,
+ is_fast=False,
+ padding_side="right",
+ truncation_side="right",
+ special_tokens={
+ "bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
+ "eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
+ "unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
+ "pad_token": "<|endoftext|>",
+ },
+ clean_up_tokenization_spaces=True
+)
+```
+
+Every pipeline expects a [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) file that tells the [`DiffusionPipeline`]:
+
+- which pipeline class to load from `_class_name`
+- which version of 🧨 Diffusers was used to create the model in `_diffusers_version`
+- what components from which library are stored in the subfolders (`name` corresponds to the component and subfolder name, `library` corresponds to the name of the library to load the class from, and `class` corresponds to the class name)
+
+```json
+{
+ "_class_name": "StableDiffusionPipeline",
+ "_diffusers_version": "0.6.0",
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "safety_checker": [
+ "stable_diffusion",
+ "StableDiffusionSafetyChecker"
+ ],
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ "text_encoder": [
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "tokenizer": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
diff --git a/diffusers/docs/source/en/using-diffusers/loading_adapters.md b/diffusers/docs/source/en/using-diffusers/loading_adapters.md
new file mode 100644
index 0000000000000000000000000000000000000000..c14b38a9dd89e31f9356928130210ded757b6f46
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/loading_adapters.md
@@ -0,0 +1,637 @@
+
+
+# Load adapters
+
+[[open-in-colab]]
+
+There are several [training](../training/overview) techniques for personalizing diffusion models to generate images of a specific subject or images in certain styles. Each of these training methods produces a different type of adapter. Some of the adapters generate an entirely new model, while other adapters only modify a smaller set of embeddings or weights. This means the loading process for each adapter is also different.
+
+This guide will show you how to load DreamBooth, textual inversion, and LoRA weights.
+
+
+
+Feel free to browse the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), and the [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) for checkpoints and embeddings to use.
+
+
+
+## DreamBooth
+
+[DreamBooth](https://dreambooth.github.io/) finetunes an *entire diffusion model* on just several images of a subject to generate images of that subject in new styles and settings. This method works by using a special word in the prompt that the model learns to associate with the subject image. Of all the training methods, DreamBooth produces the largest file size (usually a few GBs) because it is a full checkpoint model.
+
+Let's load the [herge_style](https://huggingface.co/sd-dreambooth-library/herge-style) checkpoint, which is trained on just 10 images drawn by Hergé, to generate images in that style. For it to work, you need to include the special word `herge_style` in your prompt to trigger the checkpoint:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("sd-dreambooth-library/herge-style", torch_dtype=torch.float16).to("cuda")
+prompt = "A cute herge_style brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
+image = pipeline(prompt).images[0]
+image
+```
+
+
+
+
+
+## Textual inversion
+
+[Textual inversion](https://textual-inversion.github.io/) is very similar to DreamBooth and it can also personalize a diffusion model to generate certain concepts (styles, objects) from just a few images. This method works by training and finding new embeddings that represent the images you provide with a special word in the prompt. As a result, the diffusion model weights stay the same and the training process produces a relatively tiny (a few KBs) file.
+
+Because textual inversion creates embeddings, it cannot be used on its own like DreamBooth and requires another model.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+```
+
+Now you can load the textual inversion embeddings with the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] method and generate some images. Let's load the [sd-concepts-library/gta5-artwork](https://huggingface.co/sd-concepts-library/gta5-artwork) embeddings and you'll need to include the special word `` in your prompt to trigger it:
+
+```py
+pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
+prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, style"
+image = pipeline(prompt).images[0]
+image
+```
+
+
+
+
+
+Textual inversion can also be trained on undesirable things to create *negative embeddings* to discourage a model from generating images with those undesirable things like blurry images or extra fingers on a hand. This can be an easy way to quickly improve your prompt. You'll also load the embeddings with [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`], but this time, you'll need two more parameters:
+
+- `weight_name`: specifies the weight file to load if the file was saved in the 🤗 Diffusers format with a specific name or if the file is stored in the A1111 format
+- `token`: specifies the special word to use in the prompt to trigger the embeddings
+
+Let's load the [sayakpaul/EasyNegative-test](https://huggingface.co/sayakpaul/EasyNegative-test) embeddings:
+
+```py
+pipeline.load_textual_inversion(
+ "sayakpaul/EasyNegative-test", weight_name="EasyNegative.safetensors", token="EasyNegative"
+)
+```
+
+Now you can use the `token` to generate an image with the negative embeddings:
+
+```py
+prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, EasyNegative"
+negative_prompt = "EasyNegative"
+
+image = pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50).images[0]
+image
+```
+
+
+
+
+
+## LoRA
+
+[Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685) is a popular training technique because it is fast and generates smaller file sizes (a couple hundred MBs). Like the other methods in this guide, LoRA can train a model to learn new styles from just a few images. It works by inserting new weights into the diffusion model and then only the new weights are trained instead of the entire model. This makes LoRAs faster to train and easier to store.
+
+
+
+LoRA is a very general training technique that can be used with other training methods. For example, it is common to train a model with DreamBooth and LoRA.
+
+
+
+LoRAs also need to be used with another model:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+```
+
+Then use the [`~loaders.LoraLoaderMixin.load_lora_weights`] method to load the [ostris/super-cereal-sdxl-lora](https://huggingface.co/ostris/super-cereal-sdxl-lora) weights and specify the weights filename from the repository:
+
+```py
+pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora", weight_name="cereal_box_sdxl_v1.safetensors")
+prompt = "bears, pizza bites"
+image = pipeline(prompt).images[0]
+image
+```
+
+
+
+
+
+The [`~loaders.LoraLoaderMixin.load_lora_weights`] method loads LoRA weights into both the UNet and text encoder. It is the preferred way for loading LoRAs because it can handle cases where:
+
+- the LoRA weights don't have separate identifiers for the UNet and text encoder
+- the LoRA weights have separate identifiers for the UNet and text encoder
+
+But if you only need to load LoRA weights into the UNet, then you can use the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Let's load the [jbilcke-hf/sdxl-cinematic-1](https://huggingface.co/jbilcke-hf/sdxl-cinematic-1) LoRA:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+pipeline.unet.load_attn_procs("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors")
+
+# use cnmt in the prompt to trigger the LoRA
+prompt = "A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration"
+image = pipeline(prompt).images[0]
+image
+```
+
+
+
+
+
+
+
+For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
+
+
+
+To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
+
+```py
+pipeline.unload_lora_weights()
+```
+
+### Load multiple LoRAs
+
+It can be fun to use multiple LoRAs together to create something entirely new and unique. The [`~loaders.LoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights with the original weights of the underlying model.
+
+
+
+Fusing the weights can lead to a speedup in inference latency because you don't need to separately load the base model and LoRA! You can save your fused pipeline with [`~DiffusionPipeline.save_pretrained`] to avoid loading and fusing the weights every time you want to use the model.
+
+
+
+Load an initial model:
+
+```py
+from diffusers import StableDiffusionXLPipeline, AutoencoderKL
+import torch
+
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ vae=vae,
+ torch_dtype=torch.float16,
+).to("cuda")
+```
+
+Next, load the LoRA checkpoint and fuse it with the original weights. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.LoraLoaderMixin.fuse_lora`] method because it won't work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
+
+If you need to reset the original model weights for any reason (use a different `lora_scale`), you should use the [`~loaders.LoraLoaderMixin.unfuse_lora`] method.
+
+```py
+pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl")
+pipeline.fuse_lora(lora_scale=0.7)
+
+# to unfuse the LoRA weights
+pipeline.unfuse_lora()
+```
+
+Then fuse this pipeline with the next set of LoRA weights:
+
+```py
+pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora")
+pipeline.fuse_lora(lora_scale=0.7)
+```
+
+
+
+You can't unfuse multiple LoRA checkpoints, so if you need to reset the model to its original weights, you'll need to reload it.
+
+
+
+Now you can generate an image that uses the weights from both LoRAs:
+
+```py
+prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
+image = pipeline(prompt).images[0]
+image
+```
+
+### 🤗 PEFT
+
+
+
+Read the [Inference with 🤗 PEFT](../tutorials/using_peft_for_inference) tutorial to learn more about its integration with 🤗 Diffusers and how you can easily work with and juggle multiple adapters. You'll need to install 🤗 Diffusers and PEFT from source to run the example in this section.
+
+
+
+Another way you can load and use multiple LoRAs is to specify the `adapter_name` parameter in [`~loaders.LoraLoaderMixin.load_lora_weights`]. This method takes advantage of the 🤗 PEFT integration. For example, load and name both LoRA weights:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
+pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora", weight_name="cereal_box_sdxl_v1.safetensors", adapter_name="cereal")
+```
+
+Now use the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] to activate both LoRAs, and you can configure how much weight each LoRA should have on the output:
+
+```py
+pipeline.set_adapters(["ikea", "cereal"], adapter_weights=[0.7, 0.5])
+```
+
+Then, generate an image:
+
+```py
+prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
+image = pipeline(prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}).images[0]
+image
+```
+
+### Kohya and TheLastBen
+
+Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
+
+Let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/):
+
+```sh
+!wget https://civitai.com/api/download/models/168776 -O blueprintify-sd-xl-10.safetensors
+```
+
+Load the LoRA checkpoint with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method, and specify the filename in the `weight_name` parameter:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+pipeline.load_lora_weights("path/to/weights", weight_name="blueprintify-sd-xl-10.safetensors")
+```
+
+Generate an image:
+
+```py
+# use bl3uprint in the prompt to trigger the LoRA
+prompt = "bl3uprint, a highly detailed blueprint of the eiffel tower, explaining how to build all parts, many txt, blueprint grid backdrop"
+image = pipeline(prompt).images[0]
+image
+```
+
+
+
+Some limitations of using Kohya LoRAs with 🤗 Diffusers include:
+
+- Images may not look like those generated by UIs - like ComfyUI - for multiple reasons, which are explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
+- [LyCORIS checkpoints](https://github.com/KohakuBlueleaf/LyCORIS) aren't fully supported. The [`~loaders.LoraLoaderMixin.load_lora_weights`] method loads LyCORIS checkpoints with LoRA and LoCon modules, but Hada and LoKR are not supported.
+
+
+
+Loading a checkpoint from TheLastBen is very similar. For example, to load the [TheLastBen/William_Eggleston_Style_SDXL](https://huggingface.co/TheLastBen/William_Eggleston_Style_SDXL) checkpoint:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
+pipeline.load_lora_weights("TheLastBen/William_Eggleston_Style_SDXL", weight_name="wegg.safetensors")
+
+# use by william eggleston in the prompt to trigger the LoRA
+prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, beautiful"
+image = pipeline(prompt=prompt).images[0]
+image
+```
+
+## IP-Adapter
+
+[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
+
+IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box.
+
+
+
+
+You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).
+
+IP-Adapter was contributed by [okotaku](https://github.com/okotaku).
+
+
+
+Let's first create a Stable Diffusion Pipeline.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+from diffusers.utils import load_image
+
+
+pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+```
+
+Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.
+
+```py
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+```
+
+
+IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
+
+```py
+from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
+import torch
+
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "h94/IP-Adapter",
+ subfolder="models/image_encoder",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")
+```
+
+
+IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎
+
+```py
+pipeline.set_ip_adapter_scale(0.6)
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
+generator = torch.Generator(device="cpu").manual_seed(33)
+images = pipeline(
+ prompt='best quality, high quality, wearing sunglasses',
+ ip_adapter_image=image,
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=50,
+ generator=generator,
+).images
+images[0]
+```
+
+
+
+
+
+
+
+You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio. If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt.
+`scale=0.5` can achieve good results in most cases when you use both text and image prompts.
+
+
+IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint.
+
+
+
+
+```py
+from diffusers import AutoPipelineForImage2Image
+import torch
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+
+image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
+ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
+
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+generator = torch.Generator(device="cpu").manual_seed(33)
+images = pipeline(
+ prompt='best quality, high quality',
+ image = image,
+ ip_adapter_image=ip_image,
+ num_inference_steps=50,
+ generator=generator,
+ strength=0.6,
+).images
+images[0]
+```
+
+
+
+
+```py
+from diffusers import AutoPipelineForInpaint
+import torch
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda")
+
+image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
+mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
+ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
+
+image = image.resize((512, 768))
+mask = mask.resize((512, 768))
+
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+images = pipeline(
+ prompt='best quality, high quality',
+ image = image,
+ mask_image = mask,
+ ip_adapter_image=ip_image,
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=50,
+ generator=generator,
+ strength=0.5,
+).images
+images[0]
+```
+
+
+
+
+IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
+
+```python
+from diffusers import AutoPipelineForText2Image
+from diffusers.utils import load_image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg")
+
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+image = pipeline(
+ prompt="best quality, high quality",
+ ip_adapter_image=image,
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=25,
+ generator=generator,
+).images[0]
+image.save("sdxl_t2i.png")
+```
+
+
+
+
+ input image
+
+
+
+ adapted image
+
+
+
+
+### LCM-Lora
+
+You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
+
+```py
+from diffusers import DiffusionPipeline, LCMScheduler
+import torch
+from diffusers.utils import load_image
+
+model_id = "sd-dreambooth-library/herge-style"
+lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
+
+pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+
+pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+pipe.load_lora_weights(lcm_lora_id)
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+pipe.enable_model_cpu_offload()
+
+prompt = "best quality, high quality"
+image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
+images = pipe(
+ prompt=prompt,
+ ip_adapter_image=image,
+ num_inference_steps=4,
+ guidance_scale=1,
+).images[0]
+```
+
+### Other pipelines
+
+IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image`
+
+
+
+🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet!
+
+
+
+You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff.
+
+
+
+
+```
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+import torch
+from diffusers.utils import load_image
+
+controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
+controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
+
+pipeline = StableDiffusionControlNetPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png")
+depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png")
+
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+images = pipeline(
+ prompt='best quality, high quality',
+ image=depth_map,
+ ip_adapter_image=image,
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=50,
+ generator=generator,
+).images
+images[0]
+```
+
+
+
+ input image
+
+
+
+ adapted image
+
+
+
+
+
+
+```py
+# animate diff + ip adapter
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif, load_image
+
+# Load the motion adapter
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
+# load SD 1.5 based finetuned model
+model_id = "Lykon/DreamShaper"
+pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
+
+# scheduler
+scheduler = DDIMScheduler(
+ clip_sample=False,
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="linear",
+ timestep_spacing="trailing",
+ steps_offset=1
+)
+pipe.scheduler = scheduler
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_model_cpu_offload()
+
+# load ip_adapter
+pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
+
+# load motion adapters
+pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
+pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up")
+pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left")
+
+seed = 42
+image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
+images = [image] * 3
+prompts = ["best quality, high quality"] * 3
+negative_prompt = "bad quality, worst quality"
+adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]]
+
+# generate
+output_frames = []
+for prompt, image, adapter_weight in zip(prompts, images, adapter_weights):
+ pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight)
+ output = pipe(
+ prompt= prompt,
+ num_frames=16,
+ guidance_scale=7.5,
+ num_inference_steps=30,
+ ip_adapter_image = image,
+ generator=torch.Generator("cpu").manual_seed(seed),
+ )
+ frames = output.frames[0]
+ output_frames.extend(frames)
+
+export_to_gif(output_frames, "test_out_animation.gif")
+```
+
+
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/loading_overview.md b/diffusers/docs/source/en/using-diffusers/loading_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..b36fdb77e6ddef7be7d2f8b6590744196f02a36e
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/loading_overview.md
@@ -0,0 +1,17 @@
+
+
+# Overview
+
+🧨 Diffusers offers many pipelines, models, and schedulers for generative tasks. To make loading these components as simple as possible, we provide a single and unified method - `from_pretrained()` - that loads any of these components from either the Hugging Face [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) or your local machine. Whenever you load a pipeline or model, the latest files are automatically downloaded and cached so you can quickly reuse them next time without redownloading the files.
+
+This section will show you everything you need to know about loading pipelines, how to load different components in a pipeline, how to load checkpoint variants, and how to load community pipelines. You'll also learn how to load schedulers and compare the speed and quality trade-offs of using different schedulers. Finally, you'll see how to convert and load KerasCV checkpoints so you can use them in PyTorch with 🧨 Diffusers.
diff --git a/diffusers/docs/source/en/using-diffusers/other-formats.md b/diffusers/docs/source/en/using-diffusers/other-formats.md
new file mode 100644
index 0000000000000000000000000000000000000000..6f8e00d1e39674abdd931ce08523576b85566621
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/other-formats.md
@@ -0,0 +1,176 @@
+
+
+# Load different Stable Diffusion formats
+
+[[open-in-colab]]
+
+Stable Diffusion models are available in different formats depending on the framework they're trained and saved with, and where you download them from. Converting these formats for use in 🤗 Diffusers allows you to use all the features supported by the library, such as [using different schedulers](schedulers) for inference, [building your custom pipeline](write_own_pipeline), and a variety of techniques and methods for [optimizing inference speed](../optimization/opt_overview).
+
+
+
+We highly recommend using the `.safetensors` format because it is more secure than traditional pickled files which are vulnerable and can be exploited to execute any code on your machine (learn more in the [Load safetensors](using_safetensors) guide).
+
+
+
+This guide will show you how to convert other Stable Diffusion formats to be compatible with 🤗 Diffusers.
+
+## PyTorch .ckpt
+
+The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_single_file`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available.
+
+There are two options for converting a `.ckpt` file: use a Space to convert the checkpoint or convert the `.ckpt` file with a script.
+
+### Convert with a Space
+
+The easiest and most convenient way to convert a `.ckpt` file is to use the [SD to Diffusers](https://huggingface.co/spaces/diffusers/sd-to-diffusers) Space. You can follow the instructions on the Space to convert the `.ckpt` file.
+
+This approach works well for basic models, but it may struggle with more customized models. You'll know the Space failed if it returns an empty pull request or error. In this case, you can try converting the `.ckpt` file with a script.
+
+### Convert with a script
+
+🤗 Diffusers provides a [conversion script](https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py) for converting `.ckpt` files. This approach is more reliable than the Space above.
+
+Before you start, make sure you have a local clone of 🤗 Diffusers to run the script and log in to your Hugging Face account so you can open pull requests and push your converted model to the Hub.
+
+```bash
+huggingface-cli login
+```
+
+To use the script:
+
+1. Git clone the repository containing the `.ckpt` file you want to convert. For this example, let's convert this [TemporalNet](https://huggingface.co/CiaraRowles/TemporalNet) `.ckpt` file:
+
+```bash
+git lfs install
+git clone https://huggingface.co/CiaraRowles/TemporalNet
+```
+
+2. Open a pull request on the repository where you're converting the checkpoint from:
+
+```bash
+cd TemporalNet && git fetch origin refs/pr/13:pr/13
+git checkout pr/13
+```
+
+3. There are several input arguments to configure in the conversion script, but the most important ones are:
+
+ - `checkpoint_path`: the path to the `.ckpt` file to convert.
+ - `original_config_file`: a YAML file defining the configuration of the original architecture. If you can't find this file, try searching for the YAML file in the GitHub repository where you found the `.ckpt` file.
+ - `dump_path`: the path to the converted model.
+
+ For example, you can take the `cldm_v15.yaml` file from the [ControlNet](https://github.com/lllyasviel/ControlNet/tree/main/models) repository because the TemporalNet model is a Stable Diffusion v1.5 and ControlNet model.
+
+4. Now you can run the script to convert the `.ckpt` file:
+
+```bash
+python ../diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path temporalnetv3.ckpt --original_config_file cldm_v15.yaml --dump_path ./ --controlnet
+```
+
+5. Once the conversion is done, upload your converted model and test out the resulting [pull request](https://huggingface.co/CiaraRowles/TemporalNet/discussions/13)!
+
+```bash
+git push origin pr/13:refs/pr/13
+```
+
+## Keras .pb or .h5
+
+
+
+🧪 This is an experimental feature. Only Stable Diffusion v1 checkpoints are supported by the Convert KerasCV Space at the moment.
+
+
+
+[KerasCV](https://keras.io/keras_cv/) supports training for [Stable Diffusion](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion) v1 and v2. However, it offers limited support for experimenting with Stable Diffusion models for inference and deployment whereas 🤗 Diffusers has a more complete set of features for this purpose, such as different [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other
+optimization techniques](https://huggingface.co/docs/diffusers/optimization/fp16).
+
+The [Convert KerasCV](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) Space converts `.pb` or `.h5` files to PyTorch, and then wraps them in a [`StableDiffusionPipeline`] so it is ready for inference. The converted checkpoint is stored in a repository on the Hugging Face Hub.
+
+For this example, let's convert the [`sayakpaul/textual-inversion-kerasio`](https://huggingface.co/sayakpaul/textual-inversion-kerasio/tree/main) checkpoint which was trained with Textual Inversion. It uses the special token `` to personalize images with cats.
+
+The Convert KerasCV Space allows you to input the following:
+
+* Your Hugging Face token.
+* Paths to download the UNet and text encoder weights from. Depending on how the model was trained, you don't necessarily need to provide the paths to both the UNet and text encoder. For example, Textual Inversion only requires the embeddings from the text encoder and a text-to-image model only requires the UNet weights.
+* Placeholder token is only applicable for textual inversion models.
+* The `output_repo_prefix` is the name of the repository where the converted model is stored.
+
+Click the **Submit** button to automatically convert the KerasCV checkpoint! Once the checkpoint is successfully converted, you'll see a link to the new repository containing the converted checkpoint. Follow the link to the new repository, and you'll see the Convert KerasCV Space generated a model card with an inference widget to try out the converted model.
+
+If you prefer to run inference with code, click on the **Use in Diffusers** button in the upper right corner of the model card to copy and paste the code snippet:
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline", use_safetensors=True
+)
+```
+
+Then, you can generate an image like:
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline", use_safetensors=True
+)
+pipeline.to("cuda")
+
+placeholder_token = ""
+prompt = f"two {placeholder_token} getting married, photorealistic, high quality"
+image = pipeline(prompt, num_inference_steps=50).images[0]
+```
+
+## A1111 LoRA files
+
+[Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (A1111) is a popular web UI for Stable Diffusion that supports model sharing platforms like [Civitai](https://civitai.com/). Models trained with the Low-Rank Adaptation (LoRA) technique are especially popular because they're fast to train and have a much smaller file size than a fully finetuned model. 🤗 Diffusers supports loading A1111 LoRA checkpoints with [`~loaders.LoraLoaderMixin.load_lora_weights`]:
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+```
+
+Download a LoRA checkpoint from Civitai; this example uses the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint, but feel free to try out any LoRA checkpoint!
+
+```py
+# uncomment to download the safetensor weights
+#!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
+```
+
+Load the LoRA checkpoint into the pipeline with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method:
+
+```py
+pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
+```
+
+Now you can use the pipeline to generate images:
+
+```py
+prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
+negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
+
+image = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=torch.manual_seed(0),
+).images[0]
+image
+```
+
+
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/other-modalities.md b/diffusers/docs/source/en/using-diffusers/other-modalities.md
new file mode 100644
index 0000000000000000000000000000000000000000..ec879c49b1060c7ade1a0eb7e82de87c95d1b957
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/other-modalities.md
@@ -0,0 +1,21 @@
+
+
+# Using Diffusers with other modalities
+
+Diffusers is in the process of expanding to modalities other than images.
+
+Example type | Colab | Pipeline |
+:-------------------------:|:-------------------------:|:-------------------------:|
+[Molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) | ❌
+
+More coming soon!
\ No newline at end of file
diff --git a/diffusers/docs/source/en/using-diffusers/pipeline_overview.md b/diffusers/docs/source/en/using-diffusers/pipeline_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..292ce51d322ae0616de66270825fd1debb4629fe
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/pipeline_overview.md
@@ -0,0 +1,17 @@
+
+
+# Overview
+
+A pipeline is an end-to-end class that provides a quick and easy way to use a diffusion system for inference by bundling independently trained models and schedulers together. Certain combinations of models and schedulers define specific pipeline types, like [`StableDiffusionXLPipeline`] or [`StableDiffusionControlNetPipeline`], with specific capabilities. All pipeline types inherit from the base [`DiffusionPipeline`] class; pass it any checkpoint, and it'll automatically detect the pipeline type and load the necessary components.
+
+This section demonstrates how to use specific pipelines such as Stable Diffusion XL, ControlNet, and DiffEdit. You'll also learn how to use a distilled version of the Stable Diffusion model to speed up inference, how to create reproducible pipelines, and how to use and contribute community pipelines.
diff --git a/diffusers/docs/source/en/using-diffusers/push_to_hub.md b/diffusers/docs/source/en/using-diffusers/push_to_hub.md
new file mode 100644
index 0000000000000000000000000000000000000000..58598c3bc443c5965baacaca13dc866f38e744ac
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/push_to_hub.md
@@ -0,0 +1,183 @@
+
+
+# Push files to the Hub
+
+[[open-in-colab]]
+
+🤗 Diffusers provides a [`~diffusers.utils.PushToHubMixin`] for uploading your model, scheduler, or pipeline to the Hub. It is an easy way to store your files on the Hub, and also allows you to share your work with others. Under the hood, the [`~diffusers.utils.PushToHubMixin`]:
+
+1. creates a repository on the Hub
+2. saves your model, scheduler, or pipeline files so they can be reloaded later
+3. uploads folder containing these files to the Hub
+
+This guide will show you how to use the [`~diffusers.utils.PushToHubMixin`] to upload your files to the Hub.
+
+You'll need to log in to your Hub account with your access [token](https://huggingface.co/settings/tokens) first:
+
+```py
+from huggingface_hub import notebook_login
+
+notebook_login()
+```
+
+## Models
+
+To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model to be stored on the Hub:
+
+```py
+from diffusers import ControlNetModel
+
+controlnet = ControlNetModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ in_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ cross_attention_dim=32,
+ conditioning_embedding_out_channels=(16, 32),
+)
+controlnet.push_to_hub("my-controlnet-model")
+```
+
+For models, you can also specify the [*variant*](loading#checkpoint-variants) of the weights to push to the Hub. For example, to push `fp16` weights:
+
+```py
+controlnet.push_to_hub("my-controlnet-model", variant="fp16")
+```
+
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the model's `config.json` file and the weights are automatically saved in the `safetensors` format.
+
+Now you can reload the model from your repository on the Hub:
+
+```py
+model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model")
+```
+
+## Scheduler
+
+To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler to be stored on the Hub:
+
+```py
+from diffusers import DDIMScheduler
+
+scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+)
+scheduler.push_to_hub("my-controlnet-scheduler")
+```
+
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the scheduler's `scheduler_config.json` file to the specified repository.
+
+Now you can reload the scheduler from your repository on the Hub:
+
+```py
+scheduler = DDIMScheduler.from_pretrained("your-namepsace/my-controlnet-scheduler")
+```
+
+## Pipeline
+
+You can also push an entire pipeline with all it's components to the Hub. For example, initialize the components of a [`StableDiffusionPipeline`] with the parameters you want:
+
+```py
+from diffusers import (
+ UNet2DConditionModel,
+ AutoencoderKL,
+ DDIMScheduler,
+ StableDiffusionPipeline,
+)
+from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer
+
+unet = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+)
+
+scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+)
+
+vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+)
+
+text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+)
+text_encoder = CLIPTextModel(text_encoder_config)
+tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+```
+
+Pass all of the components to the [`StableDiffusionPipeline`] and call [`~diffusers.utils.PushToHubMixin.push_to_hub`] to push the pipeline to the Hub:
+
+```py
+components = {
+ "unet": unet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+}
+
+pipeline = StableDiffusionPipeline(**components)
+pipeline.push_to_hub("my-pipeline")
+```
+
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves each component to a subfolder in the repository. Now you can reload the pipeline from your repository on the Hub:
+
+```py
+pipeline = StableDiffusionPipeline.from_pretrained("your-namespace/my-pipeline")
+```
+
+## Privacy
+
+Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] function to keep your model, scheduler, or pipeline files private:
+
+```py
+controlnet.push_to_hub("my-controlnet-model-private", private=True)
+```
+
+Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for.`
+
+To load a model, scheduler, or pipeline from private or gated repositories, set `use_auth_token=True`:
+
+```py
+model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model-private", use_auth_token=True)
+```
diff --git a/diffusers/docs/source/en/using-diffusers/reproducibility.md b/diffusers/docs/source/en/using-diffusers/reproducibility.md
new file mode 100644
index 0000000000000000000000000000000000000000..5bc1d02b14d4273d20aa3d39174bfbadc8180008
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/reproducibility.md
@@ -0,0 +1,191 @@
+
+
+# Create reproducible pipelines
+
+[[open-in-colab]]
+
+Reproducibility is important for testing, replicating results, and can even be used to [improve image quality](reusing_seeds). However, the randomness in diffusion models is a desired property because it allows the pipeline to generate different images every time it is run. While you can't expect to get the exact same results across platforms, you can expect results to be reproducible across releases and platforms within a certain tolerance range. Even then, tolerance varies depending on the diffusion pipeline and checkpoint.
+
+This is why it's important to understand how to control sources of randomness in diffusion models or use deterministic algorithms.
+
+
+
+💡 We strongly recommend reading PyTorch's [statement about reproducibility](https://pytorch.org/docs/stable/notes/randomness.html):
+
+> Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds.
+
+
+
+## Control randomness
+
+During inference, pipelines rely heavily on random sampling operations which include creating the
+Gaussian noise tensors to denoise and adding noise to the scheduling step.
+
+Take a look at the tensor values in the [`DDIMPipeline`] after two inference steps:
+
+```python
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# load model and scheduler
+ddim = DDIMPipeline.from_pretrained(model_id, use_safetensors=True)
+
+# run pipeline for just two steps and return numpy tensor
+image = ddim(num_inference_steps=2, output_type="np").images
+print(np.abs(image).sum())
+```
+
+Running the code above prints one value, but if you run it again you get a different value. What is going on here?
+
+Every time the pipeline is run, [`torch.randn`](https://pytorch.org/docs/stable/generated/torch.randn.html) uses a different random seed to create Gaussian noise which is denoised stepwise. This leads to a different result each time it is run, which is great for diffusion pipelines since it generates a different random image each time.
+
+But if you need to reliably generate the same image, that'll depend on whether you're running the pipeline on a CPU or GPU.
+
+### CPU
+
+To generate reproducible results on a CPU, you'll need to use a PyTorch [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed:
+
+```python
+import torch
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# load model and scheduler
+ddim = DDIMPipeline.from_pretrained(model_id, use_safetensors=True)
+
+# create a generator for reproducibility
+generator = torch.Generator(device="cpu").manual_seed(0)
+
+# run pipeline for just two steps and return numpy tensor
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+Now when you run the code above, it always prints a value of `1491.1711` no matter what because the `Generator` object with the seed is passed to all the random functions of the pipeline.
+
+If you run this code example on your specific hardware and PyTorch version, you should get a similar, if not the same, result.
+
+
+
+💡 It might be a bit unintuitive at first to pass `Generator` objects to the pipeline instead of
+just integer values representing the seed, but this is the recommended design when dealing with
+probabilistic models in PyTorch, as `Generator`s are *random states* that can be
+passed to multiple pipelines in a sequence.
+
+
+
+### GPU
+
+Writing a reproducible pipeline on a GPU is a bit trickier, and full reproducibility across different hardware is not guaranteed because matrix multiplication - which diffusion pipelines require a lot of - is less deterministic on a GPU than a CPU. For example, if you run the same code example above on a GPU:
+
+```python
+import torch
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# load model and scheduler
+ddim = DDIMPipeline.from_pretrained(model_id, use_safetensors=True)
+ddim.to("cuda")
+
+# create a generator for reproducibility
+generator = torch.Generator(device="cuda").manual_seed(0)
+
+# run pipeline for just two steps and return numpy tensor
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+The result is not the same even though you're using an identical seed because the GPU uses a different random number generator than the CPU.
+
+To circumvent this problem, 🧨 Diffusers has a [`~diffusers.utils.torch_utils.randn_tensor`] function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The `randn_tensor` function is used everywhere inside the pipeline, allowing the user to **always** pass a CPU `Generator` even if the pipeline is run on a GPU.
+
+You'll see the results are much closer now!
+
+```python
+import torch
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# load model and scheduler
+ddim = DDIMPipeline.from_pretrained(model_id, use_safetensors=True)
+ddim.to("cuda")
+
+# create a generator for reproducibility; notice you don't place it on the GPU!
+generator = torch.manual_seed(0)
+
+# run pipeline for just two steps and return numpy tensor
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+
+
+💡 If reproducibility is important, we recommend always passing a CPU generator.
+The performance loss is often neglectable, and you'll generate much more similar
+values than if the pipeline had been run on a GPU.
+
+
+
+Finally, for more complex pipelines such as [`UnCLIPPipeline`], these are often extremely
+susceptible to precision error propagation. Don't expect similar results across
+different GPU hardware or PyTorch versions. In this case, you'll need to run
+exactly the same hardware and PyTorch version for full reproducibility.
+
+## Deterministic algorithms
+
+You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. However, you should be aware that deterministic algorithms may be slower than nondeterministic ones and you may observe a decrease in performance. But if reproducibility is important to you, then this is the way to go!
+
+Nondeterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment variable [`CUBLAS_WORKSPACE_CONFIG`](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
+
+PyTorch typically benchmarks multiple algorithms to select the fastest one, but if you want reproducibility, you should disable this feature because the benchmark may select different algorithms each time. Lastly, pass `True` to [`torch.use_deterministic_algorithms`](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html) to enable deterministic algorithms.
+
+```py
+import os
+import torch
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+
+torch.backends.cudnn.benchmark = False
+torch.use_deterministic_algorithms(True)
+```
+
+Now when you run the same pipeline twice, you'll get identical results.
+
+```py
+import torch
+from diffusers import DDIMScheduler, StableDiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, use_safetensors=True).to("cuda")
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+g = torch.Generator(device="cuda")
+
+prompt = "A bear is playing a guitar on Times Square"
+
+g.manual_seed(0)
+result1 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
+
+g.manual_seed(0)
+result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
+
+print("L_inf dist =", abs(result1 - result2).max())
+"L_inf dist = tensor(0., device='cuda:0')"
+```
diff --git a/diffusers/docs/source/en/using-diffusers/reusing_seeds.md b/diffusers/docs/source/en/using-diffusers/reusing_seeds.md
new file mode 100644
index 0000000000000000000000000000000000000000..d2638b469e302623bf9a1b0a7a7e784cab1a6d63
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/reusing_seeds.md
@@ -0,0 +1,67 @@
+
+
+# Improve image quality with deterministic generation
+
+[[open-in-colab]]
+
+A common way to improve the quality of generated images is with *deterministic batch generation*, generate a batch of images and select one image to improve with a more detailed prompt in a second round of inference. The key is to pass a list of [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator)'s to the pipeline for batched image generation, and tie each `Generator` to a seed so you can reuse it for an image.
+
+Let's use [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) for example, and generate several versions of the following prompt:
+
+```py
+prompt = "Labrador in the style of Vermeer"
+```
+
+Instantiate a pipeline with [`DiffusionPipeline.from_pretrained`] and place it on a GPU (if available):
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.utils import make_image_grid
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+pipe = pipe.to("cuda")
+```
+
+Now, define four different `Generator`s and assign each `Generator` a seed (`0` to `3`) so you can reuse a `Generator` later for a specific image:
+
+```python
+generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
+```
+
+Generate the images and have a look:
+
+```python
+images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
+make_image_grid(images, rows=2, cols=2)
+```
+
+![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds.jpg)
+
+In this example, you'll improve upon the first image - but in reality, you can use any image you want (even the image with double sets of eyes!). The first image used the `Generator` with seed `0`, so you'll reuse that `Generator` for the second round of inference. To improve the quality of the image, add some additional text to the prompt:
+
+```python
+prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
+generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
+```
+
+Create four generators with seed `0`, and generate another batch of images, all of which should look like the first image from the previous round!
+
+```python
+images = pipe(prompt, generator=generator).images
+make_image_grid(images, rows=2, cols=2)
+```
+
+![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds_2.jpg)
diff --git a/diffusers/docs/source/en/using-diffusers/schedulers.md b/diffusers/docs/source/en/using-diffusers/schedulers.md
new file mode 100644
index 0000000000000000000000000000000000000000..6b5d8da465d89a15779dd628653689a424e5d13f
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/schedulers.md
@@ -0,0 +1,331 @@
+
+
+# Schedulers
+
+[[open-in-colab]]
+
+Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
+a pipeline to one's use case. The best example of this is the [Schedulers](../api/schedulers/overview).
+
+Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
+schedulers define the whole denoising process, *i.e.*:
+- How many denoising steps?
+- Stochastic or deterministic?
+- What algorithm to use to find the denoised sample?
+
+They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
+It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
+
+The following paragraphs show how to do so with the 🧨 Diffusers library.
+
+## Load pipeline
+
+Let's start by loading the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model in the [`DiffusionPipeline`]:
+
+```python
+from huggingface_hub import login
+from diffusers import DiffusionPipeline
+import torch
+
+login()
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+```
+
+Next, we move it to GPU:
+
+```python
+pipeline.to("cuda")
+```
+
+## Access the scheduler
+
+The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`.
+So it can be accessed via the `"scheduler"` property.
+
+```python
+pipeline.scheduler
+```
+
+**Output**:
+```
+PNDMScheduler {
+ "_class_name": "PNDMScheduler",
+ "_diffusers_version": "0.21.4",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "num_train_timesteps": 1000,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "timestep_spacing": "leading",
+ "trained_betas": null
+}
+```
+
+We can see that the scheduler is of type [`PNDMScheduler`].
+Cool, now let's compare the scheduler in its performance to other schedulers.
+First we define a prompt on which we will test all the different schedulers:
+
+```python
+prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
+```
+
+Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline:
+
+```python
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+## Changing the scheduler
+
+Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`~SchedulerMixin.compatibles`]
+which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
+
+```python
+pipeline.scheduler.compatibles
+```
+
+**Output**:
+```
+[diffusers.utils.dummy_torch_and_torchsde_objects.DPMSolverSDEScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler]
+```
+
+Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
+
+- [`EulerDiscreteScheduler`],
+- [`LMSDiscreteScheduler`],
+- [`DDIMScheduler`],
+- [`DDPMScheduler`],
+- [`HeunDiscreteScheduler`],
+- [`DPMSolverMultistepScheduler`],
+- [`DEISMultistepScheduler`],
+- [`PNDMScheduler`],
+- [`EulerAncestralDiscreteScheduler`],
+- [`UniPCMultistepScheduler`],
+- [`KDPM2DiscreteScheduler`],
+- [`DPMSolverSinglestepScheduler`],
+- [`KDPM2AncestralDiscreteScheduler`].
+
+We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
+convenient [`~ConfigMixin.config`] property in combination with the [`~ConfigMixin.from_config`] function.
+
+```python
+pipeline.scheduler.config
+```
+
+returns a dictionary of the configuration of the scheduler:
+
+**Output**:
+```py
+FrozenDict([('num_train_timesteps', 1000),
+ ('beta_start', 0.00085),
+ ('beta_end', 0.012),
+ ('beta_schedule', 'scaled_linear'),
+ ('trained_betas', None),
+ ('skip_prk_steps', True),
+ ('set_alpha_to_one', False),
+ ('prediction_type', 'epsilon'),
+ ('timestep_spacing', 'leading'),
+ ('steps_offset', 1),
+ ('_use_default_values', ['timestep_spacing', 'prediction_type']),
+ ('_class_name', 'PNDMScheduler'),
+ ('_diffusers_version', '0.21.4'),
+ ('clip_sample', False)])
+```
+
+This configuration can then be used to instantiate a scheduler
+of a different class that is compatible with the pipeline. Here,
+we change the scheduler to the [`DDIMScheduler`].
+
+```python
+from diffusers import DDIMScheduler
+
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+```
+
+Cool, now we can run the pipeline again to compare the generation quality.
+
+```python
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-in-flax) instead.
+
+## Compare schedulers
+
+So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
+A number of better schedulers have been released that can be run with much fewer steps; let's compare them here:
+
+[`LMSDiscreteScheduler`] usually leads to better results:
+
+```python
+from diffusers import LMSDiscreteScheduler
+
+pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps.
+
+```python
+from diffusers import EulerDiscreteScheduler
+
+pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
+image
+```
+
+
+
+
+[`DPMSolverMultistepScheduler`] gives a reasonable speed/quality trade-off and can be run with as little as 20 steps.
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+
+
+As you can see, most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
+schedulers to compare results.
+
+## Changing the Scheduler in Flax
+
+If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
+
+```Python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+
+from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
+
+model_id = "runwayml/stable-diffusion-v1-5"
+scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ model_id,
+ subfolder="scheduler"
+)
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ model_id,
+ scheduler=scheduler,
+ revision="bf16",
+ dtype=jax.numpy.bfloat16,
+)
+params["scheduler"] = scheduler_state
+
+# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
+prompt = "a photo of an astronaut riding a horse on mars"
+num_samples = jax.device_count()
+prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
+
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 25
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+```
+
+
+
+The following Flax schedulers are _not yet compatible_ with the Flax Stable Diffusion Pipeline:
+
+- `FlaxLMSDiscreteScheduler`
+- `FlaxDDPMScheduler`
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/sdxl.md b/diffusers/docs/source/en/using-diffusers/sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..25b581fc6f6fccd95eed07c6a9eebd41d6e5b321
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/sdxl.md
@@ -0,0 +1,451 @@
+
+
+# Stable Diffusion XL
+
+[[open-in-colab]]
+
+[Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) is a powerful text-to-image generation model that iterates on the previous Stable Diffusion models in three key ways:
+
+1. the UNet is 3x larger and SDXL combines a second text encoder (OpenCLIP ViT-bigG/14) with the original text encoder to significantly increase the number of parameters
+2. introduces size and crop-conditioning to preserve training data from being discarded and gain more control over how a generated image should be cropped
+3. introduces a two-stage model process; the *base* model (can also be run as a standalone model) generates an image as an input to the *refiner* model which adds additional high-quality details
+
+This guide will show you how to use SDXL for text-to-image, image-to-image, and inpainting.
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q diffusers transformers accelerate omegaconf invisible-watermark>=0.2.0
+```
+
+
+
+We recommend installing the [invisible-watermark](https://pypi.org/project/invisible-watermark/) library to help identify images that are generated. If the invisible-watermark library is installed, it is used by default. To disable the watermarker:
+
+```py
+pipeline = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
+```
+
+
+
+## Load model checkpoints
+
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:
+
+```py
+from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
+import torch
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
+).to("cuda")
+```
+
+You can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally:
+
+```py
+from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
+import torch
+
+pipeline = StableDiffusionXLPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
+).to("cuda")
+```
+
+## Text-to-image
+
+For text-to-image, pass a text prompt. By default, SDXL generates a 1024x1024 image for the best results. You can try setting the `height` and `width` parameters to 768x768 or 512x512, but anything below 512x512 is not likely to work.
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = pipeline_text2image(prompt=prompt).images[0]
+image
+```
+
+
+
+
+
+## Image-to-image
+
+For image-to-image, SDXL works especially well with image sizes between 768x768 and 1024x1024. Pass an initial image, and a text prompt to condition the image with:
+
+```py
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import load_image, make_image_grid
+
+# use from_pipe to avoid consuming additional memory when loading a checkpoint
+pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
+
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+init_image = load_image(url)
+prompt = "a dog catching a frisbee in the jungle"
+image = pipeline(prompt, image=init_image, strength=0.8, guidance_scale=10.5).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
+```
+
+
+
+
+
+## Inpainting
+
+For inpainting, you'll need the original image and a mask of what you want to replace in the original image. Create a prompt to describe what you want to replace the masked area with.
+
+```py
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image, make_image_grid
+
+# use from_pipe to avoid consuming additional memory when loading a checkpoint
+pipeline = AutoPipelineForInpainting.from_pipe(pipeline_text2image).to("cuda")
+
+img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
+
+init_image = load_image(img_url)
+mask_image = load_image(mask_url)
+
+prompt = "A deep sea diver floating"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, guidance_scale=12.5).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+```
+
+
+
+
+
+## Refine image quality
+
+SDXL includes a [refiner model](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) specialized in denoising low-noise stage images to generate higher-quality images from the base model. There are two ways to use the refiner:
+
+1. use the base and refiner models together to produce a refined image
+2. use the base model to produce an image, and subsequently use the refiner model to add more details to the image (this is how SDXL was originally trained)
+
+### Base + refiner model
+
+When you use the base and refiner model together to generate an image, this is known as an [*ensemble of expert denoisers*](https://research.nvidia.com/labs/dir/eDiff-I/). The ensemble of expert denoisers approach requires fewer overall denoising steps versus passing the base model's output to the refiner model, so it should be significantly faster to run. However, you won't be able to inspect the base model's output because it still contains a large amount of noise.
+
+As an ensemble of expert denoisers, the base model serves as the expert during the high-noise diffusion stage and the refiner model serves as the expert during the low-noise diffusion stage. Load the base and refiner model:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+base = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+refiner = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=base.text_encoder_2,
+ vae=base.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+).to("cuda")
+```
+
+To use this approach, you need to define the number of timesteps for each model to run through their respective stages. For the base model, this is controlled by the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) parameter and for the refiner model, it is controlled by the [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) parameter.
+
+
+
+The `denoising_end` and `denoising_start` parameters should be a float between 0 and 1. These parameters are represented as a proportion of discrete timesteps as defined by the scheduler. If you're also using the `strength` parameter, it'll be ignored because the number of denoising steps is determined by the discrete timesteps the model is trained on and the declared fractional cutoff.
+
+
+
+Let's set `denoising_end=0.8` so the base model performs the first 80% of denoising the **high-noise** timesteps and set `denoising_start=0.8` so the refiner model performs the last 20% of denoising the **low-noise** timesteps. The base model output should be in **latent** space instead of a PIL image.
+
+```py
+prompt = "A majestic lion jumping from a big stone at night"
+
+image = base(
+ prompt=prompt,
+ num_inference_steps=40,
+ denoising_end=0.8,
+ output_type="latent",
+).images
+image = refiner(
+ prompt=prompt,
+ num_inference_steps=40,
+ denoising_start=0.8,
+ image=image,
+).images[0]
+image
+```
+
+
+
+
+ default base model
+
+
+
+ ensemble of expert denoisers
+
+
+
+The refiner model can also be used for inpainting in the [`StableDiffusionXLInpaintPipeline`]:
+
+```py
+from diffusers import StableDiffusionXLInpaintPipeline
+from diffusers.utils import load_image, make_image_grid
+import torch
+
+base = StableDiffusionXLInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=base.text_encoder_2,
+ vae=base.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+).to("cuda")
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url)
+mask_image = load_image(mask_url)
+
+prompt = "A majestic tiger sitting on a bench"
+num_inference_steps = 75
+high_noise_frac = 0.7
+
+image = base(
+ prompt=prompt,
+ image=init_image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ denoising_end=high_noise_frac,
+ output_type="latent",
+).images
+image = refiner(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ denoising_start=high_noise_frac,
+).images[0]
+make_image_grid([init_image, mask_image, image.resize((512, 512))], rows=1, cols=3)
+```
+
+This ensemble of expert denoisers method works well for all available schedulers!
+
+### Base to refiner model
+
+SDXL gets a boost in image quality by using the refiner model to add additional high-quality details to the fully-denoised image from the base model, in an image-to-image setting.
+
+Load the base and refiner models:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+base = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+refiner = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=base.text_encoder_2,
+ vae=base.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+).to("cuda")
+```
+
+Generate an image from the base model, and set the model output to **latent** space:
+
+```py
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+image = base(prompt=prompt, output_type="latent").images[0]
+```
+
+Pass the generated image to the refiner model:
+
+```py
+image = refiner(prompt=prompt, image=image[None, :]).images[0]
+```
+
+
+
+
+ base model
+
+
+
+ base model + refiner model
+
+
+
+For inpainting, load the base and the refiner model in the [`StableDiffusionXLInpaintPipeline`], remove the `denoising_end` and `denoising_start` parameters, and choose a smaller number of inference steps for the refiner.
+
+## Micro-conditioning
+
+SDXL training involves several additional conditioning techniques, which are referred to as *micro-conditioning*. These include original image size, target image size, and cropping parameters. The micro-conditionings can be used at inference time to create high-quality, centered images.
+
+
+
+You can use both micro-conditioning and negative micro-conditioning parameters thanks to classifier-free guidance. They are available in the [`StableDiffusionXLPipeline`], [`StableDiffusionXLImg2ImgPipeline`], [`StableDiffusionXLInpaintPipeline`], and [`StableDiffusionXLControlNetPipeline`].
+
+
+
+### Size conditioning
+
+There are two types of size conditioning:
+
+- [`original_size`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.original_size) conditioning comes from upscaled images in the training batch (because it would be wasteful to discard the smaller images which make up almost 40% of the total training data). This way, SDXL learns that upscaling artifacts are not supposed to be present in high-resolution images. During inference, you can use `original_size` to indicate the original image resolution. Using the default value of `(1024, 1024)` produces higher-quality images that resemble the 1024x1024 images in the dataset. If you choose to use a lower resolution, such as `(256, 256)`, the model still generates 1024x1024 images, but they'll look like the low resolution images (simpler patterns, blurring) in the dataset.
+
+- [`target_size`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.target_size) conditioning comes from finetuning SDXL to support different image aspect ratios. During inference, if you use the default value of `(1024, 1024)`, you'll get an image that resembles the composition of square images in the dataset. We recommend using the same value for `target_size` and `original_size`, but feel free to experiment with other options!
+
+🤗 Diffusers also lets you specify negative conditions about an image's size to steer generation away from certain image resolutions:
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = pipe(
+ prompt=prompt,
+ negative_original_size=(512, 512),
+ negative_target_size=(1024, 1024),
+).images[0]
+```
+
+
+
+ Images negatively conditioned on image resolutions of (128, 128), (256, 256), and (512, 512).
+
+
+### Crop conditioning
+
+Images generated by previous Stable Diffusion models may sometimes appear to be cropped. This is because images are actually cropped during training so that all the images in a batch have the same size. By conditioning on crop coordinates, SDXL *learns* that no cropping - coordinates `(0, 0)` - usually correlates with centered subjects and complete faces (this is the default value in 🤗 Diffusers). You can experiment with different coordinates if you want to generate off-centered compositions!
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = pipeline(prompt=prompt, crops_coords_top_left=(256, 0)).images[0]
+image
+```
+
+
+
+
+
+You can also specify negative cropping coordinates to steer generation away from certain cropping parameters:
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = pipe(
+ prompt=prompt,
+ negative_original_size=(512, 512),
+ negative_crops_coords_top_left=(0, 0),
+ negative_target_size=(1024, 1024),
+).images[0]
+image
+```
+
+## Use a different prompt for each text-encoder
+
+SDXL uses two text-encoders, so it is possible to pass a different prompt to each text-encoder, which can [improve quality](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201). Pass your original prompt to `prompt` and the second prompt to `prompt_2` (use `negative_prompt` and `negative_prompt_2` if you're using negative prompts):
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
+
+# prompt is passed to OAI CLIP-ViT/L-14
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+# prompt_2 is passed to OpenCLIP-ViT/bigG-14
+prompt_2 = "Van Gogh painting"
+image = pipeline(prompt=prompt, prompt_2=prompt_2).images[0]
+image
+```
+
+
+
+
+
+The dual text-encoders also support textual inversion embeddings that need to be loaded separately as explained in the [SDXL textual inversion](textual_inversion_inference#stable-diffusion-xl) section.
+
+## Optimizations
+
+SDXL is a large model, and you may need to optimize memory to get it to run on your hardware. Here are some tips to save memory and speed up inference.
+
+1. Offload the model to the CPU with [`~StableDiffusionXLPipeline.enable_model_cpu_offload`] for out-of-memory errors:
+
+```diff
+- base.to("cuda")
+- refiner.to("cuda")
++ base.enable_model_cpu_offload()
++ refiner.enable_model_cpu_offload()
+```
+
+2. Use `torch.compile` for ~20% speed-up (you need `torch>=2.0`):
+
+```diff
++ base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
++ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+3. Enable [xFormers](../optimization/xformers) to run SDXL if `torch<2.0`:
+
+```diff
++ base.enable_xformers_memory_efficient_attention()
++ refiner.enable_xformers_memory_efficient_attention()
+```
+
+## Other resources
+
+If you're interested in experimenting with a minimal version of the [`UNet2DConditionModel`] used in SDXL, take a look at the [minSDXL](https://github.com/cloneofsimo/minSDXL) implementation which is written in PyTorch and directly compatible with 🤗 Diffusers.
diff --git a/diffusers/docs/source/en/using-diffusers/shap-e.md b/diffusers/docs/source/en/using-diffusers/shap-e.md
new file mode 100644
index 0000000000000000000000000000000000000000..f0ce977584a5aa5ab8a8f4790c2d4ce21524d05a
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/shap-e.md
@@ -0,0 +1,192 @@
+
+
+# Shap-E
+
+[[open-in-colab]]
+
+Shap-E is a conditional model for generating 3D assets which could be used for video game development, interior design, and architecture. It is trained on a large dataset of 3D assets, and post-processed to render more views of each object and produce 16K instead of 4K point clouds. The Shap-E model is trained in two steps:
+
+1. an encoder accepts the point clouds and rendered views of a 3D asset and outputs the parameters of implicit functions that represent the asset
+2. a diffusion model is trained on the latents produced by the encoder to generate either neural radiance fields (NeRFs) or a textured 3D mesh, making it easier to render and use the 3D asset in downstream applications
+
+This guide will show you how to use Shap-E to start generating your own 3D assets!
+
+Before you begin, make sure you have the following libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q diffusers transformers accelerate trimesh
+```
+
+## Text-to-3D
+
+To generate a gif of a 3D object, pass a text prompt to the [`ShapEPipeline`]. The pipeline generates a list of image frames which are used to create the 3D object.
+
+```py
+import torch
+from diffusers import ShapEPipeline
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16, variant="fp16")
+pipe = pipe.to(device)
+
+guidance_scale = 15.0
+prompt = ["A firecracker", "A birthday cupcake"]
+
+images = pipe(
+ prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=64,
+ frame_size=256,
+).images
+```
+
+Now use the [`~utils.export_to_gif`] function to turn the list of image frames into a gif of the 3D object.
+
+```py
+from diffusers.utils import export_to_gif
+
+export_to_gif(images[0], "firecracker_3d.gif")
+export_to_gif(images[1], "cake_3d.gif")
+```
+
+
+
+
+ prompt = "A firecracker"
+
+
+
+ prompt = "A birthday cupcake"
+
+
+
+## Image-to-3D
+
+To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+prior_pipeline = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+pipeline = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+
+prompt = "A cheeseburger, white background"
+
+image_embeds, negative_image_embeds = prior_pipeline(prompt, guidance_scale=1.0).to_tuple()
+image = pipeline(
+ prompt,
+ image_embeds=image_embeds,
+ negative_image_embeds=negative_image_embeds,
+).images[0]
+
+image.save("burger.png")
+```
+
+Pass the cheeseburger to the [`ShapEImg2ImgPipeline`] to generate a 3D representation of it.
+
+```py
+from PIL import Image
+from diffusers import ShapEImg2ImgPipeline
+from diffusers.utils import export_to_gif
+
+pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16, variant="fp16").to("cuda")
+
+guidance_scale = 3.0
+image = Image.open("burger.png").resize((256, 256))
+
+images = pipe(
+ image,
+ guidance_scale=guidance_scale,
+ num_inference_steps=64,
+ frame_size=256,
+).images
+
+gif_path = export_to_gif(images[0], "burger_3d.gif")
+```
+
+
+
+
+ cheeseburger
+
+
+
+ 3D cheeseburger
+
+
+
+## Generate mesh
+
+Shap-E is a flexible model that can also generate textured mesh outputs to be rendered for downstream applications. In this example, you'll convert the output into a `glb` file because the 🤗 Datasets library supports mesh visualization of `glb` files which can be rendered by the [Dataset viewer](https://huggingface.co/docs/hub/datasets-viewer#dataset-preview).
+
+You can generate mesh outputs for both the [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`] by specifying the `output_type` parameter as `"mesh"`:
+
+```py
+import torch
+from diffusers import ShapEPipeline
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16, variant="fp16")
+pipe = pipe.to(device)
+
+guidance_scale = 15.0
+prompt = "A birthday cupcake"
+
+images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type="mesh").images
+```
+
+Use the [`~utils.export_to_ply`] function to save the mesh output as a `ply` file:
+
+
+
+You can optionally save the mesh output as an `obj` file with the [`~utils.export_to_obj`] function. The ability to save the mesh output in a variety of formats makes it more flexible for downstream usage!
+
+
+
+```py
+from diffusers.utils import export_to_ply
+
+ply_path = export_to_ply(images[0], "3d_cake.ply")
+print(f"Saved to folder: {ply_path}")
+```
+
+Then you can convert the `ply` file to a `glb` file with the trimesh library:
+
+```py
+import trimesh
+
+mesh = trimesh.load("3d_cake.ply")
+mesh_export = mesh.export("3d_cake.glb", file_type="glb")
+```
+
+By default, the mesh output is focused from the bottom viewpoint but you can change the default viewpoint by applying a rotation transform:
+
+```py
+import trimesh
+import numpy as np
+
+mesh = trimesh.load("3d_cake.ply")
+rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
+mesh = mesh.apply_transform(rot)
+mesh_export = mesh.export("3d_cake.glb", file_type="glb")
+```
+
+Upload the mesh file to your dataset repository to visualize it with the Dataset viewer!
+
+
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md b/diffusers/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
new file mode 100644
index 0000000000000000000000000000000000000000..6f75ba2c399977e0ac8934e4661982bd35f5a8db
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
@@ -0,0 +1,219 @@
+
+
+# JAX/Flax
+
+[[open-in-colab]]
+
+🤗 Diffusers supports Flax for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform. This guide shows you how to run inference with Stable Diffusion using JAX/Flax.
+
+Before you begin, make sure you have the necessary libraries installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
+#!pip install -q diffusers
+```
+
+You should also make sure you're using a TPU backend. While JAX does not run exclusively on TPUs, you'll get the best performance on a TPU because each server has 8 TPU accelerators working in parallel.
+
+If you are running this guide in Colab, select *Runtime* in the menu above, select the option *Change runtime type*, and then select *TPU* under the *Hardware accelerator* setting. Import JAX and quickly check whether you're using a TPU:
+
+```python
+import jax
+import jax.tools.colab_tpu
+jax.tools.colab_tpu.setup_tpu()
+
+num_devices = jax.device_count()
+device_type = jax.devices()[0].device_kind
+
+print(f"Found {num_devices} JAX devices of type {device_type}.")
+assert (
+ "TPU" in device_type,
+ "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
+)
+# Found 8 JAX devices of type Cloud TPU.
+```
+
+Great, now you can import the rest of the dependencies you'll need:
+
+```python
+import jax.numpy as jnp
+from jax import pmap
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+
+from diffusers import FlaxStableDiffusionPipeline
+```
+
+## Load a model
+
+Flax is a functional framework, so models are stateless and parameters are stored outside of them. Loading a pretrained Flax pipeline returns *both* the pipeline and the model weights (or parameters). In this guide, you'll use `bfloat16`, a more efficient half-float type that is supported by TPUs (you can also use `float32` for full precision if you want).
+
+```python
+dtype = jnp.bfloat16
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ revision="bf16",
+ dtype=dtype,
+)
+```
+
+## Inference
+
+TPUs usually have 8 devices working in parallel, so let's use the same prompt for each device. This means you can perform inference on 8 devices at once, with each device generating one image. As a result, you'll get 8 images in the same amount of time it takes for one chip to generate a single image!
+
+
+
+Learn more details in the [How does parallelization work?](#how-does-parallelization-work) section.
+
+
+
+After replicating the prompt, get the tokenized text ids by calling the `prepare_inputs` function on the pipeline. The length of the tokenized text is set to 77 tokens as required by the configuration of the underlying CLIP text model.
+
+```python
+prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
+prompt = [prompt] * jax.device_count()
+prompt_ids = pipeline.prepare_inputs(prompt)
+prompt_ids.shape
+# (8, 77)
+```
+
+Model parameters and inputs have to be replicated across the 8 parallel devices. The parameters dictionary is replicated with [`flax.jax_utils.replicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate) which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
+
+```python
+# parameters
+p_params = replicate(params)
+
+# arrays
+prompt_ids = shard(prompt_ids)
+prompt_ids.shape
+# (8, 1, 77)
+```
+
+This shape means each one of the 8 devices receives as an input a `jnp` array with shape `(1, 77)`, where `1` is the batch size per device. On TPUs with sufficient memory, you could have a batch size larger than `1` if you want to generate multiple images (per chip) at once.
+
+Next, create a random number generator to pass to the generation function. This is standard procedure in Flax, which is very serious and opinionated about random numbers. All functions that deal with random numbers are expected to receive a generator to ensure reproducibility, even when you're training across multiple distributed devices.
+
+The helper function below uses a seed to initialize a random number generator. As long as you use the same seed, you'll get the exact same results. Feel free to use different seeds when exploring results later in the guide.
+
+```python
+def create_key(seed=0):
+ return jax.random.PRNGKey(seed)
+```
+
+The helper function, or `rng`, is split 8 times so each device receives a different generator and generates a different image.
+
+```python
+rng = create_key(0)
+rng = jax.random.split(rng, jax.device_count())
+```
+
+To take advantage of JAX's optimized speed on a TPU, pass `jit=True` to the pipeline to compile the JAX code into an efficient representation and to ensure the model runs in parallel across the 8 devices.
+
+
+
+You need to ensure all your inputs have the same shape in subsequent calls, otherwise JAX will need to recompile the code which is slower.
+
+
+
+The first inference run takes more time because it needs to compile the code, but subsequent calls (even with different inputs) are much faster. For example, it took more than a minute to compile on a TPU v2-8, but then it takes about **7s** on a future inference run!
+
+```py
+%%time
+images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
+
+# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
+# Wall time: 1min 29s
+```
+
+The returned array has shape `(8, 1, 512, 512, 3)` which should be reshaped to remove the second dimension and get 8 images of `512 × 512 × 3`. Then you can use the [`~utils.numpy_to_pil`] function to convert the arrays into images.
+
+```python
+from diffusers.utils import make_image_grid
+
+images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+images = pipeline.numpy_to_pil(images)
+make_image_grid(images, rows=2, cols=4)
+```
+
+![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
+
+## Using different prompts
+
+You don't necessarily have to use the same prompt on all devices. For example, to generate 8 different prompts:
+
+```python
+prompts = [
+ "Labrador in the style of Hokusai",
+ "Painting of a squirrel skating in New York",
+ "HAL-9000 in the style of Van Gogh",
+ "Times Square under water, with fish and a dolphin swimming around",
+ "Ancient Roman fresco showing a man working on his laptop",
+ "Close-up photograph of young black woman against urban background, high quality, bokeh",
+ "Armchair in the shape of an avocado",
+ "Clown astronaut in space, with Earth in the background",
+]
+
+prompt_ids = pipeline.prepare_inputs(prompts)
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, p_params, rng, jit=True).images
+images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+images = pipeline.numpy_to_pil(images)
+
+make_image_grid(images, 2, 4)
+```
+
+![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
+
+## How does parallelization work?
+
+The Flax pipeline in 🤗 Diffusers automatically compiles the model and runs it in parallel on all available devices. Let's take a closer look at how that process works.
+
+JAX parallelization can be done in multiple ways. The easiest one revolves around using the [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function to achieve single-program multiple-data (SPMD) parallelization. It means running several copies of the same code, each on different data inputs. More sophisticated approaches are possible, and you can go over to the JAX [documentation](https://jax.readthedocs.io/en/latest/index.html) to explore this topic in more detail if you are interested!
+
+`jax.pmap` does two things:
+
+1. Compiles (or "`jit`s") the code which is similar to `jax.jit()`. This does not happen when you call `pmap`, and only the first time the `pmap`ped function is called.
+2. Ensures the compiled code runs in parallel on all available devices.
+
+To demonstrate, call `pmap` on the pipeline's `_generate` method (this is a private method that generates images and may be renamed or removed in future releases of 🤗 Diffusers):
+
+```python
+p_generate = pmap(pipeline._generate)
+```
+
+After calling `pmap`, the prepared function `p_generate` will:
+
+1. Make a copy of the underlying function, `pipeline._generate`, on each device.
+2. Send each device a different portion of the input arguments (this is why it's necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
+
+The most important thing to pay attention to here is the batch size (1 in this example), and the input dimensions that make sense for your code. You don't have to change anything else to make the code work in parallel.
+
+The first time you call the pipeline takes more time, but the calls afterward are much faster. The `block_until_ready` function is used to correctly measure inference time because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking occurs automatically when you want to use the result of a computation that has not yet been materialized.
+
+```py
+%%time
+images = p_generate(prompt_ids, p_params, rng)
+images = images.block_until_ready()
+
+# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
+# Wall time: 1min 15s
+```
+
+Check your image dimensions to see if they're correct:
+
+```python
+images.shape
+# (8, 1, 512, 512, 3)
+```
diff --git a/diffusers/docs/source/en/using-diffusers/textual_inversion_inference.md b/diffusers/docs/source/en/using-diffusers/textual_inversion_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..084101c06ba326ba87407c7c1ed558f66b33a8c2
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/textual_inversion_inference.md
@@ -0,0 +1,118 @@
+
+
+# Textual inversion
+
+[[open-in-colab]]
+
+The [`StableDiffusionPipeline`] supports textual inversion, a technique that enables a model like Stable Diffusion to learn a new concept from just a few sample images. This gives you more control over the generated images and allows you to tailor the model towards specific concepts. You can get started quickly with a collection of community created concepts in the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer).
+
+This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](../training/text_inversion) training guide.
+
+Import the necessary libraries:
+
+```py
+import torch
+from diffusers import StableDiffusionPipeline
+from diffusers.utils import make_image_grid
+```
+
+## Stable Diffusion 1 and 2
+
+Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer):
+
+```py
+pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
+repo_id_embeds = "sd-concepts-library/cat-toy"
+```
+
+Now you can load a pipeline, and pass the pre-learned concept to it:
+
+```py
+pipeline = StableDiffusionPipeline.from_pretrained(
+ pretrained_model_name_or_path, torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+
+pipeline.load_textual_inversion(repo_id_embeds)
+```
+
+Create a prompt with the pre-learned concept by using the special placeholder token ``, and choose the number of samples and rows of images you'd like to generate:
+
+```py
+prompt = "a grafitti in a favela wall with a on it"
+
+num_samples_per_row = 2
+num_rows = 2
+```
+
+Then run the pipeline (feel free to adjust the parameters like `num_inference_steps` and `guidance_scale` to see how they affect image quality), save the generated images and visualize them with the helper function you created at the beginning:
+
+```py
+all_images = []
+for _ in range(num_rows):
+ images = pipeline(prompt, num_images_per_prompt=num_samples_per_row, num_inference_steps=50, guidance_scale=7.5).images
+ all_images.extend(images)
+
+grid = make_image_grid(all_images, num_rows, num_samples_per_row)
+grid
+```
+
+
+
+
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) can also use textual inversion vectors for inference. In contrast to Stable Diffusion 1 and 2, SDXL has two text encoders so you'll need two textual inversion embeddings - one for each text encoder model.
+
+Let's download the SDXL textual inversion embeddings and have a closer look at it's structure:
+
+```py
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+
+file = hf_hub_download("dn118/unaestheticXL", filename="unaestheticXLv31.safetensors")
+state_dict = load_file(file)
+state_dict
+```
+
+```
+{'clip_g': tensor([[ 0.0077, -0.0112, 0.0065, ..., 0.0195, 0.0159, 0.0275],
+ ...,
+ [-0.0170, 0.0213, 0.0143, ..., -0.0302, -0.0240, -0.0362]],
+ 'clip_l': tensor([[ 0.0023, 0.0192, 0.0213, ..., -0.0385, 0.0048, -0.0011],
+ ...,
+ [ 0.0475, -0.0508, -0.0145, ..., 0.0070, -0.0089, -0.0163]],
+```
+
+There are two tensors, `"clip_g"` and `"clip_l"`.
+`"clip_g"` corresponds to the bigger text encoder in SDXL and refers to
+`pipe.text_encoder_2` and `"clip_l"` refers to `pipe.text_encoder`.
+
+Now you can load each tensor separately by passing them along with the correct text encoder and tokenizer
+to [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16)
+pipe.to("cuda")
+
+pipe.load_textual_inversion(state_dict["clip_g"], token="unaestheticXLv31", text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
+pipe.load_textual_inversion(state_dict["clip_l"], token="unaestheticXLv31", text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
+
+# the embedding should be used as a negative embedding, so we pass it as a negative prompt
+generator = torch.Generator().manual_seed(33)
+image = pipe("a woman standing in front of a mountain", negative_prompt="unaestheticXLv31", generator=generator).images[0]
+image
+```
diff --git a/diffusers/docs/source/en/using-diffusers/unconditional_image_generation.md b/diffusers/docs/source/en/using-diffusers/unconditional_image_generation.md
new file mode 100644
index 0000000000000000000000000000000000000000..1983f6981e8fe98cb8e52cb09ceda07ae3d20cef
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -0,0 +1,68 @@
+
+
+# Unconditional image generation
+
+[[open-in-colab]]
+
+Unconditional image generation is a relatively straightforward task. The model only generates images - without any additional context like text or an image - resembling the training data it was trained on.
+
+The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference.
+
+Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
+You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/models?library=diffusers&sort=downloads) from the Hub (the checkpoint you'll use generates images of butterflies).
+
+
+
+💡 Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images.
+
+
+
+In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
+
+```python
+from diffusers import DiffusionPipeline
+
+generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
+```
+
+The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
+Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU.
+You can move the generator object to a GPU, just like you would in PyTorch:
+
+```python
+generator.to("cuda")
+```
+
+Now you can use the `generator` to generate an image:
+
+```python
+image = generator().images[0]
+image
+```
+
+The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
+
+You can save the image by calling:
+
+```python
+image.save("generated_image.png")
+```
+
+Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/using_safetensors.md b/diffusers/docs/source/en/using-diffusers/using_safetensors.md
new file mode 100644
index 0000000000000000000000000000000000000000..3e89e7eed9a014b1846e0d6799c41cf0a8311455
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/using_safetensors.md
@@ -0,0 +1,84 @@
+
+
+# Load safetensors
+
+[[open-in-colab]]
+
+[safetensors](https://github.com/huggingface/safetensors) is a safe and fast file format for storing and loading tensors. Typically, PyTorch model weights are saved or *pickled* into a `.bin` file with Python's [`pickle`](https://docs.python.org/3/library/pickle.html) utility. However, `pickle` is not secure and pickled files may contain malicious code that can be executed. safetensors is a secure alternative to `pickle`, making it ideal for sharing model weights.
+
+This guide will show you how you load `.safetensor` files, and how to convert Stable Diffusion model weights stored in other formats to `.safetensor`. Before you start, make sure you have safetensors installed:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install safetensors
+```
+
+If you look at the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) repository, you'll see weights inside the `text_encoder`, `unet` and `vae` subfolders are stored in the `.safetensors` format. By default, 🤗 Diffusers automatically loads these `.safetensors` files from their subfolders if they're available in the model repository.
+
+For more explicit control, you can optionally set `use_safetensors=True` (if `safetensors` is not installed, you'll get an error message asking you to install it):
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+```
+
+However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] method:
+
+```py
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_single_file(
+ "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
+)
+```
+
+## Convert to safetensors
+
+Not all weights on the Hub are available in the `.safetensors` format, and you may encounter weights stored as `.bin`. In this case, use the [Convert Space](https://huggingface.co/spaces/diffusers/convert) to convert the weights to `.safetensors`. The Convert Space downloads the pickled weights, converts them, and opens a Pull Request to upload the newly converted `.safetensors` file on the Hub. This way, if there is any malicious code contained in the pickled files, they're uploaded to the Hub - which has a [security scanner](https://huggingface.co/docs/hub/security-pickle#hubs-security-scanner) to detect unsafe files and suspicious pickle imports - instead of your computer.
+
+You can use the model with the new `.safetensors` weights by specifying the reference to the Pull Request in the `revision` parameter (you can also test it in this [Check PR](https://huggingface.co/spaces/diffusers/check_pr) Space on the Hub), for example `refs/pr/22`:
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1", revision="refs/pr/22", use_safetensors=True
+)
+```
+
+## Why use safetensors?
+
+There are several reasons for using safetensors:
+
+- Safety is the number one reason for using safetensors. As open-source and model distribution grows, it is important to be able to trust the model weights you downloaded don't contain any malicious code. The current size of the header in safetensors prevents parsing extremely large JSON files.
+- Loading speed between switching models is another reason to use safetensors, which performs zero-copy of the tensors. It is especially fast compared to `pickle` if you're loading the weights to CPU (the default case), and just as fast if not faster when directly loading the weights to GPU. You'll only notice the performance difference if the model is already loaded, and not if you're downloading the weights or loading the model for the first time.
+
+ The time it takes to load the entire pipeline:
+
+ ```py
+ from diffusers import StableDiffusionPipeline
+
+ pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", use_safetensors=True)
+ "Loaded in safetensors 0:00:02.033658"
+ "Loaded in PyTorch 0:00:02.663379"
+ ```
+
+ But the actual time it takes to load 500MB of the model weights is only:
+
+ ```bash
+ safetensors: 3.4873ms
+ PyTorch: 172.7537ms
+ ```
+
+- Lazy loading is also supported in safetensors, which is useful in distributed settings to only load some of the tensors. This format allowed the [BLOOM](https://huggingface.co/bigscience/bloom) model to be loaded in 45 seconds on 8 GPUs instead of 10 minutes with regular PyTorch weights.
diff --git a/diffusers/docs/source/en/using-diffusers/weighted_prompts.md b/diffusers/docs/source/en/using-diffusers/weighted_prompts.md
new file mode 100644
index 0000000000000000000000000000000000000000..947d18b86ec8b31c4472adda0330f3112b637bbc
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/weighted_prompts.md
@@ -0,0 +1,271 @@
+
+
+# Prompt weighting
+
+[[open-in-colab]]
+
+Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
+
+Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt-weighted embeddings is to use [Compel](https://github.com/damian0815/compel), a text prompt-weighting and blending library. Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [`prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [`negative_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
+
+
+
+If your favorite pipeline doesn't have a `prompt_embeds` parameter, please open an [issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
+
+
+
+This guide will show you how to weight and blend your prompts with Compel in 🤗 Diffusers.
+
+Before you begin, make sure you have the latest version of Compel installed:
+
+```py
+# uncomment to install in Colab
+#!pip install compel --upgrade
+```
+
+For this guide, let's generate an image with the prompt `"a red cat playing with a ball"` using the [`StableDiffusionPipeline`]:
+
+```py
+from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
+import torch
+
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_safetensors=True)
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.to("cuda")
+
+prompt = "a red cat playing with a ball"
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+
+image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+## Weighting
+
+You'll notice there is no "ball" in the image! Let's use compel to upweight the concept of "ball" in the prompt. Create a [`Compel`](https://github.com/damian0815/compel/blob/main/doc/compel.md#compel-objects) object, and pass it a tokenizer and text encoder:
+
+```py
+from compel import Compel
+
+compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
+```
+
+compel uses `+` or `-` to increase or decrease the weight of a word in the prompt. To increase the weight of "ball":
+
+
+
+`+` corresponds to the value `1.1`, `++` corresponds to `1.1^2`, and so on. Similarly, `-` corresponds to `0.9` and `--` corresponds to `0.9^2`. Feel free to experiment with adding more `+` or `-` in your prompt!
+
+
+
+```py
+prompt = "a red cat playing with a ball++"
+```
+
+Pass the prompt to `compel_proc` to create the new prompt embeddings which are passed to the pipeline:
+
+```py
+prompt_embeds = compel_proc(prompt)
+generator = torch.manual_seed(33)
+
+image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+To downweight parts of the prompt, use the `-` suffix:
+
+```py
+prompt = "a red------- cat playing with a ball"
+prompt_embeds = compel_proc(prompt)
+
+generator = torch.manual_seed(33)
+
+image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+You can even up or downweight multiple concepts in the same prompt:
+
+```py
+prompt = "a red cat++ playing with a ball----"
+prompt_embeds = compel_proc(prompt)
+
+generator = torch.manual_seed(33)
+
+image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+## Blending
+
+You can also create a weighted *blend* of prompts by adding `.blend()` to a list of prompts and passing it some weights. Your blend may not always produce the result you expect because it breaks some assumptions about how the text encoder functions, so just have fun and experiment with it!
+
+```py
+prompt_embeds = compel_proc('("a red cat playing with a ball", "jungle").blend(0.7, 0.8)')
+generator = torch.Generator(device="cuda").manual_seed(33)
+
+image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+## Conjunction
+
+A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
+
+```py
+prompt_embeds = compel_proc('["a red cat", "playing with a", "ball"].and()')
+generator = torch.Generator(device="cuda").manual_seed(55)
+
+image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+## Textual inversion
+
+[Textual inversion](../training/text_inversion) is a technique for learning a specific concept from some images which you can use to generate new images conditioned on that concept.
+
+Create a pipeline and use the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] function to load the textual inversion embeddings (feel free to browse the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer) for 100+ trained concepts):
+
+```py
+import torch
+from diffusers import StableDiffusionPipeline
+from compel import Compel, DiffusersTextualInversionManager
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
+ use_safetensors=True, variant="fp16").to("cuda")
+pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
+```
+
+Compel provides a `DiffusersTextualInversionManager` class to simplify prompt weighting with textual inversion. Instantiate `DiffusersTextualInversionManager` and pass it to the `Compel` class:
+
+```py
+textual_inversion_manager = DiffusersTextualInversionManager(pipe)
+compel_proc = Compel(
+ tokenizer=pipe.tokenizer,
+ text_encoder=pipe.text_encoder,
+ textual_inversion_manager=textual_inversion_manager)
+```
+
+Incorporate the concept to condition a prompt with using the `` syntax:
+
+```py
+prompt_embeds = compel_proc('("A red cat++ playing with a ball ")')
+
+image = pipe(prompt_embeds=prompt_embeds).images[0]
+image
+```
+
+
+
+
+
+## DreamBooth
+
+[DreamBooth](../training/dreambooth) is a technique for generating contextualized images of a subject given just a few images of the subject to train on. It is similar to textual inversion, but DreamBooth trains the full model whereas textual inversion only fine-tunes the text embeddings. This means you should use [`~DiffusionPipeline.from_pretrained`] to load the DreamBooth model (feel free to browse the [Stable Diffusion Dreambooth Concepts Library](https://huggingface.co/sd-dreambooth-library) for 100+ trained models):
+
+```py
+import torch
+from diffusers import DiffusionPipeline, UniPCMultistepScheduler
+from compel import Compel
+
+pipe = DiffusionPipeline.from_pretrained("sd-dreambooth-library/dndcoverart-v1", torch_dtype=torch.float16).to("cuda")
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+```
+
+Create a `Compel` class with a tokenizer and text encoder, and pass your prompt to it. Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
+
+```py
+compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
+prompt_embeds = compel_proc('("magazine cover of a dndcoverart dragon, high quality, intricate details, larry elmore art style").and()')
+image = pipe(prompt_embeds=prompt_embeds).images[0]
+image
+```
+
+
+
+
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) has two tokenizers and text encoders so it's usage is a bit different. To address this, you should pass both tokenizers and encoders to the `Compel` class:
+
+```py
+from compel import Compel, ReturnedEmbeddingsType
+from diffusers import DiffusionPipeline
+from diffusers.utils import make_image_grid
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=torch.float16
+).to("cuda")
+
+compel = Compel(
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
+ requires_pooled=[False, True]
+)
+```
+
+This time, let's upweight "ball" by a factor of 1.5 for the first prompt, and downweight "ball" by 0.6 for the second prompt. The [`StableDiffusionXLPipeline`] also requires [`pooled_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline.__call__.pooled_prompt_embeds) (and optionally [`negative_pooled_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline.__call__.negative_pooled_prompt_embeds)) so you should pass those to the pipeline along with the conditioning tensors:
+
+```py
+# apply weights
+prompt = ["a red cat playing with a (ball)1.5", "a red cat playing with a (ball)0.6"]
+conditioning, pooled = compel(prompt)
+
+# generate image
+generator = [torch.Generator().manual_seed(33) for _ in range(len(prompt))]
+images = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, generator=generator, num_inference_steps=30).images
+make_image_grid(images, rows=1, cols=2)
+```
+
+
+
+
+ "a red cat playing with a (ball)1.5"
+
+
+
+ "a red cat playing with a (ball)0.6"
+
+
diff --git a/diffusers/docs/source/en/using-diffusers/write_own_pipeline.md b/diffusers/docs/source/en/using-diffusers/write_own_pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..4ca3fe33223bd30bf2ca8ed401083efa5a7f3c5a
--- /dev/null
+++ b/diffusers/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -0,0 +1,294 @@
+
+
+# Understanding pipelines, models and schedulers
+
+[[open-in-colab]]
+
+🧨 Diffusers is designed to be a user-friendly and flexible toolbox for building diffusion systems tailored to your use-case. At the core of the toolbox are models and schedulers. While the [`DiffusionPipeline`] bundles these components together for convenience, you can also unbundle the pipeline and use the models and schedulers separately to create new diffusion systems.
+
+In this tutorial, you'll learn how to use models and schedulers to assemble a diffusion system for inference, starting with a basic pipeline and then progressing to the Stable Diffusion pipeline.
+
+## Deconstruct a basic pipeline
+
+A pipeline is a quick and easy way to run a model for inference, requiring no more than four lines of code to generate an image:
+
+```py
+>>> from diffusers import DDPMPipeline
+
+>>> ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256", use_safetensors=True).to("cuda")
+>>> image = ddpm(num_inference_steps=25).images[0]
+>>> image
+```
+
+
+
+
+
+That was super easy, but how did the pipeline do that? Let's breakdown the pipeline and take a look at what's happening under the hood.
+
+In the example above, the pipeline contains a [`UNet2DModel`] model and a [`DDPMScheduler`]. The pipeline denoises an image by taking random noise the size of the desired output and passing it through the model several times. At each timestep, the model predicts the *noise residual* and the scheduler uses it to predict a less noisy image. The pipeline repeats this process until it reaches the end of the specified number of inference steps.
+
+To recreate the pipeline with the model and scheduler separately, let's write our own denoising process.
+
+1. Load the model and scheduler:
+
+```py
+>>> from diffusers import DDPMScheduler, UNet2DModel
+
+>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
+>>> model = UNet2DModel.from_pretrained("google/ddpm-cat-256", use_safetensors=True).to("cuda")
+```
+
+2. Set the number of timesteps to run the denoising process for:
+
+```py
+>>> scheduler.set_timesteps(50)
+```
+
+3. Setting the scheduler timesteps creates a tensor with evenly spaced elements in it, 50 in this example. Each element corresponds to a timestep at which the model denoises an image. When you create the denoising loop later, you'll iterate over this tensor to denoise an image:
+
+```py
+>>> scheduler.timesteps
+tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,
+ 700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440,
+ 420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160,
+ 140, 120, 100, 80, 60, 40, 20, 0])
+```
+
+4. Create some random noise with the same shape as the desired output:
+
+```py
+>>> import torch
+
+>>> sample_size = model.config.sample_size
+>>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
+```
+
+5. Now write a loop to iterate over the timesteps. At each timestep, the model does a [`UNet2DModel.forward`] pass and returns the noisy residual. The scheduler's [`~DDPMScheduler.step`] method takes the noisy residual, timestep, and input and it predicts the image at the previous timestep. This output becomes the next input to the model in the denoising loop, and it'll repeat until it reaches the end of the `timesteps` array.
+
+```py
+>>> input = noise
+
+>>> for t in scheduler.timesteps:
+... with torch.no_grad():
+... noisy_residual = model(input, t).sample
+... previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
+... input = previous_noisy_sample
+```
+
+This is the entire denoising process, and you can use this same pattern to write any diffusion system.
+
+6. The last step is to convert the denoised output into an image:
+
+```py
+>>> from PIL import Image
+>>> import numpy as np
+
+>>> image = (input / 2 + 0.5).clamp(0, 1).squeeze()
+>>> image = (image.permute(1, 2, 0) * 255).round().to(torch.uint8).cpu().numpy()
+>>> image = Image.fromarray(image)
+>>> image
+```
+
+In the next section, you'll put your skills to the test and breakdown the more complex Stable Diffusion pipeline. The steps are more or less the same. You'll initialize the necessary components, and set the number of timesteps to create a `timestep` array. The `timestep` array is used in the denoising loop, and for each element in this array, the model predicts a less noisy image. The denoising loop iterates over the `timestep`'s, and at each timestep, it outputs a noisy residual and the scheduler uses it to predict a less noisy image at the previous timestep. This process is repeated until you reach the end of the `timestep` array.
+
+Let's try it out!
+
+## Deconstruct the Stable Diffusion pipeline
+
+Stable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder to convert the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.
+
+As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
+
+
+
+💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
+
+
+
+Now that you know what you need for the Stable Diffusion pipeline, load all these components with the [`~ModelMixin.from_pretrained`] method. You can find them in the pretrained [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint, and each component is stored in a separate subfolder:
+
+```py
+>>> from PIL import Image
+>>> import torch
+>>> from transformers import CLIPTextModel, CLIPTokenizer
+>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
+
+>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
+>>> tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
+>>> text_encoder = CLIPTextModel.from_pretrained(
+... "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
+... )
+>>> unet = UNet2DConditionModel.from_pretrained(
+... "CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True
+... )
+```
+
+Instead of the default [`PNDMScheduler`], exchange it for the [`UniPCMultistepScheduler`] to see how easy it is to plug a different scheduler in:
+
+```py
+>>> from diffusers import UniPCMultistepScheduler
+
+>>> scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+```
+
+To speed up inference, move the models to a GPU since, unlike the scheduler, they have trainable weights:
+
+```py
+>>> torch_device = "cuda"
+>>> vae.to(torch_device)
+>>> text_encoder.to(torch_device)
+>>> unet.to(torch_device)
+```
+
+### Create text embeddings
+
+The next step is to tokenize the text to generate embeddings. The text is used to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.
+
+
+
+💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.
+
+
+
+Feel free to choose any prompt you like if you want to generate something else!
+
+```py
+>>> prompt = ["a photograph of an astronaut riding a horse"]
+>>> height = 512 # default height of Stable Diffusion
+>>> width = 512 # default width of Stable Diffusion
+>>> num_inference_steps = 25 # Number of denoising steps
+>>> guidance_scale = 7.5 # Scale for classifier-free guidance
+>>> generator = torch.manual_seed(0) # Seed generator to create the initial latent noise
+>>> batch_size = len(prompt)
+```
+
+Tokenize the text and generate the embeddings from the prompt:
+
+```py
+>>> text_input = tokenizer(
+... prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
+... )
+
+>>> with torch.no_grad():
+... text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
+```
+
+You'll also need to generate the *unconditional text embeddings* which are the embeddings for the padding token. These need to have the same shape (`batch_size` and `seq_length`) as the conditional `text_embeddings`:
+
+```py
+>>> max_length = text_input.input_ids.shape[-1]
+>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
+>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
+```
+
+Let's concatenate the conditional and unconditional embeddings into a batch to avoid doing two forward passes:
+
+```py
+>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+```
+
+### Create random noise
+
+Next, generate some initial random noise as a starting point for the diffusion process. This is the latent representation of the image, and it'll be gradually denoised. At this point, the `latent` image is smaller than the final image size but that's okay though because the model will transform it into the final 512x512 image dimensions later.
+
+
+
+💡 The height and width are divided by 8 because the `vae` model has 3 down-sampling layers. You can check by running the following:
+
+```py
+2 ** (len(vae.config.block_out_channels) - 1) == 8
+```
+
+
+
+```py
+>>> latents = torch.randn(
+... (batch_size, unet.config.in_channels, height // 8, width // 8),
+... generator=generator,
+... device=torch_device,
+... )
+```
+
+### Denoise the image
+
+Start by scaling the input with the initial noise distribution, *sigma*, the noise scale value, which is required for improved schedulers like [`UniPCMultistepScheduler`]:
+
+```py
+>>> latents = latents * scheduler.init_noise_sigma
+```
+
+The last step is to create the denoising loop that'll progressively transform the pure noise in `latents` to an image described by your prompt. Remember, the denoising loop needs to do three things:
+
+1. Set the scheduler's timesteps to use during denoising.
+2. Iterate over the timesteps.
+3. At each timestep, call the UNet model to predict the noise residual and pass it to the scheduler to compute the previous noisy sample.
+
+```py
+>>> from tqdm.auto import tqdm
+
+>>> scheduler.set_timesteps(num_inference_steps)
+
+>>> for t in tqdm(scheduler.timesteps):
+... # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
+... latent_model_input = torch.cat([latents] * 2)
+
+... latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
+
+... # predict the noise residual
+... with torch.no_grad():
+... noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+... # perform guidance
+... noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+... noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+... # compute the previous noisy sample x_t -> x_t-1
+... latents = scheduler.step(noise_pred, t, latents).prev_sample
+```
+
+### Decode the image
+
+The final step is to use the `vae` to decode the latent representation into an image and get the decoded output with `sample`:
+
+```py
+# scale and decode the image latents with vae
+latents = 1 / 0.18215 * latents
+with torch.no_grad():
+ image = vae.decode(latents).sample
+```
+
+Lastly, convert the image to a `PIL.Image` to see your generated image!
+
+```py
+>>> image = (image / 2 + 0.5).clamp(0, 1).squeeze()
+>>> image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
+>>> images = (image * 255).round().astype("uint8")
+>>> image = Image.fromarray(image)
+>>> image
+```
+
+
+
+
+
+## Next steps
+
+From basic to complex pipelines, you've seen that all you really need to write your own diffusion system is a denoising loop. The loop should set the scheduler's timesteps, iterate over them, and alternate between calling the UNet model to predict the noise residual and passing it to the scheduler to compute the previous noisy sample.
+
+This is really what 🧨 Diffusers is designed for: to make it intuitive and easy to write your own diffusion system using models and schedulers.
+
+For your next steps, feel free to:
+
+* Learn how to [build and contribute a pipeline](../using-diffusers/contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with!
+* Explore [existing pipelines](../api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately.
diff --git a/diffusers/docs/source/ja/_toctree.yml b/diffusers/docs/source/ja/_toctree.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7af1f9f2b28dc46306c7308693edfa1c74a0c35f
--- /dev/null
+++ b/diffusers/docs/source/ja/_toctree.yml
@@ -0,0 +1,10 @@
+- sections:
+ - local: index
+ title: 🧨 Diffusers
+ - local: quicktour
+ title: 簡単な案内
+ - local: stable_diffusion
+ title: 効果的で効率的な拡散モデル
+ - local: installation
+ title: インストール
+ title: はじめに
\ No newline at end of file
diff --git a/diffusers/docs/source/ja/index.md b/diffusers/docs/source/ja/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e8ba78dd55f015b35e6e4b486a0755a17961000
--- /dev/null
+++ b/diffusers/docs/source/ja/index.md
@@ -0,0 +1,98 @@
+
+
+
+
+### より良いプロンプト・エンジニアリング
+
+画像を生成するために使用する文章は、*プロンプトエンジニアリング*と呼ばれる分野を作られるほど、非常に重要です。プロンプト・エンジニアリングで考慮すべき点は以下の通りです:
+
+- 生成したい画像やその類似画像は、インターネット上にどのように保存されているか?
+- 私が望むスタイルにモデルを誘導するために、どのような追加詳細を与えるべきか?
+
+このことを念頭に置いて、プロンプトに色やより質の高いディテールを含めるように改良してみましょう:
+
+```python
+prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
+prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
+```
+
+新しいプロンプトで画像のバッチを生成しましょう:
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+かなりいいです!種が`1`の`Generator`に対応する2番目の画像に、被写体の年齢に関するテキストを追加して、もう少し手を加えてみましょう:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+
+
+
+## 次のステップ
+
+このチュートリアルでは、[`DiffusionPipeline`]を最適化して計算効率とメモリ効率を向上させ、生成される出力の品質を向上させる方法を学びました。パイプラインをさらに高速化することに興味があれば、以下のリソースを参照してください:
+
+- [PyTorch 2.0](./optimization/torch2.0)と[`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html)がどのように生成速度を5-300%高速化できるかを学んでください。A100 GPUの場合、画像生成は最大50%速くなります!
+- PyTorch 2が使えない場合は、[xFormers](./optimization/xformers)をインストールすることをお勧めします。このライブラリのメモリ効率の良いアテンションメカニズムは PyTorch 1.13.1 と相性が良く、高速化とメモリ消費量の削減を同時に実現します。
+- モデルのオフロードなど、その他の最適化テクニックは [this guide](./optimization/fp16) でカバーされています。
diff --git a/diffusers/docs/source/ko/_toctree.yml b/diffusers/docs/source/ko/_toctree.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c63fe3d9718d95a983aeb5752aa686fe3b826d4a
--- /dev/null
+++ b/diffusers/docs/source/ko/_toctree.yml
@@ -0,0 +1,132 @@
+- sections:
+ - local: index
+ title: "🧨 Diffusers"
+ - local: quicktour
+ title: "훑어보기"
+ - local: stable_diffusion
+ title: Stable Diffusion
+ - local: installation
+ title: "설치"
+ title: "시작하기"
+- sections:
+ - local: tutorials/tutorial_overview
+ title: 개요
+ - local: using-diffusers/write_own_pipeline
+ title: 모델과 스케줄러 이해하기
+ - local: in_translation
+ title: AutoPipeline
+ - local: tutorials/basic_training
+ title: Diffusion 모델 학습하기
+ title: Tutorials
+- sections:
+ - sections:
+ - local: using-diffusers/loading_overview
+ title: 개요
+ - local: using-diffusers/loading
+ title: 파이프라인, 모델, 스케줄러 불러오기
+ - local: using-diffusers/schedulers
+ title: 다른 스케줄러들을 가져오고 비교하기
+ - local: using-diffusers/custom_pipeline_overview
+ title: 커뮤니티 파이프라인 불러오기
+ - local: using-diffusers/using_safetensors
+ title: 세이프텐서 불러오기
+ - local: using-diffusers/other-formats
+ title: 다른 형식의 Stable Diffusion 불러오기
+ - local: in_translation
+ title: Hub에 파일 push하기
+ title: 불러오기 & 허브
+ - sections:
+ - local: using-diffusers/pipeline_overview
+ title: 개요
+ - local: using-diffusers/unconditional_image_generation
+ title: Unconditional 이미지 생성
+ - local: using-diffusers/conditional_image_generation
+ title: Text-to-image 생성
+ - local: using-diffusers/img2img
+ title: Text-guided image-to-image
+ - local: using-diffusers/inpaint
+ title: Text-guided 이미지 인페인팅
+ - local: using-diffusers/depth2img
+ title: Text-guided depth-to-image
+ - local: using-diffusers/textual_inversion_inference
+ title: Textual inversion
+ - local: training/distributed_inference
+ title: 여러 GPU를 사용한 분산 추론
+ - local: in_translation
+ title: Distilled Stable Diffusion 추론
+ - local: using-diffusers/reusing_seeds
+ title: Deterministic 생성으로 이미지 퀄리티 높이기
+ - local: using-diffusers/control_brightness
+ title: 이미지 밝기 조정하기
+ - local: using-diffusers/reproducibility
+ title: 재현 가능한 파이프라인 생성하기
+ - local: using-diffusers/custom_pipeline_examples
+ title: 커뮤니티 파이프라인들
+ - local: using-diffusers/contribute_pipeline
+ title: 커뮤티니 파이프라인에 기여하는 방법
+ - local: using-diffusers/stable_diffusion_jax_how_to
+ title: JAX/Flax에서의 Stable Diffusion
+ - local: using-diffusers/weighted_prompts
+ title: Weighting Prompts
+ title: 추론을 위한 파이프라인
+ - sections:
+ - local: training/overview
+ title: 개요
+ - local: training/create_dataset
+ title: 학습을 위한 데이터셋 생성하기
+ - local: training/adapt_a_model
+ title: 새로운 태스크에 모델 적용하기
+ - local: training/unconditional_training
+ title: Unconditional 이미지 생성
+ - local: training/text_inversion
+ title: Textual Inversion
+ - local: training/dreambooth
+ title: DreamBooth
+ - local: training/text2image
+ title: Text-to-image
+ - local: training/lora
+ title: Low-Rank Adaptation of Large Language Models (LoRA)
+ - local: training/controlnet
+ title: ControlNet
+ - local: training/instructpix2pix
+ title: InstructPix2Pix 학습
+ - local: training/custom_diffusion
+ title: Custom Diffusion
+ title: Training
+ title: Diffusers 사용하기
+- sections:
+ - local: optimization/opt_overview
+ title: 개요
+ - local: optimization/fp16
+ title: 메모리와 속도
+ - local: optimization/torch2.0
+ title: Torch2.0 지원
+ - local: optimization/xformers
+ title: xFormers
+ - local: optimization/onnx
+ title: ONNX
+ - local: optimization/open_vino
+ title: OpenVINO
+ - local: optimization/coreml
+ title: Core ML
+ - local: optimization/mps
+ title: MPS
+ - local: optimization/habana
+ title: Habana Gaudi
+ - local: optimization/tome
+ title: Token Merging
+ title: 최적화/특수 하드웨어
+- sections:
+ - local: using-diffusers/controlling_generation
+ title: 제어된 생성
+ - local: in_translation
+ title: Diffusion Models 평가하기
+ title: 개념 가이드
+- sections:
+ - sections:
+ - sections:
+ - local: api/pipelines/stable_diffusion/stable_diffusion_xl
+ title: Stable Diffusion XL
+ title: Stable Diffusion
+ title: Pipelines
+ title: API
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/diffusers/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
new file mode 100644
index 0000000000000000000000000000000000000000..ab5a03ae81a0fc0f0da7b6105ccc3886f537b64c
--- /dev/null
+++ b/diffusers/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -0,0 +1,400 @@
+
+
+# Stable diffusion XL
+
+Stable Diffusion XL은 Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, Robin Rombach에 의해 [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952)에서 제안되었습니다.
+
+논문 초록은 다음을 따릅니다:
+
+*text-to-image의 latent diffusion 모델인 SDXL을 소개합니다. 이전 버전의 Stable Diffusion과 비교하면, SDXL은 세 배 더큰 규모의 UNet 백본을 포함합니다: 모델 파라미터의 증가는 많은 attention 블럭을 사용하고 더 큰 cross-attention context를 SDXL의 두 번째 텍스트 인코더에 사용하기 때문입니다. 다중 종횡비에 다수의 새로운 conditioning 방법을 구성했습니다. 또한 후에 수정하는 image-to-image 기술을 사용함으로써 SDXL에 의해 생성된 시각적 품질을 향상하기 위해 정제된 모델을 소개합니다. SDXL은 이전 버전의 Stable Diffusion보다 성능이 향상되었고, 이러한 black-box 최신 이미지 생성자와 경쟁력있는 결과를 달성했습니다.*
+
+## 팁
+
+- Stable Diffusion XL은 특히 786과 1024사이의 이미지에 잘 작동합니다.
+- Stable Diffusion XL은 아래와 같이 학습된 각 텍스트 인코더에 대해 서로 다른 프롬프트를 전달할 수 있습니다. 동일한 프롬프트의 다른 부분을 텍스트 인코더에 전달할 수도 있습니다.
+- Stable Diffusion XL 결과 이미지는 아래에 보여지듯이 정제기(refiner)를 사용함으로써 향상될 수 있습니다.
+
+### 이용가능한 체크포인트:
+
+- *Text-to-Image (1024x1024 해상도)*: [`StableDiffusionXLPipeline`]을 사용한 [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+- *Image-to-Image / 정제기(refiner) (1024x1024 해상도)*: [`StableDiffusionXLImg2ImgPipeline`]를 사용한 [stabilityai/stable-diffusion-xl-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)
+
+## 사용 예시
+
+SDXL을 사용하기 전에 `transformers`, `accelerate`, `safetensors` 와 `invisible_watermark`를 설치하세요.
+다음과 같이 라이브러리를 설치할 수 있습니다:
+
+```
+pip install transformers
+pip install accelerate
+pip install safetensors
+pip install invisible-watermark>=0.2.0
+```
+
+### 워터마커
+
+Stable Diffusion XL로 이미지를 생성할 때 워터마크가 보이지 않도록 추가하는 것을 권장하는데, 이는 다운스트림(downstream) 어플리케이션에서 기계에 합성되었는지를 식별하는데 도움을 줄 수 있습니다. 그렇게 하려면 [invisible_watermark 라이브러리](https://pypi.org/project/invisible-watermark/)를 통해 설치해주세요:
+
+
+```
+pip install invisible-watermark>=0.2.0
+```
+
+`invisible-watermark` 라이브러리가 설치되면 워터마커가 **기본적으로** 사용될 것입니다.
+
+생성 또는 안전하게 이미지를 배포하기 위해 다른 규정이 있다면, 다음과 같이 워터마커를 비활성화할 수 있습니다:
+
+```py
+pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
+```
+
+### Text-to-Image
+
+*text-to-image*를 위해 다음과 같이 SDXL을 사용할 수 있습니다:
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = pipe(prompt=prompt).images[0]
+```
+
+### Image-to-image
+
+*image-to-image*를 위해 다음과 같이 SDXL을 사용할 수 있습니다:
+
+```py
+import torch
+from diffusers import StableDiffusionXLImg2ImgPipeline
+from diffusers.utils import load_image
+
+pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe = pipe.to("cuda")
+url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
+
+init_image = load_image(url).convert("RGB")
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt, image=init_image).images[0]
+```
+
+### 인페인팅
+
+*inpainting*를 위해 다음과 같이 SDXL을 사용할 수 있습니다:
+
+```py
+import torch
+from diffusers import StableDiffusionXLInpaintPipeline
+from diffusers.utils import load_image
+
+pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).convert("RGB")
+mask_image = load_image(mask_url).convert("RGB")
+
+prompt = "A majestic tiger sitting on a bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
+```
+
+### 이미지 결과물을 정제하기
+
+[base 모델 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)에서, StableDiffusion-XL 또한 고주파 품질을 향상시키는 이미지를 생성하기 위해 낮은 노이즈 단계 이미지를 제거하는데 특화된 [refiner 체크포인트](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 포함하고 있습니다. 이 refiner 체크포인트는 이미지 품질을 향상시키기 위해 base 체크포인트를 실행한 후 "두 번째 단계" 파이프라인에 사용될 수 있습니다.
+
+refiner를 사용할 때, 쉽게 사용할 수 있습니다
+- 1.) base 모델과 refiner을 사용하는데, 이는 *Denoisers의 앙상블*을 위한 첫 번째 제안된 [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/)를 사용하거나
+- 2.) base 모델을 거친 후 [SDEdit](https://arxiv.org/abs/2108.01073) 방법으로 단순하게 refiner를 실행시킬 수 있습니다.
+
+**참고**: SD-XL base와 refiner를 앙상블로 사용하는 아이디어는 커뮤니티 기여자들이 처음으로 제안했으며, 이는 다음과 같은 `diffusers`를 구현하는 데도 도움을 주셨습니다.
+- [SytanSD](https://github.com/SytanSD)
+- [bghira](https://github.com/bghira)
+- [Birch-san](https://github.com/Birch-san)
+- [AmericanPresidentJimmyCarter](https://github.com/AmericanPresidentJimmyCarter)
+
+#### 1.) Denoisers의 앙상블
+
+base와 refiner 모델을 denoiser의 앙상블로 사용할 때, base 모델은 고주파 diffusion 단계를 위한 전문가의 역할을 해야하고, refiner는 낮은 노이즈 diffusion 단계를 위한 전문가의 역할을 해야 합니다.
+
+2.)에 비해 1.)의 장점은 전체적으로 denoising 단계가 덜 필요하므로 속도가 훨씬 더 빨라집니다. 단점은 base 모델의 결과를 검사할 수 없다는 것입니다. 즉, 여전히 노이즈가 심하게 제거됩니다.
+
+base 모델과 refiner를 denoiser의 앙상블로 사용하기 위해 각각 고노이즈(high-nosise) (*즉* base 모델)와 저노이즈 (*즉* refiner 모델)의 노이즈를 제거하는 단계를 거쳐야하는 타임스텝의 기간을 정의해야 합니다.
+base 모델의 [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end)와 refiner 모델의 [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start)를 사용해 간격을 정합니다.
+
+`denoising_end`와 `denoising_start` 모두 0과 1사이의 실수 값으로 전달되어야 합니다.
+전달되면 노이즈 제거의 끝과 시작은 모델 스케줄에 의해 정의된 이산적(discrete) 시간 간격의 비율로 정의됩니다.
+노이즈 제거 단계의 수는 모델이 학습된 불연속적인 시간 간격과 선언된 fractional cutoff에 의해 결정되므로 '강도' 또한 선언된 경우 이 값이 '강도'를 재정의합니다.
+
+예시를 들어보겠습니다.
+우선, 두 개의 파이프라인을 가져옵니다. 텍스트 인코더와 variational autoencoder는 동일하므로 refiner를 위해 다시 불러오지 않아도 됩니다.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+base = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+refiner = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=base.text_encoder_2,
+ vae=base.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+)
+refiner.to("cuda")
+```
+
+이제 추론 단계의 수와 고노이즈에서 노이즈를 제거하는 단계(*즉* base 모델)를 거쳐 실행되는 지점을 정의합니다.
+
+```py
+n_steps = 40
+high_noise_frac = 0.8
+```
+
+Stable Diffusion XL base 모델은 타임스텝 0-999에 학습되며 Stable Diffusion XL refiner는 포괄적인 낮은 노이즈 타임스텝인 0-199에 base 모델로 부터 파인튜닝되어, 첫 800 타임스텝 (높은 노이즈)에 base 모델을 사용하고 마지막 200 타입스텝 (낮은 노이즈)에서 refiner가 사용됩니다. 따라서, `high_noise_frac`는 0.8로 설정하고, 모든 200-999 스텝(노이즈 제거 타임스텝의 첫 80%)은 base 모델에 의해 수행되며 0-199 스텝(노이즈 제거 타임스텝의 마지막 20%)은 refiner 모델에 의해 수행됩니다.
+
+기억하세요, 노이즈 제거 절차는 **높은 값**(높은 노이즈) 타임스텝에서 시작되고, **낮은 값** (낮은 노이즈) 타임스텝에서 끝납니다.
+
+이제 두 파이프라인을 실행해봅시다. `denoising_end`과 `denoising_start`를 같은 값으로 설정하고 `num_inference_steps`는 상수로 유지합니다. 또한 base 모델의 출력은 잠재 공간에 있어야 한다는 점을 기억하세요:
+
+```py
+prompt = "A majestic lion jumping from a big stone at night"
+
+image = base(
+ prompt=prompt,
+ num_inference_steps=n_steps,
+ denoising_end=high_noise_frac,
+ output_type="latent",
+).images
+image = refiner(
+ prompt=prompt,
+ num_inference_steps=n_steps,
+ denoising_start=high_noise_frac,
+ image=image,
+).images[0]
+```
+
+이미지를 살펴보겠습니다.
+
+| 원래의 이미지 | Denoiser들의 앙상블 |
+|---|---|
+| ![lion_base](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_base.png) | ![lion_ref](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_refined.png)
+
+동일한 40 단계에서 base 모델을 실행한다면, 이미지의 디테일(예: 사자의 눈과 코)이 떨어졌을 것입니다:
+
+
+
+앙상블 방식은 사용 가능한 모든 스케줄러에서 잘 작동합니다!
+
+
+
+#### 2.) 노이즈가 완전히 제거된 기본 이미지에서 이미지 출력을 정제하기
+
+일반적인 [`StableDiffusionImg2ImgPipeline`] 방식에서, 기본 모델에서 생성된 완전히 노이즈가 제거된 이미지는 [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 사용해 더 향상시킬 수 있습니다.
+
+이를 위해, 보통의 "base" text-to-image 파이프라인을 수행 후에 image-to-image 파이프라인으로써 refiner를 실행시킬 수 있습니다. base 모델의 출력을 잠재 공간에 남겨둘 수 있습니다.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+refiner = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=pipe.text_encoder_2,
+ vae=pipe.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+)
+refiner.to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+image = pipe(prompt=prompt, output_type="latent" if use_refiner else "pil").images[0]
+image = refiner(prompt=prompt, image=image[None, :]).images[0]
+```
+
+| 원래의 이미지 | 정제된 이미지 |
+|---|---|
+| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/init_image.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_image.png) |
+
+
+
+refiner는 또한 인페인팅 설정에 잘 사용될 수 있습니다. 아래에 보여지듯이 [`StableDiffusionXLInpaintPipeline`] 클래스를 사용해서 만들어보세요.
+
+
+
+Denoiser 앙상블 설정에서 인페인팅에 refiner를 사용하려면 다음을 수행하면 됩니다:
+
+```py
+from diffusers import StableDiffusionXLInpaintPipeline
+from diffusers.utils import load_image
+
+pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ text_encoder_2=pipe.text_encoder_2,
+ vae=pipe.vae,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+)
+refiner.to("cuda")
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).convert("RGB")
+mask_image = load_image(mask_url).convert("RGB")
+
+prompt = "A majestic tiger sitting on a bench"
+num_inference_steps = 75
+high_noise_frac = 0.7
+
+image = pipe(
+ prompt=prompt,
+ image=init_image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ denoising_start=high_noise_frac,
+ output_type="latent",
+).images
+image = refiner(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ denoising_start=high_noise_frac,
+).images[0]
+```
+
+일반적인 SDE 설정에서 인페인팅에 refiner를 사용하기 위해, `denoising_end`와 `denoising_start`를 제거하고 refiner의 추론 단계의 수를 적게 선택하세요.
+
+### 단독 체크포인트 파일 / 원래의 파일 형식으로 불러오기
+
+[`~diffusers.loaders.FromSingleFileMixin.from_single_file`]를 사용함으로써 원래의 파일 형식을 `diffusers` 형식으로 불러올 수 있습니다:
+
+```py
+from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
+import torch
+
+pipe = StableDiffusionXLPipeline.from_single_file(
+ "./sd_xl_base_1.0.safetensors", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
+ "./sd_xl_refiner_1.0.safetensors", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
+)
+refiner.to("cuda")
+```
+
+### 모델 offloading을 통해 메모리 최적화하기
+
+out-of-memory 에러가 난다면, [`StableDiffusionXLPipeline.enable_model_cpu_offload`]을 사용하는 것을 권장합니다.
+
+```diff
+- pipe.to("cuda")
++ pipe.enable_model_cpu_offload()
+```
+
+그리고
+
+```diff
+- refiner.to("cuda")
++ refiner.enable_model_cpu_offload()
+```
+
+### `torch.compile`로 추론 속도를 올리기
+
+`torch.compile`를 사용함으로써 추론 속도를 올릴 수 있습니다. 이는 **ca.** 20% 속도 향상이 됩니다.
+
+```diff
++ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
++ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+### `torch < 2.0`일 때 실행하기
+
+**참고** Stable Diffusion XL을 `torch`가 2.0 버전 미만에서 실행시키고 싶을 때, xformers 어텐션을 사용해주세요:
+
+```
+pip install xformers
+```
+
+```diff
++pipe.enable_xformers_memory_efficient_attention()
++refiner.enable_xformers_memory_efficient_attention()
+```
+
+## StableDiffusionXLPipeline
+
+[[autodoc]] StableDiffusionXLPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLImg2ImgPipeline
+
+[[autodoc]] StableDiffusionXLImg2ImgPipeline
+ - all
+ - __call__
+
+## StableDiffusionXLInpaintPipeline
+
+[[autodoc]] StableDiffusionXLInpaintPipeline
+ - all
+ - __call__
+
+### 각 텍스트 인코더에 다른 프롬프트를 전달하기
+
+Stable Diffusion XL는 두 개의 텍스트 인코더에 학습되었습니다. 기본 동작은 각 프롬프트에 동일한 프롬프트를 전달하는 것입니다. 그러나 [일부 사용자](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201)가 품질을 향상시킬 수 있다고 지적한 것처럼 텍스트 인코더마다 다른 프롬프트를 전달할 수 있습니다. 그렇게 하려면, `prompt_2`와 `negative_prompt_2`를 `prompt`와 `negative_prompt`에 전달해야 합니다. 그렇게 함으로써, 원래의 프롬프트들(`prompt`)과 부정 프롬프트들(`negative_prompt`)를 `텍스트 인코더`에 전달할 것입니다.(공식 SDXL 0.9/1.0의 [OpenAI CLIP-ViT/L-14](https://huggingface.co/openai/clip-vit-large-patch14)에서 볼 수 있습니다.) 그리고 `prompt_2`와 `negative_prompt_2`는 `text_encoder_2`에 전달됩니다.(공식 SDXL 0.9/1.0의 [OpenCLIP-ViT/bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)에서 볼 수 있습니다.)
+
+```py
+from diffusers import StableDiffusionXLPipeline
+import torch
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+pipe.to("cuda")
+
+# OAI CLIP-ViT/L-14에 prompt가 전달됩니다
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+# OpenCLIP-ViT/bigG-14에 prompt_2가 전달됩니다
+prompt_2 = "monet painting"
+image = pipe(prompt=prompt, prompt_2=prompt_2).images[0]
+```
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/in_translation.md b/diffusers/docs/source/ko/in_translation.md
new file mode 100644
index 0000000000000000000000000000000000000000..518be0c03b7c8cf0e8e9b2b083f08ccbb62bfad6
--- /dev/null
+++ b/diffusers/docs/source/ko/in_translation.md
@@ -0,0 +1,16 @@
+
+
+# 번역중
+
+열심히 번역을 진행중입니다. 조금만 기다려주세요.
+감사합니다!
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/index.md b/diffusers/docs/source/ko/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..a83dd0d0b29e5eee20b3d66b950d1b064aa9e964
--- /dev/null
+++ b/diffusers/docs/source/ko/index.md
@@ -0,0 +1,97 @@
+
+
+
+
+
+
+
+
+
+# Diffusers
+
+🤗 Diffusers는 이미지, 오디오, 심지어 분자의 3D 구조를 생성하기 위한 최첨단 사전 훈련된 diffusion 모델을 위한 라이브러리입니다. 간단한 추론 솔루션을 찾고 있든, 자체 diffusion 모델을 훈련하고 싶든, 🤗 Diffusers는 두 가지 모두를 지원하는 모듈식 툴박스입니다. 저희 라이브러리는 [성능보다 사용성](conceptual/philosophy#usability-over-performance), [간편함보다 단순함](conceptual/philosophy#simple-over-easy), 그리고 [추상화보다 사용자 지정 가능성](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)에 중점을 두고 설계되었습니다.
+
+이 라이브러리에는 세 가지 주요 구성 요소가 있습니다:
+
+- 몇 줄의 코드만으로 추론할 수 있는 최첨단 [diffusion 파이프라인](api/pipelines/overview).
+- 생성 속도와 품질 간의 균형을 맞추기 위해 상호교환적으로 사용할 수 있는 [노이즈 스케줄러](api/schedulers/overview).
+- 빌딩 블록으로 사용할 수 있고 스케줄러와 결합하여 자체적인 end-to-end diffusion 시스템을 만들 수 있는 사전 학습된 [모델](api/models).
+
+
+
+## Supported pipelines
+
+| Pipeline | Paper/Repository | Tasks |
+|---|---|:---:|
+| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
+| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
+| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
+| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
+| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
+| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
+| [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
+| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation |
+| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
+| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
+| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
+| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
+| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
+| [paint_by_example](./api/pipelines/paint_by_example) | [Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting |
+| [pndm](./api/pipelines/pndm) | [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
+| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
+| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
+| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation |
+| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |
+| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |
+| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |
+| [stable_diffusion_panorama](./api/pipelines/stable_diffusion/panorama) | [MultiDiffusion](https://multidiffusion.github.io/) | Text-to-Panorama Generation |
+| [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://arxiv.org/abs/2211.09800) | Text-Guided Image Editing|
+| [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [Zero-shot Image-to-Image Translation](https://pix2pixzero.github.io/) | Text-Guided Image Editing |
+| [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://arxiv.org/abs/2301.13826) | Text-to-Image Generation |
+| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939) | Text-to-Image Generation Unconditional Image Generation |
+| [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [Stable Diffusion Image Variations](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation |
+| [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [Stable Diffusion Latent Upscaler](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image |
+| [stable_diffusion_model_editing](./api/pipelines/stable_diffusion/model_editing) | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://time-diffusion.github.io/) | Text-to-Image Model Editing |
+| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
+| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
+| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Depth-Conditional Stable Diffusion](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation |
+| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
+| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [Safe Stable Diffusion](https://arxiv.org/abs/2211.05105) | Text-Guided Generation |
+| [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation |
+| [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation |
+| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
+| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation |
+| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation |
+| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
+| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
+| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
+| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
diff --git a/diffusers/docs/source/ko/installation.md b/diffusers/docs/source/ko/installation.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a9146a22620699a7faabb45844809be581a4d7a
--- /dev/null
+++ b/diffusers/docs/source/ko/installation.md
@@ -0,0 +1,142 @@
+
+
+# 설치
+
+사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
+
+🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
+
+- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
+- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
+
+## pip를 이용한 설치
+
+[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Diffusers를 설치해야 합니다.
+Python 가상 환경에 익숙하지 않은 경우 [가상환경 pip 설치 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 살펴보세요.
+가상 환경을 사용하면 서로 다른 프로젝트를 더 쉽게 관리하고, 종속성간의 호환성 문제를 피할 수 있습니다.
+
+프로젝트 디렉토리에 가상 환경을 생성하는 것으로 시작하세요:
+
+```bash
+python -m venv .env
+```
+
+그리고 가상 환경을 활성화합니다:
+
+```bash
+source .env/bin/activate
+```
+
+이제 다음의 명령어로 🤗 Diffusers를 설치할 준비가 되었습니다:
+
+**PyTorch의 경우**
+
+```bash
+pip install diffusers["torch"]
+```
+
+**Flax의 경우**
+
+```bash
+pip install diffusers["flax"]
+```
+
+## 소스로부터 설치
+
+소스에서 `diffusers`를 설치하기 전에, `torch` 및 `accelerate`이 설치되어 있는지 확인하세요.
+
+`torch` 설치에 대해서는 [torch docs](https://pytorch.org/get-started/locally/#start-locally)를 참고하세요.
+
+다음과 같이 `accelerate`을 설치하세요.
+
+```bash
+pip install accelerate
+```
+
+다음 명령어를 사용하여 소스에서 🤗 Diffusers를 설치하세요:
+
+```bash
+pip install git+https://github.com/huggingface/diffusers
+```
+
+이 명령어는 최신 `stable` 버전이 아닌 최첨단 `main` 버전을 설치합니다.
+`main` 버전은 최신 개발 정보를 최신 상태로 유지하는 데 유용합니다.
+예를 들어 마지막 공식 릴리즈 이후 버그가 수정되었지만, 새 릴리즈가 아직 출시되지 않은 경우입니다.
+그러나 이는 `main` 버전이 항상 안정적이지 않을 수 있음을 의미합니다.
+우리는 `main` 버전이 지속적으로 작동하도록 노력하고 있으며, 대부분의 문제는 보통 몇 시간 또는 하루 안에 해결됩니다.
+문제가 발생하면 더 빨리 해결할 수 있도록 [Issue](https://github.com/huggingface/transformers/issues)를 열어주세요!
+
+
+## 편집가능한 설치
+
+다음을 수행하려면 편집가능한 설치가 필요합니다:
+
+* 소스 코드의 `main` 버전을 사용
+* 🤗 Diffusers에 기여 (코드의 변경 사항을 테스트하기 위해 필요)
+
+저장소를 복제하고 다음 명령어를 사용하여 🤗 Diffusers를 설치합니다:
+
+```bash
+git clone https://github.com/huggingface/diffusers.git
+cd diffusers
+```
+
+**PyTorch의 경우**
+
+```
+pip install -e ".[torch]"
+```
+
+**Flax의 경우**
+
+```
+pip install -e ".[flax]"
+```
+
+이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
+Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
+예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
+
+
+
+라이브러리를 계속 사용하려면 `diffusers` 폴더를 유지해야 합니다.
+
+
+
+이제 다음 명령어를 사용하여 최신 버전의 🤗 Diffusers로 쉽게 업데이트할 수 있습니다:
+
+```bash
+cd ~/diffusers/
+git pull
+```
+
+이렇게 하면, 다음에 실행할 때 Python 환경이 🤗 Diffusers의 `main` 버전을 찾게 됩니다.
+
+## 텔레메트리 로깅에 대한 알림
+
+우리 라이브러리는 `from_pretrained()` 요청 중에 텔레메트리 정보를 원격으로 수집합니다.
+이 데이터에는 Diffusers 및 PyTorch/Flax의 버전, 요청된 모델 또는 파이프라인 클래스, 그리고 허브에서 호스팅되는 경우 사전학습된 체크포인트에 대한 경로를 포함합니다.
+이 사용 데이터는 문제를 디버깅하고 새로운 기능의 우선순위를 지정하는데 도움이 됩니다.
+텔레메트리는 HuggingFace 허브에서 모델과 파이프라인을 불러올 때만 전송되며, 로컬 사용 중에는 수집되지 않습니다.
+
+우리는 추가 정보를 공유하지 않기를 원하는 사람이 있다는 것을 이해하고 개인 정보를 존중하므로, 터미널에서 `DISABLE_TELEMETRY` 환경 변수를 설정하여 텔레메트리 수집을 비활성화할 수 있습니다.
+
+Linux/MacOS에서:
+```bash
+export DISABLE_TELEMETRY=YES
+```
+
+Windows에서:
+```bash
+set DISABLE_TELEMETRY=YES
+```
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/optimization/coreml.md b/diffusers/docs/source/ko/optimization/coreml.md
new file mode 100644
index 0000000000000000000000000000000000000000..5ce81a20889bafb00228c7f6bc31f263c5cc4c1f
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/coreml.md
@@ -0,0 +1,168 @@
+
+
+# Core ML로 Stable Diffusion을 실행하는 방법
+
+[Core ML](https://developer.apple.com/documentation/coreml)은 Apple 프레임워크에서 지원하는 모델 형식 및 머신 러닝 라이브러리입니다. macOS 또는 iOS/iPadOS 앱 내에서 Stable Diffusion 모델을 실행하는 데 관심이 있는 경우, 이 가이드에서는 기존 PyTorch 체크포인트를 Core ML 형식으로 변환하고 이를 Python 또는 Swift로 추론에 사용하는 방법을 설명합니다.
+
+Core ML 모델은 Apple 기기에서 사용할 수 있는 모든 컴퓨팅 엔진들, 즉 CPU, GPU, Apple Neural Engine(또는 Apple Silicon Mac 및 최신 iPhone/iPad에서 사용할 수 있는 텐서 최적화 가속기인 ANE)을 활용할 수 있습니다. 모델과 실행 중인 기기에 따라 Core ML은 컴퓨팅 엔진도 혼합하여 사용할 수 있으므로, 예를 들어 모델의 일부가 CPU에서 실행되는 반면 다른 부분은 GPU에서 실행될 수 있습니다.
+
+
+
+PyTorch에 내장된 `mps` 가속기를 사용하여 Apple Silicon Macs에서 `diffusers` Python 코드베이스를 실행할 수도 있습니다. 이 방법은 [mps 가이드]에 자세히 설명되어 있지만 네이티브 앱과 호환되지 않습니다.
+
+
+
+## Stable Diffusion Core ML 체크포인트
+
+Stable Diffusion 가중치(또는 체크포인트)는 PyTorch 형식으로 저장되기 때문에 네이티브 앱에서 사용하기 위해서는 Core ML 형식으로 변환해야 합니다.
+
+다행히도 Apple 엔지니어들이 `diffusers`를 기반으로 한 [변환 툴](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)을 개발하여 PyTorch 체크포인트를 Core ML로 변환할 수 있습니다.
+
+모델을 변환하기 전에 잠시 시간을 내어 Hugging Face Hub를 살펴보세요. 관심 있는 모델이 이미 Core ML 형식으로 제공되고 있을 가능성이 높습니다:
+
+- [Apple](https://huggingface.co/apple) organization에는 Stable Diffusion 버전 1.4, 1.5, 2.0 base 및 2.1 base가 포함되어 있습니다.
+- [coreml](https://huggingface.co/coreml) organization에는 커스텀 DreamBooth가 적용되거나, 파인튜닝된 모델이 포함되어 있습니다.
+- 이 [필터](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes)를 사용하여 사용 가능한 모든 Core ML 체크포인트들을 반환합니다.
+
+원하는 모델을 찾을 수 없는 경우 Apple의 [모델을 Core ML로 변환하기](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) 지침을 따르는 것이 좋습니다.
+
+## 사용할 Core ML 변형(Variant) 선택하기
+
+Stable Diffusion 모델은 다양한 목적에 따라 다른 Core ML 변형으로 변환할 수 있습니다:
+
+- 사용되는 어텐션 블록 유형. 어텐션 연산은 이미지 표현의 여러 영역 간의 관계에 '주의를 기울이고' 이미지와 텍스트 표현이 어떻게 연관되어 있는지 이해하는 데 사용됩니다. 어텐션 연산은 컴퓨팅 및 메모리 집약적이므로 다양한 장치의 하드웨어 특성을 고려한 다양한 구현이 존재합니다. Core ML Stable Diffusion 모델의 경우 두 가지 주의 변형이 있습니다:
+ * `split_einsum` ([Apple에서 도입](https://machinelearning.apple.com/research/neural-engine-transformers)은 최신 iPhone, iPad 및 M 시리즈 컴퓨터에서 사용할 수 있는 ANE 장치에 최적화되어 있습니다.
+ * "원본" 어텐션(`diffusers`에 사용되는 기본 구현)는 CPU/GPU와만 호환되며 ANE와는 호환되지 않습니다. "원본" 어텐션을 사용하여 CPU + GPU에서 모델을 실행하는 것이 ANE보다 *더* 빠를 수 있습니다. 자세한 내용은 [이 성능 벤치마크](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks)와 커뮤니티에서 제공하는 일부 [추가 측정](https://github.com/huggingface/swift-coreml-diffusers/issues/31)을 참조하십시오.
+
+- 지원되는 추론 프레임워크
+ * `packages`는 Python 추론에 적합합니다. 네이티브 앱에 통합하기 전에 변환된 Core ML 모델을 테스트하거나, Core ML 성능을 알고 싶지만 네이티브 앱을 지원할 필요는 없는 경우에 사용할 수 있습니다. 예를 들어, 웹 UI가 있는 애플리케이션은 Python Core ML 백엔드를 완벽하게 사용할 수 있습니다.
+ * Swift 코드에는 `컴파일된` 모델이 필요합니다. Hub의 `컴파일된` 모델은 iOS 및 iPadOS 기기와의 호환성을 위해 큰 UNet 모델 가중치를 여러 파일로 분할합니다. 이는 [`--chunk-unet` 변환 옵션](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)에 해당합니다. 네이티브 앱을 지원하려면 `컴파일된` 변형을 선택해야 합니다.
+
+공식 Core ML Stable Diffusion [모델](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main)에는 이러한 변형이 포함되어 있지만 커뮤니티 버전은 다를 수 있습니다:
+
+```
+coreml-stable-diffusion-v1-4
+├── README.md
+├── original
+│ ├── compiled
+│ └── packages
+└── split_einsum
+ ├── compiled
+ └── packages
+```
+
+아래와 같이 필요한 변형을 다운로드하여 사용할 수 있습니다.
+
+## Python에서 Core ML 추론
+
+Python에서 Core ML 추론을 실행하려면 다음 라이브러리를 설치하세요:
+
+```bash
+pip install huggingface_hub
+pip install git+https://github.com/apple/ml-stable-diffusion
+```
+
+### 모델 체크포인트 다운로드하기
+
+`컴파일된` 버전은 Swift와만 호환되므로 Python에서 추론을 실행하려면 `packages` 폴더에 저장된 버전 중 하나를 사용하세요. `원본` 또는 `split_einsum` 어텐션 중 어느 것을 사용할지 선택할 수 있습니다.
+
+다음은 Hub에서 'models'라는 디렉토리로 'original' 어텐션 변형을 다운로드하는 방법입니다:
+
+```Python
+from huggingface_hub import snapshot_download
+from pathlib import Path
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/packages"
+
+model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False)
+print(f"Model downloaded at {model_path}")
+```
+
+
+### 추론[[python-inference]]
+
+모델의 snapshot을 다운로드한 후에는 Apple의 Python 스크립트를 사용하여 테스트할 수 있습니다.
+
+```shell
+python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i models/coreml-stable-diffusion-v1-4_original_packages -o --compute-unit CPU_AND_GPU --seed 93
+```
+
+``는 위 단계에서 다운로드한 체크포인트를 가리켜야 하며, `--compute-unit`은 추론을 허용할 하드웨어를 나타냅니다. 이는 다음 옵션 중 하나이어야 합니다: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. 선택적 출력 경로와 재현성을 위한 시드를 제공할 수도 있습니다.
+
+추론 스크립트에서는 Stable Diffusion 모델의 원래 버전인 `CompVis/stable-diffusion-v1-4`를 사용한다고 가정합니다. 다른 모델을 사용하는 경우 추론 명령줄에서 `--model-version` 옵션을 사용하여 해당 허브 ID를 *지정*해야 합니다. 이는 이미 지원되는 모델과 사용자가 직접 학습하거나 파인튜닝한 사용자 지정 모델에 적용됩니다.
+
+예를 들어, [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)를 사용하려는 경우입니다:
+
+```shell
+python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version runwayml/stable-diffusion-v1-5
+```
+
+
+## Swift에서 Core ML 추론하기
+
+Swift에서 추론을 실행하는 것은 모델이 이미 `mlmodelc` 형식으로 컴파일되어 있기 때문에 Python보다 약간 빠릅니다. 이는 앱이 시작될 때 모델이 불러와지는 것이 눈에 띄지만, 이후 여러 번 실행하면 눈에 띄지 않을 것입니다.
+
+### 다운로드
+
+Mac에서 Swift에서 추론을 실행하려면 `컴파일된` 체크포인트 버전 중 하나가 필요합니다. 이전 예제와 유사하지만 `컴파일된` 변형 중 하나를 사용하여 Python 코드를 로컬로 다운로드하는 것이 좋습니다:
+
+```Python
+from huggingface_hub import snapshot_download
+from pathlib import Path
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/compiled"
+
+model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False)
+print(f"Model downloaded at {model_path}")
+```
+
+### 추론[[swift-inference]]
+
+추론을 실행하기 위해서, Apple의 리포지토리를 복제하세요:
+
+```bash
+git clone https://github.com/apple/ml-stable-diffusion
+cd ml-stable-diffusion
+```
+
+그 다음 Apple의 명령어 도구인 [Swift 패키지 관리자](https://www.swift.org/package-manager/#)를 사용합니다:
+
+```bash
+swift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all "a photo of an astronaut riding a horse on mars"
+```
+
+`--resource-path`에 이전 단계에서 다운로드한 체크포인트 중 하나를 지정해야 하므로 확장자가 `.mlmodelc`인 컴파일된 Core ML 번들이 포함되어 있는지 확인하시기 바랍니다. `--compute-units`는 다음 값 중 하나이어야 합니다: `all`, `cpuOnly`, `cpuAndGPU`, `cpuAndNeuralEngine`.
+
+자세한 내용은 [Apple의 리포지토리 안의 지침](https://github.com/apple/ml-stable-diffusion)을 참고하시기 바랍니다.
+
+
+## 지원되는 Diffusers 기능
+
+Core ML 모델과 추론 코드는 🧨 Diffusers의 많은 기능, 옵션 및 유연성을 지원하지 않습니다. 다음은 유의해야 할 몇 가지 제한 사항입니다:
+
+- Core ML 모델은 추론에만 적합합니다. 학습이나 파인튜닝에는 사용할 수 없습니다.
+- Swift에 포팅된 스케줄러는 Stable Diffusion에서 사용하는 기본 스케줄러와 `diffusers` 구현에서 Swift로 포팅한 `DPMSolverMultistepScheduler` 두 개뿐입니다. 이들 중 약 절반의 스텝으로 동일한 품질을 생성하는 `DPMSolverMultistepScheduler`를 사용하는 것이 좋습니다.
+- 추론 코드에서 네거티브 프롬프트, classifier-free guidance scale 및 image-to-image 작업을 사용할 수 있습니다. depth guidance, ControlNet, latent upscalers와 같은 고급 기능은 아직 사용할 수 없습니다.
+
+Apple의 [변환 및 추론 리포지토리](https://github.com/apple/ml-stable-diffusion)와 자체 [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) 리포지토리는 다른 개발자들이 구축할 수 있는 기술적인 데모입니다.
+
+누락된 기능이 있다고 생각되면 언제든지 기능을 요청하거나, 더 좋은 방법은 기여 PR을 열어주세요. :)
+
+
+## 네이티브 Diffusers Swift 앱
+
+자체 Apple 하드웨어에서 Stable Diffusion을 실행하는 쉬운 방법 중 하나는 `diffusers`와 Apple의 변환 및 추론 리포지토리를 기반으로 하는 [자체 오픈 소스 Swift 리포지토리](https://github.com/huggingface/swift-coreml-diffusers)를 사용하는 것입니다. 코드를 공부하고 [Xcode](https://developer.apple.com/xcode/)로 컴파일하여 필요에 맞게 조정할 수 있습니다. 편의를 위해 앱스토어에 [독립형 Mac 앱](https://apps.apple.com/app/diffusers/id1666309574)도 있으므로 코드나 IDE를 다루지 않고도 사용할 수 있습니다. 개발자로서 Core ML이 Stable Diffusion 앱을 구축하는 데 가장 적합한 솔루션이라고 판단했다면, 이 가이드의 나머지 부분을 사용하여 프로젝트를 시작할 수 있습니다. 여러분이 무엇을 빌드할지 기대됩니다. :)
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/optimization/fp16.md b/diffusers/docs/source/ko/optimization/fp16.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f2c487a75ce45b384b64249abf689d5832f13d5
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/fp16.md
@@ -0,0 +1,410 @@
+
+
+# 메모리와 속도
+
+메모리 또는 속도에 대해 🤗 Diffusers *추론*을 최적화하기 위한 몇 가지 기술과 아이디어를 제시합니다.
+일반적으로, memory-efficient attention을 위해 [xFormers](https://github.com/facebookresearch/xformers) 사용을 추천하기 때문에, 추천하는 [설치 방법](xformers)을 보고 설치해 보세요.
+
+다음 설정이 성능과 메모리에 미치는 영향에 대해 설명합니다.
+
+| | 지연시간 | 속도 향상 |
+| ---------------- | ------- | ------- |
+| 별도 설정 없음 | 9.50s | x1 |
+| cuDNN auto-tuner | 9.37s | x1.01 |
+| fp16 | 3.61s | x2.63 |
+| Channels Last 메모리 형식 | 3.30s | x2.88 |
+| traced UNet | 3.21s | x2.96 |
+| memory-efficient attention | 2.63s | x3.61 |
+
+
+ NVIDIA TITAN RTX에서 50 DDIM 스텝의 "a photo of an astronaut riding a horse on mars" 프롬프트로 512x512 크기의 단일 이미지를 생성하였습니다.
+
+
+## cuDNN auto-tuner 활성화하기
+
+[NVIDIA cuDNN](https://developer.nvidia.com/cudnn)은 컨볼루션을 계산하는 많은 알고리즘을 지원합니다. Autotuner는 짧은 벤치마크를 실행하고 주어진 입력 크기에 대해 주어진 하드웨어에서 최고의 성능을 가진 커널을 선택합니다.
+
+**컨볼루션 네트워크**를 활용하고 있기 때문에 (다른 유형들은 현재 지원되지 않음), 다음 설정을 통해 추론 전에 cuDNN autotuner를 활성화할 수 있습니다:
+
+```python
+import torch
+
+torch.backends.cudnn.benchmark = True
+```
+
+### fp32 대신 tf32 사용하기 (Ampere 및 이후 CUDA 장치들에서)
+
+Ampere 및 이후 CUDA 장치에서 행렬곱 및 컨볼루션은 TensorFloat32(TF32) 모드를 사용하여 더 빠르지만 약간 덜 정확할 수 있습니다.
+기본적으로 PyTorch는 컨볼루션에 대해 TF32 모드를 활성화하지만 행렬 곱셈은 활성화하지 않습니다.
+네트워크에 완전한 float32 정밀도가 필요한 경우가 아니면 행렬 곱셈에 대해서도 이 설정을 활성화하는 것이 좋습니다.
+이는 일반적으로 무시할 수 있는 수치의 정확도 손실이 있지만, 계산 속도를 크게 높일 수 있습니다.
+그것에 대해 [여기](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32)서 더 읽을 수 있습니다.
+추론하기 전에 다음을 추가하기만 하면 됩니다:
+
+```python
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = True
+```
+
+## 반정밀도 가중치
+
+더 많은 GPU 메모리를 절약하고 더 빠른 속도를 얻기 위해 모델 가중치를 반정밀도(half precision)로 직접 불러오고 실행할 수 있습니다.
+여기에는 `fp16`이라는 브랜치에 저장된 float16 버전의 가중치를 불러오고, 그 때 `float16` 유형을 사용하도록 PyTorch에 지시하는 작업이 포함됩니다.
+
+```Python
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+```
+
+
+ 어떤 파이프라인에서도 [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) 를 사용하는 것은 검은색 이미지를 생성할 수 있고, 순수한 float16 정밀도를 사용하는 것보다 항상 느리기 때문에 사용하지 않는 것이 좋습니다.
+
+
+## 추가 메모리 절약을 위한 슬라이스 어텐션
+
+추가 메모리 절약을 위해, 한 번에 모두 계산하는 대신 단계적으로 계산을 수행하는 슬라이스 버전의 어텐션(attention)을 사용할 수 있습니다.
+
+
+ Attention slicing은 모델이 하나 이상의 어텐션 헤드를 사용하는 한, 배치 크기가 1인 경우에도 유용합니다.
+ 하나 이상의 어텐션 헤드가 있는 경우 *QK^T* 어텐션 매트릭스는 상당한 양의 메모리를 절약할 수 있는 각 헤드에 대해 순차적으로 계산될 수 있습니다.
+
+
+각 헤드에 대해 순차적으로 어텐션 계산을 수행하려면, 다음과 같이 추론 전에 파이프라인에서 [`~StableDiffusionPipeline.enable_attention_slicing`]를 호출하면 됩니다:
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_attention_slicing()
+image = pipe(prompt).images[0]
+```
+
+추론 시간이 약 10% 느려지는 약간의 성능 저하가 있지만 이 방법을 사용하면 3.2GB 정도의 작은 VRAM으로도 Stable Diffusion을 사용할 수 있습니다!
+
+
+## 더 큰 배치를 위한 sliced VAE 디코드
+
+제한된 VRAM에서 대규모 이미지 배치를 디코딩하거나 32개 이상의 이미지가 포함된 배치를 활성화하기 위해, 배치의 latent 이미지를 한 번에 하나씩 디코딩하는 슬라이스 VAE 디코드를 사용할 수 있습니다.
+
+이를 [`~StableDiffusionPipeline.enable_attention_slicing`] 또는 [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`]과 결합하여 메모리 사용을 추가로 최소화할 수 있습니다.
+
+VAE 디코드를 한 번에 하나씩 수행하려면 추론 전에 파이프라인에서 [`~StableDiffusionPipeline.enable_vae_slicing`]을 호출합니다. 예를 들어:
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_vae_slicing()
+images = pipe([prompt] * 32).images
+```
+
+다중 이미지 배치에서 VAE 디코드가 약간의 성능 향상이 이루어집니다. 단일 이미지 배치에서는 성능 영향은 없습니다.
+
+
+
+## 메모리 절약을 위해 가속 기능을 사용하여 CPU로 오프로딩
+
+추가 메모리 절약을 위해 가중치를 CPU로 오프로드하고 순방향 전달을 수행할 때만 GPU로 로드할 수 있습니다.
+
+CPU 오프로딩을 수행하려면 [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]를 호출하기만 하면 됩니다:
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+
+ torch_dtype=torch.float16,
+)
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_sequential_cpu_offload()
+image = pipe(prompt).images[0]
+```
+
+그러면 메모리 소비를 3GB 미만으로 줄일 수 있습니다.
+
+참고로 이 방법은 전체 모델이 아닌 서브모듈 수준에서 작동합니다. 이는 메모리 소비를 최소화하는 가장 좋은 방법이지만 프로세스의 반복적 특성으로 인해 추론 속도가 훨씬 느립니다. 파이프라인의 UNet 구성 요소는 여러 번 실행됩니다('num_inference_steps' 만큼). 매번 UNet의 서로 다른 서브모듈이 순차적으로 온로드된 다음 필요에 따라 오프로드되므로 메모리 이동 횟수가 많습니다.
+
+
+또 다른 최적화 방법인 모델 오프로딩을 사용하는 것을 고려하십시오. 이는 훨씬 빠르지만 메모리 절약이 크지는 않습니다.
+
+
+또한 ttention slicing과 연결해서 최소 메모리(< 2GB)로도 동작할 수 있습니다.
+
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+
+ torch_dtype=torch.float16,
+)
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_sequential_cpu_offload()
+pipe.enable_attention_slicing(1)
+
+image = pipe(prompt).images[0]
+```
+
+**참고**: 'enable_sequential_cpu_offload()'를 사용할 때, 미리 파이프라인을 CUDA로 이동하지 **않는** 것이 중요합니다.그렇지 않으면 메모리 소비의 이득이 최소화됩니다. 더 많은 정보를 위해 [이 이슈](https://github.com/huggingface/diffusers/issues/1934)를 보세요.
+
+
+## 빠른 추론과 메모리 메모리 절약을 위한 모델 오프로딩
+
+[순차적 CPU 오프로딩](#sequential_offloading)은 이전 섹션에서 설명한 것처럼 많은 메모리를 보존하지만 필요에 따라 서브모듈을 GPU로 이동하고 새 모듈이 실행될 때 즉시 CPU로 반환되기 때문에 추론 속도가 느려집니다.
+
+전체 모델 오프로딩은 각 모델의 구성 요소인 _modules_을 처리하는 대신, 전체 모델을 GPU로 이동하는 대안입니다. 이로 인해 추론 시간에 미치는 영향은 미미하지만(파이프라인을 'cuda'로 이동하는 것과 비교하여) 여전히 약간의 메모리를 절약할 수 있습니다.
+
+이 시나리오에서는 파이프라인의 주요 구성 요소 중 하나만(일반적으로 텍스트 인코더, unet 및 vae) GPU에 있고, 나머지는 CPU에서 대기할 것입니다.
+여러 반복을 위해 실행되는 UNet과 같은 구성 요소는 더 이상 필요하지 않을 때까지 GPU에 남아 있습니다.
+
+이 기능은 아래와 같이 파이프라인에서 `enable_model_cpu_offload()`를 호출하여 활성화할 수 있습니다.
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+)
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_model_cpu_offload()
+image = pipe(prompt).images[0]
+```
+
+이는 추가적인 메모리 절약을 위한 attention slicing과도 호환됩니다.
+
+```Python
+import torch
+from diffusers import StableDiffusionPipeline
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+)
+
+prompt = "a photo of an astronaut riding a horse on mars"
+pipe.enable_model_cpu_offload()
+pipe.enable_attention_slicing(1)
+
+image = pipe(prompt).images[0]
+```
+
+
+이 기능을 사용하려면 'accelerate' 버전 0.17.0 이상이 필요합니다.
+
+
+## Channels Last 메모리 형식 사용하기
+
+Channels Last 메모리 형식은 차원 순서를 보존하는 메모리에서 NCHW 텐서 배열을 대체하는 방법입니다.
+Channels Last 텐서는 채널이 가장 조밀한 차원이 되는 방식으로 정렬됩니다(일명 픽셀당 이미지를 저장).
+현재 모든 연산자 Channels Last 형식을 지원하는 것은 아니라 성능이 저하될 수 있으므로, 사용해보고 모델에 잘 작동하는지 확인하는 것이 좋습니다.
+
+
+예를 들어 파이프라인의 UNet 모델이 channels Last 형식을 사용하도록 설정하려면 다음을 사용할 수 있습니다:
+
+```python
+print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
+pipe.unet.to(memory_format=torch.channels_last) # in-place 연산
+# 2번째 차원에서 스트라이드 1을 가지는 (2880, 1, 960, 320)로, 연산이 작동함을 증명합니다.
+print(pipe.unet.conv_out.state_dict()["weight"].stride())
+```
+
+## 추적(tracing)
+
+추적은 모델을 통해 예제 입력 텐서를 통해 실행되는데, 해당 입력이 모델의 레이어를 통과할 때 호출되는 작업을 캡처하여 실행 파일 또는 'ScriptFunction'이 반환되도록 하고, 이는 just-in-time 컴파일로 최적화됩니다.
+
+UNet 모델을 추적하기 위해 다음을 사용할 수 있습니다:
+
+```python
+import time
+import torch
+from diffusers import StableDiffusionPipeline
+import functools
+
+# torch 기울기 비활성화
+torch.set_grad_enabled(False)
+
+# 변수 설정
+n_experiments = 2
+unet_runs_per_experiment = 50
+
+
+# 입력 불러오기
+def generate_inputs():
+ sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
+ timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
+ encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
+ return sample, timestep, encoder_hidden_states
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+).to("cuda")
+unet = pipe.unet
+unet.eval()
+unet.to(memory_format=torch.channels_last) # Channels Last 메모리 형식 사용
+unet.forward = functools.partial(unet.forward, return_dict=False) # return_dict=False을 기본값으로 설정
+
+# 워밍업
+for _ in range(3):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet(*inputs)
+
+# 추적
+print("tracing..")
+unet_traced = torch.jit.trace(unet, inputs)
+unet_traced.eval()
+print("done tracing")
+
+
+# 워밍업 및 그래프 최적화
+for _ in range(5):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet_traced(*inputs)
+
+
+# 벤치마킹
+with torch.inference_mode():
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet_traced(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet inference took {time.time() - start_time:.2f} seconds")
+
+# 모델 저장
+unet_traced.save("unet_traced.pt")
+```
+
+그 다음, 파이프라인의 `unet` 특성을 다음과 같이 추적된 모델로 바꿀 수 있습니다.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+from dataclasses import dataclass
+
+
+@dataclass
+class UNet2DConditionOutput:
+ sample: torch.FloatTensor
+
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+# jitted unet 사용
+unet_traced = torch.jit.load("unet_traced.pt")
+
+
+# pipe.unet 삭제
+class TracedUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.in_channels = pipe.unet.in_channels
+ self.device = pipe.unet.device
+
+ def forward(self, latent_model_input, t, encoder_hidden_states):
+ sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
+ return UNet2DConditionOutput(sample=sample)
+
+
+pipe.unet = TracedUNet()
+
+with torch.inference_mode():
+ image = pipe([prompt] * 1, num_inference_steps=50).images[0]
+```
+
+
+## Memory-efficient attention
+
+어텐션 블록의 대역폭을 최적화하는 최근 작업으로 GPU 메모리 사용량이 크게 향상되고 향상되었습니다.
+@tridao의 가장 최근의 플래시 어텐션: [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf).
+
+배치 크기 1(프롬프트 1개)의 512x512 크기로 추론을 실행할 때 몇 가지 Nvidia GPU에서 얻은 속도 향상은 다음과 같습니다:
+
+| GPU | 기준 어텐션 FP16 | 메모리 효율적인 어텐션 FP16 |
+|------------------ |--------------------- |--------------------------------- |
+| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
+| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
+| NVIDIA A10G | 8.88it/s | 15.6it/s |
+| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
+| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
+| A100-SXM4-40GB | 18.6it/s | 29.it/s |
+| A100-SXM-80GB | 18.7it/s | 29.5it/s |
+
+이를 활용하려면 다음을 만족해야 합니다:
+ - PyTorch > 1.12
+ - Cuda 사용 가능
+ - [xformers 라이브러리를 설치함](xformers)
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+pipe.enable_xformers_memory_efficient_attention()
+
+with torch.inference_mode():
+ sample = pipe("a small cat")
+
+# 선택: 이를 비활성화 하기 위해 다음을 사용할 수 있습니다.
+# pipe.disable_xformers_memory_efficient_attention()
+```
diff --git a/diffusers/docs/source/ko/optimization/habana.md b/diffusers/docs/source/ko/optimization/habana.md
new file mode 100644
index 0000000000000000000000000000000000000000..0f076245fb1c69b83026a36b820105d5de15c85a
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/habana.md
@@ -0,0 +1,71 @@
+
+
+# Habana Gaudi에서 Stable Diffusion을 사용하는 방법
+
+🤗 Diffusers는 🤗 [Optimum Habana](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)를 통해서 Habana Gaudi와 호환됩니다.
+
+## 요구 사항
+
+- Optimum Habana 1.4 또는 이후, [여기](https://huggingface.co/docs/optimum/habana/installation)에 설치하는 방법이 있습니다.
+- SynapseAI 1.8.
+
+
+## 추론 파이프라인
+
+Gaudi에서 Stable Diffusion 1 및 2로 이미지를 생성하려면 두 인스턴스를 인스턴스화해야 합니다:
+- [`GaudiStableDiffusionPipeline`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline)이 포함된 파이프라인. 이 파이프라인은 *텍스트-이미지 생성*을 지원합니다.
+- [`GaudiDDIMScheduler`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiDDIMScheduler)이 포함된 스케줄러. 이 스케줄러는 Habana Gaudi에 최적화되어 있습니다.
+
+파이프라인을 초기화할 때, HPU에 배포하기 위해 `use_habana=True`를 지정해야 합니다.
+또한 가능한 가장 빠른 생성을 위해 `use_hpu_graphs=True`로 **HPU 그래프**를 활성화해야 합니다.
+마지막으로, [Hugging Face Hub](https://huggingface.co/Habana)에서 다운로드할 수 있는 [Gaudi configuration](https://huggingface.co/docs/optimum/habana/package_reference/gaudi_config)을 지정해야 합니다.
+
+```python
+from optimum.habana import GaudiConfig
+from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
+
+model_name = "stabilityai/stable-diffusion-2-base"
+scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
+pipeline = GaudiStableDiffusionPipeline.from_pretrained(
+ model_name,
+ scheduler=scheduler,
+ use_habana=True,
+ use_hpu_graphs=True,
+ gaudi_config="Habana/stable-diffusion",
+)
+```
+
+파이프라인을 호출하여 하나 이상의 프롬프트에서 배치별로 이미지를 생성할 수 있습니다.
+
+```python
+outputs = pipeline(
+ prompt=[
+ "High quality photo of an astronaut riding a horse in space",
+ "Face of a yellow cat, high resolution, sitting on a park bench",
+ ],
+ num_images_per_prompt=10,
+ batch_size=4,
+)
+```
+
+더 많은 정보를 얻기 위해, Optimum Habana의 [문서](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)와 공식 Github 저장소에 제공된 [예시](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)를 확인하세요.
+
+
+## 벤치마크
+
+다음은 [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) Gaudi 구성(혼합 정밀도 bf16/fp32)을 사용하는 Habana first-generation Gaudi 및 Gaudi2의 지연 시간입니다:
+
+| | Latency (배치 크기 = 1) | Throughput (배치 크기 = 8) |
+| ---------------------- |:------------------------:|:---------------------------:|
+| first-generation Gaudi | 4.29s | 0.283 images/s |
+| Gaudi2 | 1.54s | 0.904 images/s |
diff --git a/diffusers/docs/source/ko/optimization/mps.md b/diffusers/docs/source/ko/optimization/mps.md
new file mode 100644
index 0000000000000000000000000000000000000000..cd04d6d1103d5ecd83d7c983a99110928eb85c7e
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/mps.md
@@ -0,0 +1,71 @@
+
+
+# Apple Silicon (M1/M2)에서 Stable Diffusion을 사용하는 방법
+
+Diffusers는 Stable Diffusion 추론을 위해 PyTorch `mps`를 사용해 Apple 실리콘과 호환됩니다. 다음은 Stable Diffusion이 있는 M1 또는 M2 컴퓨터를 사용하기 위해 따라야 하는 단계입니다.
+
+## 요구 사항
+
+- Apple silicon (M1/M2) 하드웨어의 Mac 컴퓨터.
+- macOS 12.6 또는 이후 (13.0 또는 이후 추천).
+- Python arm64 버전
+- PyTorch 2.0(추천) 또는 1.13(`mps`를 지원하는 최소 버전). Yhttps://pytorch.org/get-started/locally/의 지침에 따라 `pip` 또는 `conda`로 설치할 수 있습니다.
+
+
+## 추론 파이프라인
+
+아래 코도는 익숙한 `to()` 인터페이스를 사용하여 `mps` 백엔드로 Stable Diffusion 파이프라인을 M1 또는 M2 장치로 이동하는 방법을 보여줍니다.
+
+
+
+
+**PyTorch 1.13을 사용 중일 때 ** 추가 일회성 전달을 사용하여 파이프라인을 "프라이밍"하는 것을 추천합니다. 이것은 발견한 이상한 문제에 대한 임시 해결 방법입니다. 첫 번째 추론 전달은 후속 전달와 약간 다른 결과를 생성합니다. 이 전달은 한 번만 수행하면 되며 추론 단계를 한 번만 사용하고 결과를 폐기해도 됩니다.
+
+
+
+이전 팁에서 설명한 것들을 포함한 여러 문제를 해결하므로 PyTorch 2 이상을 사용하는 것이 좋습니다.
+
+
+```python
+# `huggingface-cli login`에 로그인되어 있음을 확인
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipe = pipe.to("mps")
+
+# 컴퓨터가 64GB 이하의 RAM 램일 때 추천
+pipe.enable_attention_slicing()
+
+prompt = "a photo of an astronaut riding a horse on mars"
+
+# 처음 "워밍업" 전달 (위 설명을 보세요)
+_ = pipe(prompt, num_inference_steps=1)
+
+# 결과는 워밍업 전달 후의 CPU 장치의 결과와 일치합니다.
+image = pipe(prompt).images[0]
+```
+
+## 성능 추천
+
+M1/M2 성능은 메모리 압력에 매우 민감합니다. 시스템은 필요한 경우 자동으로 스왑되지만 스왑할 때 성능이 크게 저하됩니다.
+
+
+특히 컴퓨터의 시스템 RAM이 64GB 미만이거나 512 × 512픽셀보다 큰 비표준 해상도에서 이미지를 생성하는 경우, 추론 중에 메모리 압력을 줄이고 스와핑을 방지하기 위해 *어텐션 슬라이싱*을 사용하는 것이 좋습니다. 어텐션 슬라이싱은 비용이 많이 드는 어텐션 작업을 한 번에 모두 수행하는 대신 여러 단계로 수행합니다. 일반적으로 범용 메모리가 없는 컴퓨터에서 ~20%의 성능 영향을 미치지만 64GB 이상이 아닌 경우 대부분의 Apple Silicon 컴퓨터에서 *더 나은 성능*이 관찰되었습니다.
+
+```python
+pipeline.enable_attention_slicing()
+```
+
+## Known Issues
+
+- 여러 프롬프트를 배치로 생성하는 것은 [충돌이 발생하거나 안정적으로 작동하지 않습니다](https://github.com/huggingface/diffusers/issues/363). 우리는 이것이 [PyTorch의 `mps` 백엔드](https://github.com/pytorch/pytorch/issues/84039)와 관련이 있다고 생각합니다. 이 문제는 해결되고 있지만 지금은 배치 대신 반복 방법을 사용하는 것이 좋습니다.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/optimization/onnx.md b/diffusers/docs/source/ko/optimization/onnx.md
new file mode 100644
index 0000000000000000000000000000000000000000..d52110b8c1fbd4b09614ce5b76e79e136b71e959
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/onnx.md
@@ -0,0 +1,65 @@
+
+
+
+# 추론을 위해 ONNX 런타임을 사용하는 방법
+
+🤗 Diffusers는 ONNX Runtime과 호환되는 Stable Diffusion 파이프라인을 제공합니다. 이를 통해 ONNX(CPU 포함)를 지원하고 PyTorch의 가속 버전을 사용할 수 없는 모든 하드웨어에서 Stable Diffusion을 실행할 수 있습니다.
+
+## 설치
+
+다음 명령어로 ONNX Runtime를 지원하는 🤗 Optimum를 설치합니다:
+
+```
+pip install optimum["onnxruntime"]
+```
+
+## Stable Diffusion 추론
+
+아래 코드는 ONNX 런타임을 사용하는 방법을 보여줍니다. `StableDiffusionPipeline` 대신 `OnnxStableDiffusionPipeline`을 사용해야 합니다.
+PyTorch 모델을 불러오고 즉시 ONNX 형식으로 변환하려는 경우 `export=True`로 설정합니다.
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True)
+prompt = "a photo of an astronaut riding a horse on mars"
+images = pipe(prompt).images[0]
+pipe.save_pretrained("./onnx-stable-diffusion-v1-5")
+```
+
+파이프라인을 ONNX 형식으로 오프라인으로 내보내고 나중에 추론에 사용하려는 경우,
+[`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) 명령어를 사용할 수 있습니다:
+
+```bash
+optimum-cli export onnx --model runwayml/stable-diffusion-v1-5 sd_v15_onnx/
+```
+
+그 다음 추론을 수행합니다:
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionPipeline
+
+model_id = "sd_v15_onnx"
+pipe = ORTStableDiffusionPipeline.from_pretrained(model_id)
+prompt = "a photo of an astronaut riding a horse on mars"
+images = pipe(prompt).images[0]
+```
+
+Notice that we didn't have to specify `export=True` above.
+
+[Optimum 문서](https://huggingface.co/docs/optimum/)에서 더 많은 예시를 찾을 수 있습니다.
+
+## 알려진 이슈들
+
+- 여러 프롬프트를 배치로 생성하면 너무 많은 메모리가 사용되는 것 같습니다. 이를 조사하는 동안, 배치 대신 반복 방법이 필요할 수도 있습니다.
diff --git a/diffusers/docs/source/ko/optimization/open_vino.md b/diffusers/docs/source/ko/optimization/open_vino.md
new file mode 100644
index 0000000000000000000000000000000000000000..cb279909f61840c3e7c4b99e4f6edda132cd563b
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/open_vino.md
@@ -0,0 +1,39 @@
+
+
+# 추론을 위한 OpenVINO 사용 방법
+
+🤗 [Optimum](https://github.com/huggingface/optimum-intel)은 OpenVINO와 호환되는 Stable Diffusion 파이프라인을 제공합니다.
+이제 다양한 Intel 프로세서에서 OpenVINO Runtime으로 쉽게 추론을 수행할 수 있습니다. ([여기](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html)서 지원되는 전 기기 목록을 확인하세요).
+
+## 설치
+
+다음 명령어로 🤗 Optimum을 설치합니다:
+
+```
+pip install optimum["openvino"]
+```
+
+## Stable Diffusion 추론
+
+OpenVINO 모델을 불러오고 OpenVINO 런타임으로 추론을 실행하려면 `StableDiffusionPipeline`을 `OVStableDiffusionPipeline`으로 교체해야 합니다. PyTorch 모델을 불러오고 즉시 OpenVINO 형식으로 변환하려는 경우 `export=True`로 설정합니다.
+
+```python
+from optimum.intel.openvino import OVStableDiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = OVStableDiffusionPipeline.from_pretrained(model_id, export=True)
+prompt = "a photo of an astronaut riding a horse on mars"
+images = pipe(prompt).images[0]
+```
+
+[Optimum 문서](https://huggingface.co/docs/optimum/intel/inference#export-and-inference-of-stable-diffusion-models)에서 (정적 reshaping과 모델 컴파일 등의) 더 많은 예시들을 찾을 수 있습니다.
diff --git a/diffusers/docs/source/ko/optimization/opt_overview.md b/diffusers/docs/source/ko/optimization/opt_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..c322ee3156d325e27b57fd1587d61b00e66fe306
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/opt_overview.md
@@ -0,0 +1,17 @@
+
+
+# 개요
+
+노이즈가 많은 출력에서 적은 출력으로 만드는 과정으로 고품질 생성 모델의 출력을 만드는 각각의 반복되는 스텝은 많은 계산이 필요합니다. 🧨 Diffuser의 목표 중 하나는 모든 사람이 이 기술을 널리 이용할 수 있도록 하는 것이며, 여기에는 소비자 및 특수 하드웨어에서 빠른 추론을 가능하게 하는 것을 포함합니다.
+
+이 섹션에서는 추론 속도를 최적화하고 메모리 소비를 줄이기 위한 반정밀(half-precision) 가중치 및 sliced attention과 같은 팁과 요령을 다룹니다. 또한 [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) 또는 [ONNX Runtime](https://onnxruntime.ai/docs/)을 사용하여 PyTorch 코드의 속도를 높이고, [xFormers](https://facebookresearch.github.io/xformers/)를 사용하여 memory-efficient attention을 활성화하는 방법을 배울 수 있습니다. Apple Silicon, Intel 또는 Habana 프로세서와 같은 특정 하드웨어에서 추론을 실행하기 위한 가이드도 있습니다.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/optimization/tome.md b/diffusers/docs/source/ko/optimization/tome.md
new file mode 100644
index 0000000000000000000000000000000000000000..43c59968d55ea5ca0a122de7c36c87a49a6403ea
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/tome.md
@@ -0,0 +1,121 @@
+
+
+# Token Merging (토큰 병합)
+
+Token Merging (introduced in [Token Merging: Your ViT But Faster](https://arxiv.org/abs/2210.09461))은 트랜스포머 기반 네트워크의 forward pass에서 중복 토큰이나 패치를 점진적으로 병합하는 방식으로 작동합니다. 이를 통해 기반 네트워크의 추론 지연 시간을 단축할 수 있습니다.
+
+Token Merging(ToMe)이 출시된 후, 저자들은 [Fast Stable Diffusion을 위한 토큰 병합](https://arxiv.org/abs/2303.17604)을 발표하여 Stable Diffusion과 더 잘 호환되는 ToMe 버전을 소개했습니다. ToMe를 사용하면 [`DiffusionPipeline`]의 추론 지연 시간을 부드럽게 단축할 수 있습니다. 이 문서에서는 ToMe를 [`StableDiffusionPipeline`]에 적용하는 방법, 예상되는 속도 향상, [`StableDiffusionPipeline`]에서 ToMe를 사용할 때의 질적 측면에 대해 설명합니다.
+
+## ToMe 사용하기
+
+ToMe의 저자들은 [`tomesd`](https://github.com/dbolya/tomesd)라는 편리한 Python 라이브러리를 공개했는데, 이 라이브러리를 이용하면 [`DiffusionPipeline`]에 ToMe를 다음과 같이 적용할 수 있습니다:
+
+```diff
+from diffusers import StableDiffusionPipeline
+import tomesd
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+).to("cuda")
++ tomesd.apply_patch(pipeline, ratio=0.5)
+
+image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
+```
+
+이것이 다입니다!
+
+`tomesd.apply_patch()`는 파이프라인 추론 속도와 생성된 토큰의 품질 사이의 균형을 맞출 수 있도록 [여러 개의 인자](https://github.com/dbolya/tomesd#usage)를 노출합니다. 이러한 인수 중 가장 중요한 것은 `ratio(비율)`입니다. `ratio`은 forward pass 중에 병합될 토큰의 수를 제어합니다. `tomesd`에 대한 자세한 내용은 해당 리포지토리(https://github.com/dbolya/tomesd) 및 [논문](https://arxiv.org/abs/2303.17604)을 참고하시기 바랍니다.
+
+## `StableDiffusionPipeline`으로 `tomesd` 벤치마킹하기
+
+We benchmarked the impact of using `tomesd` on [`StableDiffusionPipeline`] along with [xformers](https://huggingface.co/docs/diffusers/optimization/xformers) across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5):
+다양한 이미지 해상도에서 [xformers](https://huggingface.co/docs/diffusers/optimization/xformers)를 적용한 상태에서, [`StableDiffusionPipeline`]에 `tomesd`를 사용했을 때의 영향을 벤치마킹했습니다. 테스트 GPU 장치로 A100과 V100을 사용했으며 개발 환경은 다음과 같습니다(Python 3.8.5 사용):
+
+```bash
+- `diffusers` version: 0.15.1
+- Python version: 3.8.16
+- PyTorch version (GPU?): 1.13.1+cu116 (True)
+- Huggingface_hub version: 0.13.2
+- Transformers version: 4.27.2
+- Accelerate version: 0.18.0
+- xFormers version: 0.0.16
+- tomesd version: 0.1.2
+```
+
+벤치마킹에는 다음 스크립트를 사용했습니다: [https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). 결과는 다음과 같습니다:
+
+### A100
+
+| 해상도 | 배치 크기 | Vanilla | ToMe | ToMe + xFormers | ToMe 속도 향상 (%) | ToMe + xFormers 속도 향상 (%) |
+| --- | --- | --- | --- | --- | --- | --- |
+| 512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 |
+| | | | | | | |
+| 768 | 10 | OOM | 14.71 | 11 | | |
+| | 8 | OOM | 11.56 | 8.84 | | |
+| | 4 | OOM | 5.98 | 4.66 | | |
+| | 2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 |
+| | 1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 |
+| | | | | | | |
+| 1024 | 10 | OOM | OOM | OOM | | |
+| | 8 | OOM | OOM | OOM | | |
+| | 4 | OOM | 12.51 | 9.09 | | |
+| | 2 | OOM | 6.52 | 4.96 | | |
+| | 1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 |
+
+***결과는 초 단위입니다. 속도 향상은 `Vanilla`과 비교해 계산됩니다.***
+
+### V100
+
+| 해상도 | 배치 크기 | Vanilla | ToMe | ToMe + xFormers | ToMe 속도 향상 (%) | ToMe + xFormers 속도 향상 (%) |
+| --- | --- | --- | --- | --- | --- | --- |
+| 512 | 10 | OOM | 10.03 | 9.29 | | |
+| | 8 | OOM | 8.05 | 7.47 | | |
+| | 4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 |
+| | 2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 |
+| | 1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 |
+| | | | | | | |
+| 768 | 10 | OOM | OOM | 23.67 | | |
+| | 8 | OOM | OOM | 18.81 | | |
+| | 4 | OOM | 11.81 | 9.7 | | |
+| | 2 | OOM | 6.27 | 5.2 | | |
+| | 1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 |
+| | | | | | | |
+| 1024 | 10 | OOM | OOM | OOM | | |
+| | 8 | OOM | OOM | OOM | | |
+| | 4 | OOM | OOM | 19.35 | | |
+| | 2 | OOM | 13 | 10.78 | | |
+| | 1 | OOM | 6.66 | 5.54 | | |
+
+위의 표에서 볼 수 있듯이, 이미지 해상도가 높을수록 `tomesd`를 사용한 속도 향상이 더욱 두드러집니다. 또한 `tomesd`를 사용하면 1024x1024와 같은 더 높은 해상도에서 파이프라인을 실행할 수 있다는 점도 흥미롭습니다.
+
+[`torch.compile()`](https://huggingface.co/docs/diffusers/optimization/torch2.0)을 사용하면 추론 속도를 더욱 높일 수 있습니다.
+
+## 품질
+
+As reported in [the paper](https://arxiv.org/abs/2303.17604), ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the `ratio`, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.
+
+To test the quality of the generated samples using our setup, we sampled a few prompts from the “Parti Prompts” (introduced in [Parti](https://parti.research.google/)) and performed inference with the [`StableDiffusionPipeline`] in the following settings:
+
+[논문](https://arxiv.org/abs/2303.17604)에 보고된 바와 같이, ToMe는 생성된 이미지의 품질을 상당 부분 보존하면서 추론 속도를 높일 수 있습니다. `ratio`을 높이면 추론 속도를 더 높일 수 있지만, 이미지 품질이 저하될 수 있습니다.
+
+해당 설정을 사용하여 생성된 샘플의 품질을 테스트하기 위해, "Parti 프롬프트"([Parti](https://parti.research.google/)에서 소개)에서 몇 가지 프롬프트를 샘플링하고 다음 설정에서 [`StableDiffusionPipeline`]을 사용하여 추론을 수행했습니다:
+
+- Vanilla [`StableDiffusionPipeline`]
+- [`StableDiffusionPipeline`] + ToMe
+- [`StableDiffusionPipeline`] + ToMe + xformers
+
+생성된 샘플의 품질이 크게 저하되는 것을 발견하지 못했습니다. 다음은 샘플입니다:
+
+![tome-samples](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/tome/tome_samples.png)
+
+생성된 샘플은 [여기](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=)에서 확인할 수 있습니다. 이 실험을 수행하기 위해 [이 스크립트](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd)를 사용했습니다.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/optimization/torch2.0.md b/diffusers/docs/source/ko/optimization/torch2.0.md
new file mode 100644
index 0000000000000000000000000000000000000000..0d0f1043d00be2fe1f05e9c58c5210f3faede48c
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/torch2.0.md
@@ -0,0 +1,445 @@
+
+
+# Diffusers에서의 PyTorch 2.0 가속화 지원
+
+`0.13.0` 버전부터 Diffusers는 [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/)에서의 최신 최적화를 지원합니다. 이는 다음을 포함됩니다.
+1. momory-efficient attention을 사용한 가속화된 트랜스포머 지원 - `xformers`같은 추가적인 dependencies 필요 없음
+2. 추가 성능 향상을 위한 개별 모델에 대한 컴파일 기능 [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) 지원
+
+
+## 설치
+가속화된 어텐션 구현과 및 `torch.compile()`을 사용하기 위해, pip에서 최신 버전의 PyTorch 2.0을 설치되어 있고 diffusers 0.13.0. 버전 이상인지 확인하세요. 아래 설명된 바와 같이, PyTorch 2.0이 활성화되어 있을 때 diffusers는 최적화된 어텐션 프로세서([`AttnProcessor2_0`](https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/src/diffusers/models/attention_processor.py#L798))를 사용합니다.
+
+```bash
+pip install --upgrade torch diffusers
+```
+
+## 가속화된 트랜스포머와 `torch.compile` 사용하기.
+
+
+1. **가속화된 트랜스포머 구현**
+
+ PyTorch 2.0에는 [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) 함수를 통해 최적화된 memory-efficient attention의 구현이 포함되어 있습니다. 이는 입력 및 GPU 유형에 따라 여러 최적화를 자동으로 활성화합니다. 이는 [xFormers](https://github.com/facebookresearch/xformers)의 `memory_efficient_attention`과 유사하지만 기본적으로 PyTorch에 내장되어 있습니다.
+
+ 이러한 최적화는 PyTorch 2.0이 설치되어 있고 `torch.nn.functional.scaled_dot_product_attention`을 사용할 수 있는 경우 Diffusers에서 기본적으로 활성화됩니다. 이를 사용하려면 `torch 2.0`을 설치하고 파이프라인을 사용하기만 하면 됩니다. 예를 들어:
+
+ ```Python
+ import torch
+ from diffusers import DiffusionPipeline
+
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ pipe = pipe.to("cuda")
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ image = pipe(prompt).images[0]
+ ```
+
+ 이를 명시적으로 활성화하려면(필수는 아님) 아래와 같이 수행할 수 있습니다.
+
+ ```diff
+ import torch
+ from diffusers import DiffusionPipeline
+ + from diffusers.models.attention_processor import AttnProcessor2_0
+
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+ + pipe.unet.set_attn_processor(AttnProcessor2_0())
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ image = pipe(prompt).images[0]
+ ```
+
+ 이 실행 과정은 `xFormers`만큼 빠르고 메모리적으로 효율적이어야 합니다. 자세한 내용은 [벤치마크](#benchmark)에서 확인하세요.
+
+ 파이프라인을 보다 deterministic으로 만들거나 파인 튜닝된 모델을 [Core ML](https://huggingface.co/docs/diffusers/v0.16.0/en/optimization/coreml#how-to-run-stable-diffusion-with-core-ml)과 같은 다른 형식으로 변환해야 하는 경우 바닐라 어텐션 프로세서 ([`AttnProcessor`](https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/src/diffusers/models/attention_processor.py#L402))로 되돌릴 수 있습니다. 일반 어텐션 프로세서를 사용하려면 [`~diffusers.UNet2DConditionModel.set_default_attn_processor`] 함수를 사용할 수 있습니다:
+
+ ```Python
+ import torch
+ from diffusers import DiffusionPipeline
+ from diffusers.models.attention_processor import AttnProcessor
+
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+ pipe.unet.set_default_attn_processor()
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ image = pipe(prompt).images[0]
+ ```
+
+2. **torch.compile**
+
+ 추가적인 속도 향상을 위해 새로운 `torch.compile` 기능을 사용할 수 있습니다. 파이프라인의 UNet은 일반적으로 계산 비용이 가장 크기 때문에 나머지 하위 모델(텍스트 인코더와 VAE)은 그대로 두고 `unet`을 `torch.compile`로 래핑합니다. 자세한 내용과 다른 옵션은 [torch 컴파일 문서](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)를 참조하세요.
+
+ ```python
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
+ ```
+
+ GPU 유형에 따라 `compile()`은 가속화된 트랜스포머 최적화를 통해 **5% - 300%**의 _추가 성능 향상_을 얻을 수 있습니다. 그러나 컴파일은 Ampere(A100, 3090), Ada(4090) 및 Hopper(H100)와 같은 최신 GPU 아키텍처에서 더 많은 성능 향상을 가져올 수 있음을 참고하세요.
+
+ 컴파일은 완료하는 데 약간의 시간이 걸리므로, 파이프라인을 한 번 준비한 다음 동일한 유형의 추론 작업을 여러 번 수행해야 하는 상황에 가장 적합합니다. 다른 이미지 크기에서 컴파일된 파이프라인을 호출하면 시간적 비용이 많이 들 수 있는 컴파일 작업이 다시 트리거됩니다.
+
+
+## 벤치마크
+
+PyTorch 2.0의 효율적인 어텐션 구현과 `torch.compile`을 사용하여 가장 많이 사용되는 5개의 파이프라인에 대해 다양한 GPU와 배치 크기에 걸쳐 포괄적인 벤치마크를 수행했습니다. 여기서는 [`torch.compile()`이 최적으로 활용되도록 하는](https://github.com/huggingface/diffusers/pull/3313) `diffusers 0.17.0.dev0`을 사용했습니다.
+
+### 벤치마킹 코드
+
+#### Stable Diffusion text-to-image
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+path = "runwayml/stable-diffusion-v1-5"
+
+run_compile = True # Set True / False
+
+pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ images = pipe(prompt=prompt).images
+```
+
+#### Stable Diffusion image-to-image
+
+```python
+from diffusers import StableDiffusionImg2ImgPipeline
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+
+path = "runwayml/stable-diffusion-v1-5"
+
+run_compile = True # Set True / False
+
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ image = pipe(prompt=prompt, image=init_image).images[0]
+```
+
+#### Stable Diffusion - inpainting
+
+```python
+from diffusers import StableDiffusionInpaintPipeline
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+def download_image(url):
+ response = requests.get(url)
+ return Image.open(BytesIO(response.content)).convert("RGB")
+
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+path = "runwayml/stable-diffusion-inpainting"
+
+run_compile = True # Set True / False
+
+pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+#### ControlNet
+
+```python
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+
+path = "runwayml/stable-diffusion-v1-5"
+
+run_compile = True # Set True / False
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ path, controlnet=controlnet, torch_dtype=torch.float16
+)
+
+pipe = pipe.to("cuda")
+pipe.unet.to(memory_format=torch.channels_last)
+pipe.controlnet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ print("Run torch compile")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "ghibli style, a fantasy landscape with castles"
+
+for _ in range(3):
+ image = pipe(prompt=prompt, image=init_image).images[0]
+```
+
+#### IF text-to-image + upscaling
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+run_compile = True # Set True / False
+
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16)
+pipe.to("cuda")
+pipe_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16)
+pipe_2.to("cuda")
+pipe_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16)
+pipe_3.to("cuda")
+
+
+pipe.unet.to(memory_format=torch.channels_last)
+pipe_2.unet.to(memory_format=torch.channels_last)
+pipe_3.unet.to(memory_format=torch.channels_last)
+
+if run_compile:
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe_2.unet = torch.compile(pipe_2.unet, mode="reduce-overhead", fullgraph=True)
+ pipe_3.unet = torch.compile(pipe_3.unet, mode="reduce-overhead", fullgraph=True)
+
+prompt = "the blue hulk"
+
+prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
+neg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
+
+for _ in range(3):
+ image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
+ image_2 = pipe_2(image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
+ image_3 = pipe_3(prompt=prompt, image=image, noise_level=100).images
+```
+
+PyTorch 2.0 및 `torch.compile()`로 얻을 수 있는 가능한 속도 향상에 대해, [Stable Diffusion text-to-image pipeline](StableDiffusionPipeline)에 대한 상대적인 속도 향상을 보여주는 차트를 5개의 서로 다른 GPU 제품군(배치 크기 4)에 대해 나타냅니다:
+
+![t2i_speedup](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/t2i_speedup.png)
+
+To give you an even better idea of how this speed-up holds for the other pipelines presented above, consider the following
+plot that shows the benchmarking numbers from an A100 across three different batch sizes
+(with PyTorch 2.0 nightly and `torch.compile()`):
+이 속도 향상이 위에 제시된 다른 파이프라인에 대해서도 어떻게 유지되는지 더 잘 이해하기 위해, 세 가지의 다른 배치 크기에 걸쳐 A100의 벤치마킹(PyTorch 2.0 nightly 및 `torch.compile() 사용) 수치를 보여주는 차트를 보입니다:
+
+![a100_numbers](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/a100_numbers.png)
+
+_(위 차트의 벤치마크 메트릭은 **초당 iteration 수(iterations/second)**입니다)_
+
+그러나 투명성을 위해 모든 벤치마킹 수치를 공개합니다!
+
+다음 표들에서는, **_초당 처리되는 iteration_** 수 측면에서의 결과를 보여줍니다.
+
+### A100 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 21.66 | 23.13 | 44.03 | 49.74 |
+| SD - img2img | 21.81 | 22.40 | 43.92 | 46.32 |
+| SD - inpaint | 22.24 | 23.23 | 43.76 | 49.25 |
+| SD - controlnet | 15.02 | 15.82 | 32.13 | 36.08 |
+| IF | 20.21 / 13.84 / 24.00 | 20.12 / 13.70 / 24.03 | ❌ | 97.34 / 27.23 / 111.66 |
+
+### A100 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 11.6 | 13.12 | 14.62 | 17.27 |
+| SD - img2img | 11.47 | 13.06 | 14.66 | 17.25 |
+| SD - inpaint | 11.67 | 13.31 | 14.88 | 17.48 |
+| SD - controlnet | 8.28 | 9.38 | 10.51 | 12.41 |
+| IF | 25.02 | 18.04 | ❌ | 48.47 |
+
+### A100 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 3.04 | 3.6 | 3.83 | 4.68 |
+| SD - img2img | 2.98 | 3.58 | 3.83 | 4.67 |
+| SD - inpaint | 3.04 | 3.66 | 3.9 | 4.76 |
+| SD - controlnet | 2.15 | 2.58 | 2.74 | 3.35 |
+| IF | 8.78 | 9.82 | ❌ | 16.77 |
+
+### V100 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 18.99 | 19.14 | 20.95 | 22.17 |
+| SD - img2img | 18.56 | 19.18 | 20.95 | 22.11 |
+| SD - inpaint | 19.14 | 19.06 | 21.08 | 22.20 |
+| SD - controlnet | 13.48 | 13.93 | 15.18 | 15.88 |
+| IF | 20.01 / 9.08 / 23.34 | 19.79 / 8.98 / 24.10 | ❌ | 55.75 / 11.57 / 57.67 |
+
+### V100 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 5.96 | 5.89 | 6.83 | 6.86 |
+| SD - img2img | 5.90 | 5.91 | 6.81 | 6.82 |
+| SD - inpaint | 5.99 | 6.03 | 6.93 | 6.95 |
+| SD - controlnet | 4.26 | 4.29 | 4.92 | 4.93 |
+| IF | 15.41 | 14.76 | ❌ | 22.95 |
+
+### V100 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 1.66 | 1.66 | 1.92 | 1.90 |
+| SD - img2img | 1.65 | 1.65 | 1.91 | 1.89 |
+| SD - inpaint | 1.69 | 1.69 | 1.95 | 1.93 |
+| SD - controlnet | 1.19 | 1.19 | OOM after warmup | 1.36 |
+| IF | 5.43 | 5.29 | ❌ | 7.06 |
+
+### T4 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 6.9 | 6.95 | 7.3 | 7.56 |
+| SD - img2img | 6.84 | 6.99 | 7.04 | 7.55 |
+| SD - inpaint | 6.91 | 6.7 | 7.01 | 7.37 |
+| SD - controlnet | 4.89 | 4.86 | 5.35 | 5.48 |
+| IF | 17.42 / 2.47 / 18.52 | 16.96 / 2.45 / 18.69 | ❌ | 24.63 / 2.47 / 23.39 |
+
+### T4 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 1.79 | 1.79 | 2.03 | 1.99 |
+| SD - img2img | 1.77 | 1.77 | 2.05 | 2.04 |
+| SD - inpaint | 1.81 | 1.82 | 2.09 | 2.09 |
+| SD - controlnet | 1.34 | 1.27 | 1.47 | 1.46 |
+| IF | 5.79 | 5.61 | ❌ | 7.39 |
+
+### T4 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 2.34s | 2.30s | OOM after 2nd iteration | 1.99s |
+| SD - img2img | 2.35s | 2.31s | OOM after warmup | 2.00s |
+| SD - inpaint | 2.30s | 2.26s | OOM after 2nd iteration | 1.95s |
+| SD - controlnet | OOM after 2nd iteration | OOM after 2nd iteration | OOM after warmup | OOM after warmup |
+| IF * | 1.44 | 1.44 | ❌ | 1.94 |
+
+### RTX 3090 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 22.56 | 22.84 | 23.84 | 25.69 |
+| SD - img2img | 22.25 | 22.61 | 24.1 | 25.83 |
+| SD - inpaint | 22.22 | 22.54 | 24.26 | 26.02 |
+| SD - controlnet | 16.03 | 16.33 | 17.38 | 18.56 |
+| IF | 27.08 / 9.07 / 31.23 | 26.75 / 8.92 / 31.47 | ❌ | 68.08 / 11.16 / 65.29 |
+
+### RTX 3090 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 6.46 | 6.35 | 7.29 | 7.3 |
+| SD - img2img | 6.33 | 6.27 | 7.31 | 7.26 |
+| SD - inpaint | 6.47 | 6.4 | 7.44 | 7.39 |
+| SD - controlnet | 4.59 | 4.54 | 5.27 | 5.26 |
+| IF | 16.81 | 16.62 | ❌ | 21.57 |
+
+### RTX 3090 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 1.7 | 1.69 | 1.93 | 1.91 |
+| SD - img2img | 1.68 | 1.67 | 1.93 | 1.9 |
+| SD - inpaint | 1.72 | 1.71 | 1.97 | 1.94 |
+| SD - controlnet | 1.23 | 1.22 | 1.4 | 1.38 |
+| IF | 5.01 | 5.00 | ❌ | 6.33 |
+
+### RTX 4090 (batch size: 1)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 40.5 | 41.89 | 44.65 | 49.81 |
+| SD - img2img | 40.39 | 41.95 | 44.46 | 49.8 |
+| SD - inpaint | 40.51 | 41.88 | 44.58 | 49.72 |
+| SD - controlnet | 29.27 | 30.29 | 32.26 | 36.03 |
+| IF | 69.71 / 18.78 / 85.49 | 69.13 / 18.80 / 85.56 | ❌ | 124.60 / 26.37 / 138.79 |
+
+### RTX 4090 (batch size: 4)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 12.62 | 12.84 | 15.32 | 15.59 |
+| SD - img2img | 12.61 | 12,.79 | 15.35 | 15.66 |
+| SD - inpaint | 12.65 | 12.81 | 15.3 | 15.58 |
+| SD - controlnet | 9.1 | 9.25 | 11.03 | 11.22 |
+| IF | 31.88 | 31.14 | ❌ | 43.92 |
+
+### RTX 4090 (batch size: 16)
+
+| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
+|:---:|:---:|:---:|:---:|:---:|
+| SD - txt2img | 3.17 | 3.2 | 3.84 | 3.85 |
+| SD - img2img | 3.16 | 3.2 | 3.84 | 3.85 |
+| SD - inpaint | 3.17 | 3.2 | 3.85 | 3.85 |
+| SD - controlnet | 2.23 | 2.3 | 2.7 | 2.75 |
+| IF | 9.26 | 9.2 | ❌ | 13.31 |
+
+## 참고
+
+* Follow [this PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.
+* For the IF pipeline and batch sizes > 1, we only used a batch size of >1 in the first IF pipeline for text-to-image generation and NOT for upscaling. So, that means the two upscaling pipelines received a batch size of 1.
+
+*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*
+
+* 벤치마크 수행에 사용된 환경에 대한 자세한 내용은 [이 PR](https://github.com/huggingface/diffusers/pull/3313)을 참조하세요.
+* IF 파이프라인와 배치 크기 > 1의 경우 첫 번째 IF 파이프라인에서 text-to-image 생성을 위한 배치 크기 > 1만 사용했으며 업스케일링에는 사용하지 않았습니다. 즉, 두 개의 업스케일링 파이프라인이 배치 크기 1임을 의미합니다.
+
+*Diffusers에서 `torch.compile()` 지원을 개선하는 데 도움을 준 PyTorch 팀의 [Horace He](https://github.com/Chillee)에게 감사드립니다.*
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/optimization/xformers.md b/diffusers/docs/source/ko/optimization/xformers.md
new file mode 100644
index 0000000000000000000000000000000000000000..a8b9408fbe50b07e9cb1e566a0678e2e8ca52ea2
--- /dev/null
+++ b/diffusers/docs/source/ko/optimization/xformers.md
@@ -0,0 +1,36 @@
+
+
+# xFormers 설치하기
+
+추론과 학습 모두에 [xFormers](https://github.com/facebookresearch/xformers)를 사용하는 것이 좋습니다.
+자체 테스트로 어텐션 블록에서 수행된 최적화가 더 빠른 속도와 적은 메모리 소비를 확인했습니다.
+
+2023년 1월에 출시된 xFormers 버전 '0.0.16'부터 사전 빌드된 pip wheel을 사용하여 쉽게 설치할 수 있습니다:
+
+```bash
+pip install xformers
+```
+
+
+
+xFormers PIP 패키지에는 최신 버전의 PyTorch(xFormers 0.0.16에 1.13.1)가 필요합니다. 이전 버전의 PyTorch를 사용해야 하는 경우 [프로젝트 지침](https://github.com/facebookresearch/xformers#installing-xformers)의 소스를 사용해 xFormers를 설치하는 것이 좋습니다.
+
+
+
+xFormers를 설치하면, [여기](fp16#memory-efficient-attention)서 설명한 것처럼 'enable_xformers_memory_efficient_attention()'을 사용하여 추론 속도를 높이고 메모리 소비를 줄일 수 있습니다.
+
+
+
+[이 이슈](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)에 따르면 xFormers `v0.0.16`에서 GPU를 사용한 학습(파인 튜닝 또는 Dreambooth)을 할 수 없습니다. 해당 문제가 발견되면. 해당 코멘트를 참고해 development 버전을 설치하세요.
+
+
diff --git a/diffusers/docs/source/ko/quicktour.md b/diffusers/docs/source/ko/quicktour.md
new file mode 100644
index 0000000000000000000000000000000000000000..e256f6c932233c793e463bf968056c449bf65a32
--- /dev/null
+++ b/diffusers/docs/source/ko/quicktour.md
@@ -0,0 +1,313 @@
+
+[[open-in-colab]]
+
+# 훑어보기
+
+Diffusion 모델은 이미지나 오디오와 같은 관심 샘플들을 생성하기 위해 랜덤 가우시안 노이즈를 단계별로 제거하도록 학습됩니다. 이로 인해 생성 AI에 대한 관심이 매우 높아졌으며, 인터넷에서 diffusion 생성 이미지의 예를 본 적이 있을 것입니다. 🧨 Diffusers는 누구나 diffusion 모델들을 널리 이용할 수 있도록 하기 위한 라이브러리입니다.
+
+개발자든 일반 사용자든 이 훑어보기를 통해 🧨 diffusers를 소개하고 빠르게 생성할 수 있도록 도와드립니다! 알아야 할 라이브러리의 주요 구성 요소는 크게 세 가지입니다:
+
+* [`DiffusionPipeline`]은 추론을 위해 사전 학습된 diffusion 모델에서 샘플을 빠르게 생성하도록 설계된 높은 수준의 엔드투엔드 클래스입니다.
+* Diffusion 시스템 생성을 위한 빌딩 블록으로 사용할 수 있는 널리 사용되는 사전 학습된 [model](./api/models) 아키텍처 및 모듈.
+* 다양한 [schedulers](./api/schedulers/overview) - 학습을 위해 노이즈를 추가하는 방법과 추론 중에 노이즈 제거된 이미지를 생성하는 방법을 제어하는 알고리즘입니다.
+
+훑어보기에서는 추론을 위해 [`DiffusionPipeline`]을 사용하는 방법을 보여준 다음, 모델과 스케줄러를 결합하여 [`DiffusionPipeline`] 내부에서 일어나는 일을 복제하는 방법을 안내합니다.
+
+
+
+훑어보기는 간결한 버전의 🧨 Diffusers 소개로서 [노트북](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 빠르게 시작할 수 있도록 도와드립니다. 디퓨저의 목표, 디자인 철학, 핵심 API에 대한 추가 세부 정보를 자세히 알아보려면 노트북을 확인하세요!
+
+
+
+시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:
+
+```py
+# 주석 풀어서 Colab에 필요한 라이브러리 설치하기.
+#!pip install --upgrade diffusers accelerate transformers
+```
+
+- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index)는 추론 및 학습을 위한 모델 로딩 속도를 높여줍니다.
+- [🤗 Transformers](https://huggingface.co/docs/transformers/index)는 [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview)과 같이 가장 많이 사용되는 diffusion 모델을 실행하는 데 필요합니다.
+
+## DiffusionPipeline
+
+[`DiffusionPipeline`] 은 추론을 위해 사전 학습된 diffusion 시스템을 사용하는 가장 쉬운 방법입니다. 모델과 스케줄러를 포함하는 엔드 투 엔드 시스템입니다. 다양한 작업에 [`DiffusionPipeline`]을 바로 사용할 수 있습니다. 아래 표에서 지원되는 몇 가지 작업을 살펴보고, 지원되는 작업의 전체 목록은 [🧨 Diffusers Summary](./api/pipelines/overview#diffusers-summary) 표에서 확인할 수 있습니다.
+
+| **Task** | **Description** | **Pipeline**
+|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|
+| Unconditional Image Generation | generate an image from Gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
+| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
+| Text-Guided Image-to-Image Translation | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |
+| Text-Guided Image-Inpainting | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |
+| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2img](./using-diffusers/depth2img) |
+
+먼저 [`DiffusionPipeline`]의 인스턴스를 생성하고 다운로드할 파이프라인 체크포인트를 지정합니다.
+허깅페이스 허브에 저장된 모든 [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads)에 대해 [`DiffusionPipeline`]을 사용할 수 있습니다.
+이 훑어보기에서는 text-to-image 생성을 위한 [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 체크포인트를 로드합니다.
+
+
+
+[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) 모델의 경우, 모델을 실행하기 전에 [라이선스](https://huggingface.co/spaces/CompVis/stable-diffusion-license)를 먼저 주의 깊게 읽어주세요. 🧨 Diffusers는 불쾌하거나 유해한 콘텐츠를 방지하기 위해 [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)를 구현하고 있지만, 모델의 향상된 이미지 생성 기능으로 인해 여전히 잠재적으로 유해한 콘텐츠가 생성될 수 있습니다.
+
+
+
+[`~DiffusionPipeline.from_pretrained`] 방법으로 모델 로드하기:
+
+```python
+>>> from diffusers import DiffusionPipeline
+
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+```
+
+The [`DiffusionPipeline`]은 모든 모델링, 토큰화, 스케줄링 컴포넌트를 다운로드하고 캐시합니다. Stable Diffusion Pipeline은 무엇보다도 [`UNet2DConditionModel`]과 [`PNDMScheduler`]로 구성되어 있음을 알 수 있습니다:
+
+```py
+>>> pipeline
+StableDiffusionPipeline {
+ "_class_name": "StableDiffusionPipeline",
+ "_diffusers_version": "0.13.1",
+ ...,
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ ...,
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
+
+이 모델은 약 14억 개의 파라미터로 구성되어 있으므로 GPU에서 파이프라인을 실행할 것을 강력히 권장합니다.
+PyTorch에서와 마찬가지로 제너레이터 객체를 GPU로 이동할 수 있습니다:
+
+```python
+>>> pipeline.to("cuda")
+```
+
+이제 `파이프라인`에 텍스트 프롬프트를 전달하여 이미지를 생성한 다음 노이즈가 제거된 이미지에 액세스할 수 있습니다. 기본적으로 이미지 출력은 [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 감싸집니다.
+
+```python
+>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
+>>> image
+```
+
+
+
+
+
+`save`를 호출하여 이미지를 저장합니다:
+
+```python
+>>> image.save("image_of_squirrel_painting.png")
+```
+
+### 로컬 파이프라인
+
+파이프라인을 로컬에서 사용할 수도 있습니다. 유일한 차이점은 가중치를 먼저 다운로드해야 한다는 점입니다:
+
+```bash
+!git lfs install
+!git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+그런 다음 저장된 가중치를 파이프라인에 로드합니다:
+
+```python
+>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
+```
+
+이제 위 섹션에서와 같이 파이프라인을 실행할 수 있습니다.
+
+### 스케줄러 교체
+
+스케줄러마다 노이즈 제거 속도와 품질이 서로 다릅니다. 자신에게 가장 적합한 스케줄러를 찾는 가장 좋은 방법은 직접 사용해 보는 것입니다! 🧨 Diffusers의 주요 기능 중 하나는 스케줄러 간에 쉽게 전환이 가능하다는 것입니다. 예를 들어, 기본 스케줄러인 [`PNDMScheduler`]를 [`EulerDiscreteScheduler`]로 바꾸려면, [`~diffusers.ConfigMixin.from_config`] 메서드를 사용하여 로드하세요:
+
+```py
+>>> from diffusers import EulerDiscreteScheduler
+
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+```
+
+새 스케줄러로 이미지를 생성해보고 어떤 차이가 있는지 확인해 보세요!
+
+다음 섹션에서는 모델과 스케줄러라는 [`DiffusionPipeline`]을 구성하는 컴포넌트를 자세히 살펴보고 이러한 컴포넌트를 사용하여 고양이 이미지를 생성하는 방법을 배워보겠습니다.
+
+## 모델
+
+대부분의 모델은 노이즈가 있는 샘플을 가져와 각 시간 간격마다 노이즈가 적은 이미지와 입력 이미지 사이의 차이인 *노이즈 잔차*(다른 모델은 이전 샘플을 직접 예측하거나 속도 또는 [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)을 예측하는 학습을 합니다)을 예측합니다. 모델을 믹스 앤 매치하여 다른 diffusion 시스템을 만들 수 있습니다.
+
+모델은 [`~ModelMixin.from_pretrained`] 메서드로 시작되며, 이 메서드는 모델 가중치를 로컬에 캐시하여 다음에 모델을 로드할 때 더 빠르게 로드할 수 있습니다. 훑어보기에서는 고양이 이미지에 대해 학습된 체크포인트가 있는 기본적인 unconditional 이미지 생성 모델인 [`UNet2DModel`]을 로드합니다:
+
+```py
+>>> from diffusers import UNet2DModel
+
+>>> repo_id = "google/ddpm-cat-256"
+>>> model = UNet2DModel.from_pretrained(repo_id)
+```
+
+모델 매개변수에 액세스하려면 `model.config`를 호출합니다:
+
+```py
+>>> model.config
+```
+
+모델 구성은 🧊 고정된 🧊 딕셔너리로, 모델이 생성된 후에는 해당 매개 변수들을 변경할 수 없습니다. 이는 의도적인 것으로, 처음에 모델 아키텍처를 정의하는 데 사용된 매개변수는 동일하게 유지하면서 다른 매개변수는 추론 중에 조정할 수 있도록 하기 위한 것입니다.
+
+가장 중요한 매개변수들은 다음과 같습니다:
+
+* `sample_size`: 입력 샘플의 높이 및 너비 치수입니다.
+* `in_channels`: 입력 샘플의 입력 채널 수입니다.
+* `down_block_types` 및 `up_block_types`: UNet 아키텍처를 생성하는 데 사용되는 다운 및 업샘플링 블록의 유형.
+* `block_out_channels`: 다운샘플링 블록의 출력 채널 수. 업샘플링 블록의 입력 채널 수에 역순으로 사용되기도 합니다.
+* `layers_per_block`: 각 UNet 블록에 존재하는 ResNet 블록의 수입니다.
+
+추론에 모델을 사용하려면 랜덤 가우시안 노이즈로 이미지 모양을 만듭니다. 모델이 여러 개의 무작위 노이즈를 수신할 수 있으므로 'batch' 축, 입력 채널 수에 해당하는 'channel' 축, 이미지의 높이와 너비를 나타내는 'sample_size' 축이 있어야 합니다:
+
+```py
+>>> import torch
+
+>>> torch.manual_seed(0)
+
+>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+>>> noisy_sample.shape
+torch.Size([1, 3, 256, 256])
+```
+
+추론을 위해 모델에 노이즈가 있는 이미지와 `timestep`을 전달합니다. 'timestep'은 입력 이미지의 노이즈 정도를 나타내며, 시작 부분에 더 많은 노이즈가 있고 끝 부분에 더 적은 노이즈가 있습니다. 이를 통해 모델이 diffusion 과정에서 시작 또는 끝에 더 가까운 위치를 결정할 수 있습니다. `sample` 메서드를 사용하여 모델 출력을 얻습니다:
+
+```py
+>>> with torch.no_grad():
+... noisy_residual = model(sample=noisy_sample, timestep=2).sample
+```
+
+하지만 실제 예를 생성하려면 노이즈 제거 프로세스를 안내할 스케줄러가 필요합니다. 다음 섹션에서는 모델을 스케줄러와 결합하는 방법에 대해 알아봅니다.
+
+## 스케줄러
+
+스케줄러는 모델 출력이 주어졌을 때 노이즈가 많은 샘플에서 노이즈가 적은 샘플로 전환하는 것을 관리합니다 - 이 경우 'noisy_residual'.
+
+
+
+🧨 Diffusers는 Diffusion 시스템을 구축하기 위한 툴박스입니다. [`DiffusionPipeline`]을 사용하면 미리 만들어진 Diffusion 시스템을 편리하게 시작할 수 있지만, 모델과 스케줄러 구성 요소를 개별적으로 선택하여 사용자 지정 Diffusion 시스템을 구축할 수도 있습니다.
+
+
+
+훑어보기의 경우, [`~diffusers.ConfigMixin.from_config`] 메서드를 사용하여 [`DDPMScheduler`]를 인스턴스화합니다:
+
+```py
+>>> from diffusers import DDPMScheduler
+
+>>> scheduler = DDPMScheduler.from_config(repo_id)
+>>> scheduler
+DDPMScheduler {
+ "_class_name": "DDPMScheduler",
+ "_diffusers_version": "0.13.1",
+ "beta_end": 0.02,
+ "beta_schedule": "linear",
+ "beta_start": 0.0001,
+ "clip_sample": true,
+ "clip_sample_range": 1.0,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "trained_betas": null,
+ "variance_type": "fixed_small"
+}
+```
+
+
+
+💡 스케줄러가 구성에서 어떻게 인스턴스화되는지 주목하세요. 모델과 달리 스케줄러에는 학습 가능한 가중치가 없으며 매개변수도 없습니다!
+
+
+
+가장 중요한 매개변수는 다음과 같습니다:
+
+* `num_train_timesteps`: 노이즈 제거 프로세스의 길이, 즉 랜덤 가우스 노이즈를 데이터 샘플로 처리하는 데 필요한 타임스텝 수입니다.
+* `beta_schedule`: 추론 및 학습에 사용할 노이즈 스케줄 유형입니다.
+* `beta_start` 및 `beta_end`: 노이즈 스케줄의 시작 및 종료 노이즈 값입니다.
+
+노이즈가 약간 적은 이미지를 예측하려면 스케줄러의 [`~diffusers.DDPMScheduler.step`] 메서드에 모델 출력, `timestep`, 현재 `sample`을 전달하세요.
+
+```py
+>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample
+>>> less_noisy_sample.shape
+```
+
+`less_noisy_sample`을 다음 `timestep`으로 넘기면 노이즈가 더 줄어듭니다! 이제 이 모든 것을 한데 모아 전체 노이즈 제거 과정을 시각화해 보겠습니다.
+
+먼저 노이즈 제거된 이미지를 후처리하여 `PIL.Image`로 표시하는 함수를 만듭니다:
+
+```py
+>>> import PIL.Image
+>>> import numpy as np
+
+
+>>> def display_sample(sample, i):
+... image_processed = sample.cpu().permute(0, 2, 3, 1)
+... image_processed = (image_processed + 1.0) * 127.5
+... image_processed = image_processed.numpy().astype(np.uint8)
+
+... image_pil = PIL.Image.fromarray(image_processed[0])
+... display(f"Image at step {i}")
+... display(image_pil)
+```
+
+노이즈 제거 프로세스의 속도를 높이려면 입력과 모델을 GPU로 옮기세요:
+
+```py
+>>> model.to("cuda")
+>>> noisy_sample = noisy_sample.to("cuda")
+```
+
+이제 노이즈가 적은 샘플의 잔차를 예측하고 스케줄러로 노이즈가 적은 샘플을 계산하는 노이즈 제거 루프를 생성합니다:
+
+```py
+>>> import tqdm
+
+>>> sample = noisy_sample
+
+>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
+... # 1. predict noise residual
+... with torch.no_grad():
+... residual = model(sample, t).sample
+
+... # 2. compute less noisy image and set x_t -> x_t-1
+... sample = scheduler.step(residual, t, sample).prev_sample
+
+... # 3. optionally look at image
+... if (i + 1) % 50 == 0:
+... display_sample(sample, i + 1)
+```
+
+가만히 앉아서 고양이가 소음으로만 생성되는 것을 지켜보세요!😻
+
+
+
+
+
+## 다음 단계
+
+이번 훑어보기에서 🧨 Diffusers로 멋진 이미지를 만들어 보셨기를 바랍니다! 다음 단계로 넘어가세요:
+
+* [training](./tutorials/basic_training) 튜토리얼에서 모델을 학습하거나 파인튜닝하여 나만의 이미지를 생성할 수 있습니다.
+* 다양한 사용 사례는 공식 및 커뮤니티 [학습 또는 파인튜닝 스크립트](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) 예시를 참조하세요.
+* 스케줄러 로드, 액세스, 변경 및 비교에 대한 자세한 내용은 [다른 스케줄러 사용](./using-diffusers/schedulers) 가이드에서 확인하세요.
+* [Stable Diffusion](./stable_diffusion) 가이드에서 프롬프트 엔지니어링, 속도 및 메모리 최적화, 고품질 이미지 생성을 위한 팁과 요령을 살펴보세요.
+* [GPU에서 파이토치 최적화](./optimization/fp16) 가이드와 [애플 실리콘(M1/M2)에서의 Stable Diffusion](./optimization/mps) 및 [ONNX 런타임](./optimization/onnx) 실행에 대한 추론 가이드를 통해 🧨 Diffuser 속도를 높이는 방법을 더 자세히 알아보세요.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/stable_diffusion.md b/diffusers/docs/source/ko/stable_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..65575700e77e7813ea01c302630743065376faf3
--- /dev/null
+++ b/diffusers/docs/source/ko/stable_diffusion.md
@@ -0,0 +1,279 @@
+
+
+# 효과적이고 효율적인 Diffusion
+
+[[open-in-colab]]
+
+특정 스타일로 이미지를 생성하거나 원하는 내용을 포함하도록[`DiffusionPipeline`]을 설정하는 것은 까다로울 수 있습니다. 종종 만족스러운 이미지를 얻기까지 [`DiffusionPipeline`]을 여러 번 실행해야 하는 경우가 많습니다. 그러나 무에서 유를 창조하는 것은 특히 추론을 반복해서 실행하는 경우 계산 집약적인 프로세스입니다.
+
+그렇기 때문에 파이프라인에서 *계산*(속도) 및 *메모리*(GPU RAM) 효율성을 극대화하여 추론 주기 사이의 시간을 단축하여 더 빠르게 반복할 수 있도록 하는 것이 중요합니다.
+
+이 튜토리얼에서는 [`DiffusionPipeline`]을 사용하여 더 빠르고 효과적으로 생성하는 방법을 안내합니다.
+
+[`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 모델을 불러와서 시작합니다:
+
+```python
+from diffusers import DiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(model_id)
+```
+
+예제 프롬프트는 "portrait of an old warrior chief" 이지만, 자유롭게 자신만의 프롬프트를 사용해도 됩니다:
+
+```python
+prompt = "portrait photo of a old warrior chief"
+```
+
+## 속도
+
+
+
+💡 GPU에 액세스할 수 없는 경우 다음과 같은 GPU 제공업체에서 무료로 사용할 수 있습니다!. [Colab](https://colab.research.google.com/)
+
+
+
+추론 속도를 높이는 가장 간단한 방법 중 하나는 Pytorch 모듈을 사용할 때와 같은 방식으로 GPU에 파이프라인을 배치하는 것입니다:
+
+```python
+pipeline = pipeline.to("cuda")
+```
+
+동일한 이미지를 사용하고 개선할 수 있는지 확인하려면 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html)를 사용하고 [재현성](./using-diffusers/reproducibility)에 대한 시드를 설정하세요:
+
+```python
+import torch
+
+generator = torch.Generator("cuda").manual_seed(0)
+```
+
+이제 이미지를 생성할 수 있습니다:
+
+```python
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+이 프로세스는 T4 GPU에서 약 30초가 소요되었습니다(할당된 GPU가 T4보다 나은 경우 더 빠를 수 있음). 기본적으로 [`DiffusionPipeline`]은 50개의 추론 단계에 대해 전체 `float32` 정밀도로 추론을 실행합니다. `float16`과 같은 더 낮은 정밀도로 전환하거나 추론 단계를 더 적게 실행하여 속도를 높일 수 있습니다.
+
+`float16`으로 모델을 로드하고 이미지를 생성해 보겠습니다:
+
+
+```python
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+pipeline = pipeline.to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+이번에는 이미지를 생성하는 데 약 11초밖에 걸리지 않아 이전보다 3배 가까이 빨라졌습니다!
+
+
+
+💡 파이프라인은 항상 `float16`에서 실행할 것을 강력히 권장하며, 지금까지 출력 품질이 저하되는 경우는 거의 없었습니다.
+
+
+
+또 다른 옵션은 추론 단계의 수를 줄이는 것입니다. 보다 효율적인 스케줄러를 선택하면 출력 품질 저하 없이 단계 수를 줄이는 데 도움이 될 수 있습니다. 현재 모델과 호환되는 스케줄러는 `compatibles` 메서드를 호출하여 [`DiffusionPipeline`]에서 찾을 수 있습니다:
+
+```python
+pipeline.scheduler.compatibles
+[
+ diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
+ diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+]
+```
+
+Stable Diffusion 모델은 일반적으로 약 50개의 추론 단계가 필요한 [`PNDMScheduler`]를 기본으로 사용하지만, [`DPMSolverMultistepScheduler`]와 같이 성능이 더 뛰어난 스케줄러는 약 20개 또는 25개의 추론 단계만 필요로 합니다. 새 스케줄러를 로드하려면 [`ConfigMixin.from_config`] 메서드를 사용합니다:
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+```
+
+`num_inference_steps`를 20으로 설정합니다:
+
+```python
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+추론시간을 4초로 단축할 수 있었습니다! ⚡️
+
+## 메모리
+
+파이프라인 성능 향상의 또 다른 핵심은 메모리 사용량을 줄이는 것인데, 초당 생성되는 이미지 수를 최대화하려고 하는 경우가 많기 때문에 간접적으로 더 빠른 속도를 의미합니다. 한 번에 생성할 수 있는 이미지 수를 확인하는 가장 쉬운 방법은 `OutOfMemoryError`(OOM)이 발생할 때까지 다양한 배치 크기를 시도해 보는 것입니다.
+
+프롬프트 목록과 `Generators`에서 이미지 배치를 생성하는 함수를 만듭니다. 좋은 결과를 생성하는 경우 재사용할 수 있도록 각 `Generator`에 시드를 할당해야 합니다.
+
+```python
+def get_inputs(batch_size=1):
+ generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
+ prompts = batch_size * [prompt]
+ num_inference_steps = 20
+
+ return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
+```
+
+또한 각 이미지 배치를 보여주는 기능이 필요합니다:
+
+```python
+from PIL import Image
+
+
+def image_grid(imgs, rows=2, cols=2):
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+```
+
+`batch_size=4`부터 시작해 얼마나 많은 메모리를 소비했는지 확인합니다:
+
+```python
+images = pipeline(**get_inputs(batch_size=4)).images
+image_grid(images)
+```
+
+RAM이 더 많은 GPU가 아니라면 위의 코드에서 `OOM` 오류가 반환되었을 것입니다! 대부분의 메모리는 cross-attention 레이어가 차지합니다. 이 작업을 배치로 실행하는 대신 순차적으로 실행하면 상당한 양의 메모리를 절약할 수 있습니다. 파이프라인을 구성하여 [`~DiffusionPipeline.enable_attention_slicing`] 함수를 사용하기만 하면 됩니다:
+
+
+```python
+pipeline.enable_attention_slicing()
+```
+
+이제 `batch_size`를 8로 늘려보세요!
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+이전에는 4개의 이미지를 배치로 생성할 수도 없었지만, 이제는 이미지당 약 3.5초 만에 8개의 이미지를 배치로 생성할 수 있습니다! 이는 아마도 품질 저하 없이 T4 GPU에서 가장 빠른 속도일 것입니다.
+
+## 품질
+
+지난 두 섹션에서는 `fp16`을 사용하여 파이프라인의 속도를 최적화하고, 더 성능이 좋은 스케줄러를 사용하여 추론 단계의 수를 줄이고, attention slicing을 활성화하여 메모리 소비를 줄이는 방법을 배웠습니다. 이제 생성된 이미지의 품질을 개선하는 방법에 대해 집중적으로 알아보겠습니다.
+
+
+### 더 나은 체크포인트
+
+가장 확실한 단계는 더 나은 체크포인트를 사용하는 것입니다. Stable Diffusion 모델은 좋은 출발점이며, 공식 출시 이후 몇 가지 개선된 버전도 출시되었습니다. 하지만 최신 버전을 사용한다고 해서 자동으로 더 나은 결과를 얻을 수 있는 것은 아닙니다. 여전히 다양한 체크포인트를 직접 실험해보고, [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/) 사용 등 약간의 조사를 통해 최상의 결과를 얻어야 합니다.
+
+이 분야가 성장함에 따라 특정 스타일을 연출할 수 있도록 세밀하게 조정된 고품질 체크포인트가 점점 더 많아지고 있습니다. [Hub](https://huggingface.co/models?library=diffusers&sort=downloads)와 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)를 둘러보고 관심 있는 것을 찾아보세요!
+
+
+### 더 나은 파이프라인 구성 요소
+
+현재 파이프라인 구성 요소를 최신 버전으로 교체해 볼 수도 있습니다. Stability AI의 최신 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae)를 파이프라인에 로드하고 몇 가지 이미지를 생성해 보겠습니다:
+
+
+```python
+from diffusers import AutoencoderKL
+
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
+pipeline.vae = vae
+images = pipeline(**get_inputs(batch_size=8)).images
+image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+### 더 나은 프롬프트 엔지니어링
+
+이미지를 생성하는 데 사용하는 텍스트 프롬프트는 *prompt engineering*이라고 할 정도로 매우 중요합니다. 프롬프트 엔지니어링 시 고려해야 할 몇 가지 사항은 다음과 같습니다:
+
+- 생성하려는 이미지 또는 유사한 이미지가 인터넷에 어떻게 저장되어 있는가?
+- 내가 원하는 스타일로 모델을 유도하기 위해 어떤 추가 세부 정보를 제공할 수 있는가?
+
+이를 염두에 두고 색상과 더 높은 품질의 디테일을 포함하도록 프롬프트를 개선해 봅시다:
+
+
+```python
+prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
+prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
+```
+
+새로운 프롬프트로 이미지 배치를 생성합니다:
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+꽤 인상적입니다! `1`의 시드를 가진 `Generator`에 해당하는 두 번째 이미지에 피사체의 나이에 대한 텍스트를 추가하여 조금 더 조정해 보겠습니다:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+image_grid(images)
+```
+
+
+
+
+
+## 다음 단계
+
+이 튜토리얼에서는 계산 및 메모리 효율을 높이고 생성된 출력의 품질을 개선하기 위해 [`DiffusionPipeline`]을 최적화하는 방법을 배웠습니다. 파이프라인을 더 빠르게 만드는 데 관심이 있다면 다음 리소스를 살펴보세요:
+
+- [PyTorch 2.0](./optimization/torch2.0) 및 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html)이 어떻게 추론 속도를 5~300% 향상시킬 수 있는지 알아보세요. A100 GPU에서는 추론 속도가 최대 50%까지 빨라질 수 있습니다!
+- PyTorch 2를 사용할 수 없는 경우, [xFormers](./optimization/xformers)를 설치하는 것이 좋습니다. 메모리 효율적인 어텐션 메커니즘은 PyTorch 1.13.1과 함께 사용하면 속도가 빨라지고 메모리 소비가 줄어듭니다.
+- 모델 오프로딩과 같은 다른 최적화 기법은 [이 가이드](./optimization/fp16)에서 다루고 있습니다.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/adapt_a_model.md b/diffusers/docs/source/ko/training/adapt_a_model.md
new file mode 100644
index 0000000000000000000000000000000000000000..2b035a449c1d1119b48774949c2cfd330e1d77c9
--- /dev/null
+++ b/diffusers/docs/source/ko/training/adapt_a_model.md
@@ -0,0 +1,54 @@
+
+
+# 새로운 작업에 대한 모델을 적용하기
+
+많은 diffusion 시스템은 같은 구성 요소들을 공유하므로 한 작업에 대해 사전학습된 모델을 완전히 다른 작업에 적용할 수 있습니다.
+
+이 인페인팅을 위한 가이드는 사전학습된 [`UNet2DConditionModel`]의 아키텍처를 초기화하고 수정하여 사전학습된 text-to-image 모델을 어떻게 인페인팅에 적용하는지를 알려줄 것입니다.
+
+## UNet2DConditionModel 파라미터 구성
+
+[`UNet2DConditionModel`]은 [input sample](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels)에서 4개의 채널을 기본적으로 허용합니다. 예를 들어, [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)와 같은 사전학습된 text-to-image 모델을 불러오고 `in_channels`의 수를 확인합니다:
+
+```py
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipeline.unet.config["in_channels"]
+4
+```
+
+인페인팅은 입력 샘플에 9개의 채널이 필요합니다. [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting)와 같은 사전학습된 인페인팅 모델에서 이 값을 확인할 수 있습니다:
+
+```py
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
+pipeline.unet.config["in_channels"]
+9
+```
+
+인페인팅에 대한 text-to-image 모델을 적용하기 위해, `in_channels` 수를 4에서 9로 수정해야 할 것입니다.
+
+사전학습된 text-to-image 모델의 가중치와 [`UNet2DConditionModel`]을 초기화하고 `in_channels`를 9로 수정해 주세요. `in_channels`의 수를 수정하면 크기가 달라지기 때문에 크기가 안 맞는 오류를 피하기 위해 `ignore_mismatched_sizes=True` 및 `low_cpu_mem_usage=False`를 설정해야 합니다.
+
+```py
+from diffusers import UNet2DConditionModel
+
+model_id = "runwayml/stable-diffusion-v1-5"
+unet = UNet2DConditionModel.from_pretrained(
+ model_id, subfolder="unet", in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True
+)
+```
+
+Text-to-image 모델로부터 다른 구성 요소의 사전학습된 가중치는 체크포인트로부터 초기화되지만 `unet`의 입력 채널 가중치 (`conv_in.weight`)는 랜덤하게 초기화됩니다. 그렇지 않으면 모델이 노이즈를 리턴하기 때문에 인페인팅의 모델을 파인튜닝 할 때 중요합니다.
diff --git a/diffusers/docs/source/ko/training/controlnet.md b/diffusers/docs/source/ko/training/controlnet.md
new file mode 100644
index 0000000000000000000000000000000000000000..46632fb8d18d2dbaa73b7690c1da212114d61a67
--- /dev/null
+++ b/diffusers/docs/source/ko/training/controlnet.md
@@ -0,0 +1,331 @@
+
+
+# ControlNet
+
+[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) (ControlNet)은 Lvmin Zhang과 Maneesh Agrawala에 의해 쓰여졌습니다.
+
+이 예시는 [원본 ControlNet 리포지토리에서 예시 학습하기](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md)에 기반합니다. ControlNet은 원들을 채우기 위해 [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k)을 사용해서 학습됩니다.
+
+## 의존성 설치하기
+
+아래의 스크립트를 실행하기 전에, 라이브러리의 학습 의존성을 설치해야 합니다.
+
+
+
+가장 최신 버전의 예시 스크립트를 성공적으로 실행하기 위해서는, 소스에서 설치하고 최신 버전의 설치를 유지하는 것을 강력하게 추천합니다. 우리는 예시 스크립트들을 자주 업데이트하고 예시에 맞춘 특정한 요구사항을 설치합니다.
+
+
+
+위 사항을 만족시키기 위해서, 새로운 가상환경에서 다음 일련의 스텝을 실행하세요:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+그 다음에는 [예시 폴더](https://github.com/huggingface/diffusers/tree/main/examples/controlnet)으로 이동합니다.
+
+```bash
+cd examples/controlnet
+```
+
+이제 실행하세요:
+
+```bash
+pip install -r requirements.txt
+```
+
+[🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화 합니다:
+
+```bash
+accelerate config
+```
+
+혹은 여러분의 환경이 무엇인지 몰라도 기본적인 🤗Accelerate 구성으로 초기화할 수 있습니다:
+
+```bash
+accelerate config default
+```
+
+혹은 당신의 환경이 노트북 같은 상호작용하는 쉘을 지원하지 않는다면, 아래의 코드로 초기화 할 수 있습니다:
+
+```python
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+## 원을 채우는 데이터셋
+
+원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
+
+우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
+
+자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
+
+## 학습
+
+이 학습에 사용될 다음 이미지들을 다운로드하세요:
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+`MODEL_NAME` 환경 변수 (Hub 모델 리포지토리 아이디 혹은 모델 가중치가 있는 디렉토리로 가는 주소)를 명시하고 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인자로 환경변수를 보냅니다.
+
+학습 스크립트는 당신의 리포지토리에 `diffusion_pytorch_model.bin` 파일을 생성하고 저장합니다.
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=4 \
+ --push_to_hub
+```
+
+이 기본적인 설정으로는 ~38GB VRAM이 필요합니다.
+
+기본적으로 학습 스크립트는 결과를 텐서보드에 기록합니다. 가중치(weight)와 편향(bias)을 사용하기 위해 `--report_to wandb` 를 전달합니다.
+
+더 작은 batch(배치) 크기로 gradient accumulation(기울기 누적)을 하면 학습 요구사항을 ~20 GB VRAM으로 줄일 수 있습니다.
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --push_to_hub
+```
+
+## 여러개 GPU로 학습하기
+
+`accelerate` 은 seamless multi-GPU 학습을 고려합니다. `accelerate`과 함께 분산된 학습을 실행하기 위해 [여기](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+의 설명을 확인하세요. 아래는 예시 명령어입니다:
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=4 \
+ --mixed_precision="fp16" \
+ --tracker_project_name="controlnet-demo" \
+ --report_to=wandb \
+ --push_to_hub
+```
+
+## 예시 결과
+
+#### 배치 사이즈 8로 300 스텝 이후:
+
+| | |
+|-------------------|:-------------------------:|
+| | 푸른 배경과 빨간 원 |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![푸른 배경과 빨간 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_300_steps.png) |
+| | 갈색 꽃 배경과 청록색 원 |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![갈색 꽃 배경과 청록색 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_300_steps.png) |
+
+#### 배치 사이즈 8로 6000 스텝 이후:
+
+| | |
+|-------------------|:-------------------------:|
+| | 푸른 배경과 빨간 원 |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![푸른 배경과 빨간 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_6000_steps.png) |
+| | 갈색 꽃 배경과 청록색 원 |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![갈색 꽃 배경과 청록색 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_6000_steps.png) |
+
+## 16GB GPU에서 학습하기
+
+16GB GPU에서 학습하기 위해 다음의 최적화를 진행하세요:
+
+- 기울기 체크포인트 저장하기
+- bitsandbyte의 [8-bit optimizer](https://github.com/TimDettmers/bitsandbytes#requirements--installation)가 설치되지 않았다면 링크에 연결된 설명서를 보세요.
+
+이제 학습 스크립트를 시작할 수 있습니다:
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --push_to_hub
+```
+
+## 12GB GPU에서 학습하기
+
+12GB GPU에서 실행하기 위해 다음의 최적화를 진행하세요:
+
+- 기울기 체크포인트 저장하기
+- bitsandbyte의 8-bit [optimizer](https://github.com/TimDettmers/bitsandbytes#requirements--installation)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)
+- [xFormers](https://huggingface.co/docs/diffusers/training/optimization/xformers)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)
+- 기울기를 `None`으로 설정
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --push_to_hub
+```
+
+`pip install xformers`으로 `xformers`을 확실히 설치하고 `enable_xformers_memory_efficient_attention`을 사용하세요.
+
+## 8GB GPU에서 학습하기
+
+우리는 ControlNet을 지원하기 위한 DeepSpeed를 철저하게 테스트하지 않았습니다. 환경설정이 메모리를 저장할 때,
+그 환경이 성공적으로 학습했는지를 확정하지 않았습니다. 성공한 학습 실행을 위해 설정을 변경해야 할 가능성이 높습니다.
+
+8GB GPU에서 실행하기 위해 다음의 최적화를 진행하세요:
+
+- 기울기 체크포인트 저장하기
+- bitsandbyte의 8-bit [optimizer](https://github.com/TimDettmers/bitsandbytes#requirements--installation)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)
+- [xFormers](https://huggingface.co/docs/diffusers/training/optimization/xformers)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)
+- 기울기를 `None`으로 설정
+- DeepSpeed stage 2 변수와 optimizer 없에기
+- fp16 혼합 정밀도(precision)
+
+[DeepSpeed](https://www.deepspeed.ai/)는 CPU 또는 NVME로 텐서를 VRAM에서 오프로드할 수 있습니다.
+이를 위해서 훨씬 더 많은 RAM(약 25 GB)가 필요합니다.
+
+DeepSpeed stage 2를 활성화하기 위해서 `accelerate config`로 환경을 구성해야합니다.
+
+구성(configuration) 파일은 이런 모습이어야 합니다:
+
+```yaml
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 4
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+```
+
+<팁>
+
+[문서](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)를 더 많은 DeepSpeed 설정 옵션을 위해 보세요.
+
+<팁>
+
+기본 Adam optimizer를 DeepSpeed'의 Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` 으로 바꾸면 상당한 속도 향상을 이룰수 있지만,
+Pytorch와 같은 버전의 CUDA toolchain이 필요합니다. 8-비트 optimizer는 현재 DeepSpeed와
+호환되지 않는 것 같습니다.
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --mixed_precision fp16 \
+ --push_to_hub
+```
+
+## 추론
+
+학습된 모델은 [`StableDiffusionControlNetPipeline`]과 함께 실행될 수 있습니다.
+`base_model_path`와 `controlnet_path` 에 값을 지정하세요 `--pretrained_model_name_or_path` 와
+`--output_dir` 는 학습 스크립트에 개별적으로 지정됩니다.
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "path to model"
+controlnet_path = "path to controlnet"
+
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
+)
+
+# 더 빠른 스케줄러와 메모리 최적화로 diffusion 프로세스 속도 올리기
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# xformers가 설치되지 않으면 아래 줄을 삭제하기
+pipe.enable_xformers_memory_efficient_attention()
+
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+# 이미지 생성하기
+generator = torch.manual_seed(0)
+image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
+
+image.save("./output.png")
+```
diff --git a/diffusers/docs/source/ko/training/create_dataset.md b/diffusers/docs/source/ko/training/create_dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..0e5f5018f4c5b7ad3e397afb99ad1821d6a1492a
--- /dev/null
+++ b/diffusers/docs/source/ko/training/create_dataset.md
@@ -0,0 +1,98 @@
+# 학습을 위한 데이터셋 만들기
+
+[Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) 에는 모델 교육을 위한 많은 데이터셋이 있지만,
+관심이 있거나 사용하고 싶은 데이터셋을 찾을 수 없는 경우 🤗 [Datasets](hf.co/docs/datasets) 라이브러리를 사용하여 데이터셋을 만들 수 있습니다.
+데이터셋 구조는 모델을 학습하려는 작업에 따라 달라집니다.
+가장 기본적인 데이터셋 구조는 unconditional 이미지 생성과 같은 작업을 위한 이미지 디렉토리입니다.
+또 다른 데이터셋 구조는 이미지 디렉토리와 text-to-image 생성과 같은 작업에 해당하는 텍스트 캡션이 포함된 텍스트 파일일 수 있습니다.
+
+이 가이드에는 파인 튜닝할 데이터셋을 만드는 두 가지 방법을 소개합니다:
+
+- 이미지 폴더를 `--train_data_dir` 인수에 제공합니다.
+- 데이터셋을 Hub에 업로드하고 데이터셋 리포지토리 id를 `--dataset_name` 인수에 전달합니다.
+
+
+
+💡 학습에 사용할 이미지 데이터셋을 만드는 방법에 대한 자세한 내용은 [이미지 데이터셋 만들기](https://huggingface.co/docs/datasets/image_dataset) 가이드를 참고하세요.
+
+
+
+## 폴더 형태로 데이터셋 구축하기
+
+Unconditional 생성을 위해 이미지 폴더로 자신의 데이터셋을 구축할 수 있습니다.
+학습 스크립트는 🤗 Datasets의 [ImageFolder](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder) 빌더를 사용하여
+자동으로 폴더에서 데이터셋을 구축합니다. 디렉토리 구조는 다음과 같아야 합니다 :
+
+```bash
+data_dir/xxx.png
+data_dir/xxy.png
+data_dir/[...]/xxz.png
+```
+
+데이터셋 디렉터리의 경로를 `--train_data_dir` 인수로 전달한 다음 학습을 시작할 수 있습니다:
+
+```bash
+accelerate launch train_unconditional.py \
+ # argument로 폴더 지정하기 \
+ --train_data_dir \
+
+```
+
+## Hub에 데이터 올리기
+
+
+
+💡 데이터셋을 만들고 Hub에 업로드하는 것에 대한 자세한 내용은 [🤗 Datasets을 사용한 이미지 검색](https://huggingface.co/blog/image-search-datasets) 게시물을 참고하세요.
+
+
+
+PIL 인코딩된 이미지가 포함된 `이미지` 열을 생성하는 [이미지 폴더](https://huggingface.co/docs/datasets/image_load#imagefolder) 기능을 사용하여 데이터셋 생성을 시작합니다.
+
+`data_dir` 또는 `data_files` 매개 변수를 사용하여 데이터셋의 위치를 지정할 수 있습니다.
+`data_files` 매개변수는 특정 파일을 `train` 이나 `test` 로 분리한 데이터셋에 매핑하는 것을 지원합니다:
+
+```python
+from datasets import load_dataset
+
+# 예시 1: 로컬 폴더
+dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
+
+# 예시 2: 로컬 파일 (지원 포맷 : tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
+
+# 예시 3: 원격 파일 (지원 포맷 : tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset(
+ "imagefolder",
+ data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",
+)
+
+# 예시 4: 여러개로 분할
+dataset = load_dataset(
+ "imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]}
+)
+```
+
+[push_to_hub(https://huggingface.co/docs/datasets/v2.13.1/en/package_reference/main_classes#datasets.Dataset.push_to_hub) 을 사용해서 Hub에 데이터셋을 업로드 합니다:
+
+```python
+# 터미널에서 huggingface-cli login 커맨드를 이미 실행했다고 가정합니다
+dataset.push_to_hub("name_of_your_dataset")
+
+# 개인 repo로 push 하고 싶다면, `private=True` 을 추가하세요:
+dataset.push_to_hub("name_of_your_dataset", private=True)
+```
+
+이제 데이터셋 이름을 `--dataset_name` 인수에 전달하여 데이터셋을 학습에 사용할 수 있습니다:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
+ --dataset_name="name_of_your_dataset" \
+
+```
+
+## 다음 단계
+
+데이터셋을 생성했으니 이제 학습 스크립트의 `train_data_dir` (데이터셋이 로컬이면) 혹은 `dataset_name` (Hub에 데이터셋을 올렸으면) 인수에 연결할 수 있습니다.
+
+다음 단계에서는 데이터셋을 사용하여 [unconditional 생성](https://huggingface.co/docs/diffusers/v0.18.2/en/training/unconditional_training) 또는 [텍스트-이미지 생성](https://huggingface.co/docs/diffusers/training/text2image)을 위한 모델을 학습시켜보세요!
diff --git a/diffusers/docs/source/ko/training/custom_diffusion.md b/diffusers/docs/source/ko/training/custom_diffusion.md
new file mode 100644
index 0000000000000000000000000000000000000000..0923c046cc6f6ab66edd0ee6cc3920f87cdc82b7
--- /dev/null
+++ b/diffusers/docs/source/ko/training/custom_diffusion.md
@@ -0,0 +1,300 @@
+
+
+# 커스텀 Diffusion 학습 예제
+
+[커스텀 Diffusion](https://arxiv.org/abs/2212.04488)은 피사체의 이미지 몇 장(4~5장)만 주어지면 Stable Diffusion처럼 text-to-image 모델을 커스터마이징하는 방법입니다.
+'train_custom_diffusion.py' 스크립트는 학습 과정을 구현하고 이를 Stable Diffusion에 맞게 조정하는 방법을 보여줍니다.
+
+이 교육 사례는 [Nupur Kumari](https://nupurkmr9.github.io/)가 제공하였습니다. (Custom Diffusion의 저자 중 한명).
+
+## 로컬에서 PyTorch로 실행하기
+
+### Dependencies 설치하기
+
+스크립트를 실행하기 전에 라이브러리의 학습 dependencies를 설치해야 합니다:
+
+**중요**
+
+예제 스크립트의 최신 버전을 성공적으로 실행하려면 **소스로부터 설치**하는 것을 매우 권장하며, 예제 스크립트를 자주 업데이트하는 만큼 일부 예제별 요구 사항을 설치하고 설치를 최신 상태로 유지하는 것이 좋습니다. 이를 위해 새 가상 환경에서 다음 단계를 실행하세요:
+
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+[example folder](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion)로 cd하여 이동하세요.
+
+```
+cd examples/custom_diffusion
+```
+
+이제 실행
+
+```bash
+pip install -r requirements.txt
+pip install clip-retrieval
+```
+
+그리고 [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화:
+
+```bash
+accelerate config
+```
+
+또는 사용자 환경에 대한 질문에 답하지 않고 기본 가속 구성을 사용하려면 다음과 같이 하세요.
+
+```bash
+accelerate config default
+```
+
+또는 사용 중인 환경이 대화형 셸을 지원하지 않는 경우(예: jupyter notebook)
+
+```python
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+### 고양이 예제 😺
+
+이제 데이터셋을 가져옵니다. [여기](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip)에서 데이터셋을 다운로드하고 압축을 풉니다. 직접 데이터셋을 사용하려면 [학습용 데이터셋 생성하기](create_dataset) 가이드를 참고하세요.
+
+또한 'clip-retrieval'을 사용하여 200개의 실제 이미지를 수집하고, regularization으로서 이를 학습 데이터셋의 타겟 이미지와 결합합니다. 이렇게 하면 주어진 타겟 이미지에 대한 과적합을 방지할 수 있습니다. 다음 플래그를 사용하면 `prior_loss_weight=1.`로 `prior_preservation`, `real_prior` regularization을 활성화할 수 있습니다.
+클래스_프롬프트`는 대상 이미지와 동일한 카테고리 이름이어야 합니다. 수집된 실제 이미지에는 `class_prompt`와 유사한 텍스트 캡션이 있습니다. 검색된 이미지는 `class_data_dir`에 저장됩니다. 생성된 이미지를 regularization으로 사용하기 위해 `real_prior`를 비활성화할 수 있습니다. 실제 이미지를 수집하려면 훈련 전에 이 명령을 먼저 사용하십시오.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
+```
+
+**___참고: [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 모델을 사용하는 경우 '해상도'를 768로 변경하세요.___**
+
+스크립트는 모델 체크포인트와 `pytorch_custom_diffusion_weights.bin` 파일을 생성하여 저장소에 저장합니다.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="./data/cat"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="cat" --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr --hflip \
+ --modifier_token "" \
+ --push_to_hub
+```
+
+**더 낮은 VRAM 요구 사항(GPU당 16GB)으로 더 빠르게 훈련하려면 `--enable_xformers_memory_efficient_attention`을 사용하세요. 설치 방법은 [가이드](https://github.com/facebookresearch/xformers)를 따르세요.**
+
+가중치 및 편향(`wandb`)을 사용하여 실험을 추적하고 중간 결과를 저장하려면(강력히 권장합니다) 다음 단계를 따르세요:
+
+* `wandb` 설치: `pip install wandb`.
+* 로그인 : `wandb login`.
+* 그런 다음 트레이닝을 시작하는 동안 `validation_prompt`를 지정하고 `report_to`를 `wandb`로 설정합니다. 다음과 같은 관련 인수를 구성할 수도 있습니다:
+ * `num_validation_images`
+ * `validation_steps`
+
+```bash
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="cat" --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr --hflip \
+ --modifier_token "" \
+ --validation_prompt=" cat sitting in a bucket" \
+ --report_to="wandb" \
+ --push_to_hub
+```
+
+다음은 [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau)의 예시이며, 여러 학습 세부 정보와 함께 중간 결과들을 확인할 수 있습니다.
+
+`--push_to_hub`를 지정하면 학습된 파라미터가 허깅 페이스 허브의 리포지토리에 푸시됩니다. 다음은 [예제 리포지토리](https://huggingface.co/sayakpaul/custom-diffusion-cat)입니다.
+
+### 멀티 컨셉에 대한 학습 🐱🪵
+
+[this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)와 유사하게 각 컨셉에 대한 정보가 포함된 [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) 파일을 제공합니다.
+
+실제 이미지를 수집하려면 json 파일의 각 컨셉에 대해 이 명령을 실행합니다.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
+```
+
+그럼 우리는 학습시킬 준비가 되었습니다!
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --concepts_list=./concept_list.json \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --num_class_images=200 \
+ --scale_lr --hflip \
+ --modifier_token "+" \
+ --push_to_hub
+```
+
+다음은 [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg)의 예시이며, 다른 학습 세부 정보와 함께 중간 결과들을 확인할 수 있습니다.
+
+### 사람 얼굴에 대한 학습
+
+사람 얼굴에 대한 파인튜닝을 위해 다음과 같은 설정이 더 효과적이라는 것을 확인했습니다: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, `freeze_model=crossattn`을 최소 15~20개의 이미지로 설정합니다.
+
+실제 이미지를 수집하려면 훈련 전에 이 명령을 먼저 사용하십시오.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
+```
+
+이제 학습을 시작하세요!
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="path-to-images"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_person/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="person" --num_class_images=200 \
+ --instance_prompt="photo of a person" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=5e-6 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1000 \
+ --scale_lr --hflip --noaug \
+ --freeze_model crossattn \
+ --modifier_token "" \
+ --enable_xformers_memory_efficient_attention \
+ --push_to_hub
+```
+
+## 추론
+
+위 프롬프트를 사용하여 모델을 학습시킨 후에는 아래 프롬프트를 사용하여 추론을 실행할 수 있습니다. 프롬프트에 'modifier token'(예: 위 예제에서는 \)을 반드시 포함해야 합니다.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
+pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion("path-to-save-model", weight_name=".bin")
+
+image = pipe(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+허브 리포지토리에서 이러한 매개변수를 직접 로드할 수 있습니다:
+
+```python
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+model_id = "sayakpaul/custom-diffusion-cat"
+card = RepoCard.load(model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
+pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipe(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+다음은 여러 컨셉으로 추론을 수행하는 예제입니다:
+
+```python
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
+card = RepoCard.load(model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
+pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipe(
+ "the cat sculpture in the style of a wooden pot",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("multi-subject.png")
+```
+
+여기서 '고양이'와 '나무 냄비'는 여러 컨셉을 말합니다.
+
+### 학습된 체크포인트에서 추론하기
+
+`--checkpointing_steps` 인수를 사용한 경우 학습 과정에서 저장된 전체 체크포인트 중 하나에서 추론을 수행할 수도 있습니다.
+
+## Grads를 None으로 설정
+
+더 많은 메모리를 절약하려면 스크립트에 `--set_grads_to_none` 인수를 전달하세요. 이렇게 하면 성적이 0이 아닌 없음으로 설정됩니다. 그러나 특정 동작이 변경되므로 문제가 발생하면 이 인수를 제거하세요.
+
+자세한 정보: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+## 실험 결과
+
+실험에 대한 자세한 내용은 [당사 웹페이지](https://www.cs.cmu.edu/~custom-diffusion/)를 참조하세요.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/distributed_inference.md b/diffusers/docs/source/ko/training/distributed_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..826a7bbff352ee87f252d1e2ffeb0060a5269cf6
--- /dev/null
+++ b/diffusers/docs/source/ko/training/distributed_inference.md
@@ -0,0 +1,92 @@
+# 여러 GPU를 사용한 분산 추론
+
+분산 설정에서는 여러 개의 프롬프트를 동시에 생성할 때 유용한 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) 또는 [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)를 사용하여 여러 GPU에서 추론을 실행할 수 있습니다.
+
+이 가이드에서는 분산 추론을 위해 🤗 Accelerate와 PyTorch Distributed를 사용하는 방법을 보여드립니다.
+
+## 🤗 Accelerate
+
+🤗 [Accelerate](https://huggingface.co/docs/accelerate/index)는 분산 설정에서 추론을 쉽게 훈련하거나 실행할 수 있도록 설계된 라이브러리입니다. 분산 환경 설정 프로세스를 간소화하여 PyTorch 코드에 집중할 수 있도록 해줍니다.
+
+시작하려면 Python 파일을 생성하고 [`accelerate.PartialState`]를 초기화하여 분산 환경을 생성하면, 설정이 자동으로 감지되므로 `rank` 또는 `world_size`를 명시적으로 정의할 필요가 없습니다. ['DiffusionPipeline`]을 `distributed_state.device`로 이동하여 각 프로세스에 GPU를 할당합니다.
+
+이제 컨텍스트 관리자로 [`~accelerate.PartialState.split_between_processes`] 유틸리티를 사용하여 프로세스 수에 따라 프롬프트를 자동으로 분배합니다.
+
+
+```py
+from accelerate import PartialState
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+distributed_state = PartialState()
+pipeline.to(distributed_state.device)
+
+with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
+ result = pipeline(prompt).images[0]
+ result.save(f"result_{distributed_state.process_index}.png")
+```
+
+Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
+
+```bash
+accelerate launch run_distributed.py --num_processes=2
+```
+
+자세한 내용은 [🤗 Accelerate를 사용한 분산 추론](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 가이드를 참조하세요.
+
+
+
+## Pytoerch 분산
+
+PyTorch는 데이터 병렬 처리를 가능하게 하는 [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)을 지원합니다.
+
+시작하려면 Python 파일을 생성하고 `torch.distributed` 및 `torch.multiprocessing`을 임포트하여 분산 프로세스 그룹을 설정하고 각 GPU에서 추론용 프로세스를 생성합니다. 그리고 [`DiffusionPipeline`]도 초기화해야 합니다:
+
+확산 파이프라인을 `rank`로 이동하고 `get_rank`를 사용하여 각 프로세스에 GPU를 할당하면 각 프로세스가 다른 프롬프트를 처리합니다:
+
+```py
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from diffusers import DiffusionPipeline
+
+sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+```
+
+사용할 백엔드 유형, 현재 프로세스의 `rank`, `world_size` 또는 참여하는 프로세스 수로 분산 환경 생성을 처리하는 함수[`init_process_group`]를 만들어 추론을 실행해야 합니다.
+
+2개의 GPU에서 추론을 병렬로 실행하는 경우 `world_size`는 2입니다.
+
+```py
+def run_inference(rank, world_size):
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+ sd.to(rank)
+
+ if torch.distributed.get_rank() == 0:
+ prompt = "a dog"
+ elif torch.distributed.get_rank() == 1:
+ prompt = "a cat"
+
+ image = sd(prompt).images[0]
+ image.save(f"./{'_'.join(prompt)}.png")
+```
+
+분산 추론을 실행하려면 [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)을 호출하여 `world_size`에 정의된 GPU 수에 대해 `run_inference` 함수를 실행합니다:
+
+```py
+def main():
+ world_size = 2
+ mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
+
+
+if __name__ == "__main__":
+ main()
+```
+
+추론 스크립트를 완료했으면 `--nproc_per_node` 인수를 사용하여 사용할 GPU 수를 지정하고 `torchrun`을 호출하여 스크립트를 실행합니다:
+
+```bash
+torchrun run_distributed.py --nproc_per_node=2
+```
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/dreambooth.md b/diffusers/docs/source/ko/training/dreambooth.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d76731933abafacaf7f35a637c8d8f222b9dd98
--- /dev/null
+++ b/diffusers/docs/source/ko/training/dreambooth.md
@@ -0,0 +1,474 @@
+
+
+# DreamBooth
+
+[DreamBooth](https://arxiv.org/abs/2208.12242)는 한 주제에 대한 적은 이미지(3~5개)만으로도 stable diffusion과 같이 text-to-image 모델을 개인화할 수 있는 방법입니다. 이를 통해 모델은 다양한 장면, 포즈 및 장면(뷰)에서 피사체에 대해 맥락화(contextualized)된 이미지를 생성할 수 있습니다.
+
+![프로젝트 블로그에서의 DreamBooth 예시](https://dreambooth.github.io/DreamBooth_files/teaser_static.jpg)
+에서의 Dreambooth 예시 project's blog.
+
+
+이 가이드는 다양한 GPU, Flax 사양에 대해 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 DreamBooth를 파인튜닝하는 방법을 보여줍니다. 더 깊이 파고들어 작동 방식을 확인하는 데 관심이 있는 경우, 이 가이드에 사용된 DreamBooth의 모든 학습 스크립트를 [여기](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)에서 찾을 수 있습니다.
+
+스크립트를 실행하기 전에 라이브러리의 학습에 필요한 dependencies를 설치해야 합니다. 또한 `main` GitHub 브랜치에서 🧨 Diffusers를 설치하는 것이 좋습니다.
+
+```bash
+pip install git+https://github.com/huggingface/diffusers
+pip install -U -r diffusers/examples/dreambooth/requirements.txt
+```
+
+xFormers는 학습에 필요한 요구 사항은 아니지만, 가능하면 [설치](../optimization/xformers)하는 것이 좋습니다. 학습 속도를 높이고 메모리 사용량을 줄일 수 있기 때문입니다.
+
+모든 dependencies을 설정한 후 다음을 사용하여 [🤗 Accelerate](https://github.com/huggingface/accelerate/) 환경을 다음과 같이 초기화합니다:
+
+```bash
+accelerate config
+```
+
+별도 설정 없이 기본 🤗 Accelerate 환경을 설치하려면 다음을 실행합니다:
+
+```bash
+accelerate config default
+```
+
+또는 현재 환경이 노트북과 같은 대화형 셸을 지원하지 않는 경우 다음을 사용할 수 있습니다:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+## 파인튜닝
+
+
+
+DreamBooth 파인튜닝은 하이퍼파라미터에 매우 민감하고 과적합되기 쉽습니다. 적절한 하이퍼파라미터를 선택하는 데 도움이 되도록 다양한 권장 설정이 포함된 [심층 분석](https://huggingface.co/blog/dreambooth)을 살펴보는 것이 좋습니다.
+
+
+
+
+
+[몇 장의 강아지 이미지들](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ)로 DreamBooth를 시도해봅시다.
+이를 다운로드해 디렉터리에 저장한 다음 `INSTANCE_DIR` 환경 변수를 해당 경로로 설정합니다:
+
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path_to_training_images"
+export OUTPUT_DIR="path_to_saved_model"
+```
+
+그런 다음, 다음 명령을 사용하여 학습 스크립트를 실행할 수 있습니다 (전체 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)에서 찾을 수 있습니다):
+
+```bash
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400
+```
+
+
+
+TPU에 액세스할 수 있거나 더 빠르게 훈련하고 싶다면 [Flax 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_flax.py)를 사용해 볼 수 있습니다. Flax 학습 스크립트는 gradient checkpointing 또는 gradient accumulation을 지원하지 않으므로, 메모리가 30GB 이상인 GPU가 필요합니다.
+
+스크립트를 실행하기 전에 요구 사항이 설치되어 있는지 확인하십시오.
+
+```bash
+pip install -U -r requirements.txt
+```
+
+그러면 다음 명령어로 학습 스크립트를 실행시킬 수 있습니다:
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --max_train_steps=400
+```
+
+
+
+### Prior-preserving(사전 보존) loss를 사용한 파인튜닝
+
+과적합과 language drift를 방지하기 위해 사전 보존이 사용됩니다(관심이 있는 경우 [논문](https://arxiv.org/abs/2208.12242)을 참조하세요). 사전 보존을 위해 동일한 클래스의 다른 이미지를 학습 프로세스의 일부로 사용합니다. 좋은 점은 Stable Diffusion 모델 자체를 사용하여 이러한 이미지를 생성할 수 있다는 것입니다! 학습 스크립트는 생성된 이미지를 우리가 지정한 로컬 경로에 저장합니다.
+
+저자들에 따르면 사전 보존을 위해 `num_epochs * num_samples`개의 이미지를 생성하는 것이 좋습니다. 200-300개에서 대부분 잘 작동합니다.
+
+
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path_to_training_images"
+export CLASS_DIR="path_to_class_images"
+export OUTPUT_DIR="path_to_saved_model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+
+## 텍스트 인코더와 and UNet로 파인튜닝하기
+
+해당 스크립트를 사용하면 `unet`과 함께 `text_encoder`를 파인튜닝할 수 있습니다. 실험에서(자세한 내용은 [🧨 Diffusers를 사용해 DreamBooth로 Stable Diffusion 학습하기](https://huggingface.co/blog/dreambooth) 게시물을 확인하세요), 특히 얼굴 이미지를 생성할 때 훨씬 더 나은 결과를 얻을 수 있습니다.
+
+
+
+텍스트 인코더를 학습시키려면 추가 메모리가 필요해 16GB GPU로는 동작하지 않습니다. 이 옵션을 사용하려면 최소 24GB VRAM이 필요합니다.
+
+
+
+`--train_text_encoder` 인수를 학습 스크립트에 전달하여 `text_encoder` 및 `unet`을 파인튜닝할 수 있습니다:
+
+
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path_to_training_images"
+export CLASS_DIR="path_to_class_images"
+export OUTPUT_DIR="path_to_saved_model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=2e-6 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+
+## LoRA로 파인튜닝하기
+
+DreamBooth에서 대규모 모델의 학습을 가속화하기 위한 파인튜닝 기술인 LoRA(Low-Rank Adaptation of Large Language Models)를 사용할 수 있습니다. 자세한 내용은 [LoRA 학습](training/lora#dreambooth) 가이드를 참조하세요.
+
+### 학습 중 체크포인트 저장하기
+
+Dreambooth로 훈련하는 동안 과적합하기 쉬우므로, 때때로 학습 중에 정기적인 체크포인트를 저장하는 것이 유용합니다. 중간 체크포인트 중 하나가 최종 모델보다 더 잘 작동할 수 있습니다! 체크포인트 저장 기능을 활성화하려면 학습 스크립트에 다음 인수를 전달해야 합니다:
+
+```bash
+ --checkpointing_steps=500
+```
+
+이렇게 하면 `output_dir`의 하위 폴더에 전체 학습 상태가 저장됩니다. 하위 폴더 이름은 접두사 `checkpoint-`로 시작하고 지금까지 수행된 step 수입니다. 예시로 `checkpoint-1500`은 1500 학습 step 후에 저장된 체크포인트입니다.
+
+#### 저장된 체크포인트에서 훈련 재개하기
+
+저장된 체크포인트에서 훈련을 재개하려면, `--resume_from_checkpoint` 인수를 전달한 다음 사용할 체크포인트의 이름을 지정하면 됩니다. 특수 문자열 `"latest"`를 사용하여 저장된 마지막 체크포인트(즉, step 수가 가장 많은 체크포인트)에서 재개할 수도 있습니다. 예를 들어 다음은 1500 step 후에 저장된 체크포인트에서부터 학습을 재개합니다:
+
+```bash
+ --resume_from_checkpoint="checkpoint-1500"
+```
+
+원하는 경우 일부 하이퍼파라미터를 조정할 수 있습니다.
+
+#### 저장된 체크포인트를 사용하여 추론 수행하기
+
+저장된 체크포인트는 훈련 재개에 적합한 형식으로 저장됩니다. 여기에는 모델 가중치뿐만 아니라 옵티마이저, 데이터 로더 및 학습률의 상태도 포함됩니다.
+
+**`"accelerate>=0.16.0"`**이 설치된 경우 다음 코드를 사용하여 중간 체크포인트에서 추론을 실행합니다.
+
+```python
+from diffusers import DiffusionPipeline, UNet2DConditionModel
+from transformers import CLIPTextModel
+import torch
+
+# 학습에 사용된 것과 동일한 인수(model, revision)로 파이프라인을 불러옵니다.
+model_id = "CompVis/stable-diffusion-v1-4"
+
+unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet")
+
+# `args.train_text_encoder`로 학습한 경우면 텍스트 인코더를 꼭 불러오세요
+text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder")
+
+pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16)
+pipeline.to("cuda")
+
+# 추론을 수행하거나 저장하거나, 허브에 푸시합니다.
+pipeline.save_pretrained("dreambooth-pipeline")
+```
+
+If you have **`"accelerate<0.16.0"`** installed, you need to convert it to an inference pipeline first:
+
+```python
+from accelerate import Accelerator
+from diffusers import DiffusionPipeline
+
+# 학습에 사용된 것과 동일한 인수(model, revision)로 파이프라인을 불러옵니다.
+model_id = "CompVis/stable-diffusion-v1-4"
+pipeline = DiffusionPipeline.from_pretrained(model_id)
+
+accelerator = Accelerator()
+
+# 초기 학습에 `--train_text_encoder`가 사용된 경우 text_encoder를 사용합니다.
+unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder)
+
+# 체크포인트 경로로부터 상태를 복원합니다. 여기서는 절대 경로를 사용해야 합니다.
+accelerator.load_state("/sddata/dreambooth/daruma-v2-1/checkpoint-100")
+
+# unwrapped 모델로 파이프라인을 다시 빌드합니다.(.unet and .text_encoder로의 할당도 작동해야 합니다)
+pipeline = DiffusionPipeline.from_pretrained(
+ model_id,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+)
+
+# 추론을 수행하거나 저장하거나, 허브에 푸시합니다.
+pipeline.save_pretrained("dreambooth-pipeline")
+```
+
+## 각 GPU 용량에서의 최적화
+
+하드웨어에 따라 16GB에서 8GB까지 GPU에서 DreamBooth를 최적화하는 몇 가지 방법이 있습니다!
+
+### xFormers
+
+[xFormers](https://github.com/facebookresearch/xformers)는 Transformers를 최적화하기 위한 toolbox이며, 🧨 Diffusers에서 사용되는[memory-efficient attention](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops) 메커니즘을 포함하고 있습니다. [xFormers를 설치](./optimization/xformers)한 다음 학습 스크립트에 다음 인수를 추가합니다:
+
+```bash
+ --enable_xformers_memory_efficient_attention
+```
+
+xFormers는 Flax에서 사용할 수 없습니다.
+
+### 그래디언트 없음으로 설정
+
+메모리 사용량을 줄일 수 있는 또 다른 방법은 [기울기 설정](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html)을 0 대신 `None`으로 하는 것입니다. 그러나 이로 인해 특정 동작이 변경될 수 있으므로 문제가 발생하면 이 인수를 제거해 보십시오. 학습 스크립트에 다음 인수를 추가하여 그래디언트를 `None`으로 설정합니다.
+
+```bash
+ --set_grads_to_none
+```
+
+### 16GB GPU
+
+Gradient checkpointing과 [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)의 8비트 옵티마이저의 도움으로, 16GB GPU에서 dreambooth를 훈련할 수 있습니다. bitsandbytes가 설치되어 있는지 확인하세요:
+
+```bash
+pip install bitsandbytes
+```
+
+그 다음, 학습 스크립트에 `--use_8bit_adam` 옵션을 명시합니다:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path_to_training_images"
+export CLASS_DIR="path_to_class_images"
+export OUTPUT_DIR="path_to_saved_model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### 12GB GPU
+
+12GB GPU에서 DreamBooth를 실행하려면 gradient checkpointing, 8비트 옵티마이저, xFormers를 활성화하고 그래디언트를 `None`으로 설정해야 합니다.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### 8GB GPU에서 학습하기
+
+8GB GPU에 대해서는 [DeepSpeed](https://www.deepspeed.ai/)를 사용해 일부 텐서를 VRAM에서 CPU 또는 NVME로 오프로드하여 더 적은 GPU 메모리로 학습할 수도 있습니다.
+
+🤗 Accelerate 환경을 구성하려면 다음 명령을 실행하세요:
+
+```bash
+accelerate config
+```
+
+환경 구성 중에 DeepSpeed를 사용할 것을 확인하세요.
+그러면 DeepSpeed stage 2, fp16 혼합 정밀도를 결합하고 모델 매개변수와 옵티마이저 상태를 모두 CPU로 오프로드하면 8GB VRAM 미만에서 학습할 수 있습니다.
+단점은 더 많은 시스템 RAM(약 25GB)이 필요하다는 것입니다. 추가 구성 옵션은 [DeepSpeed 문서](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)를 참조하세요.
+
+또한 기본 Adam 옵티마이저를 DeepSpeed의 최적화된 Adam 버전으로 변경해야 합니다.
+이는 상당한 속도 향상을 위한 Adam인 [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu)입니다.
+`DeepSpeedCPUAdam`을 활성화하려면 시스템의 CUDA toolchain 버전이 PyTorch와 함께 설치된 것과 동일해야 합니다.
+
+8비트 옵티마이저는 현재 DeepSpeed와 호환되지 않는 것 같습니다.
+
+다음 명령으로 학습을 시작합니다:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path_to_training_images"
+export CLASS_DIR="path_to_class_images"
+export OUTPUT_DIR="path_to_saved_model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --sample_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --mixed_precision=fp16
+```
+
+## 추론
+
+모델을 학습한 후에는, 모델이 저장된 경로를 지정해 [`StableDiffusionPipeline`]로 추론을 수행할 수 있습니다. 프롬프트에 학습에 사용된 특수 `식별자`(이전 예시의 `sks`)가 포함되어 있는지 확인하세요.
+
+**`"accelerate>=0.16.0"`**이 설치되어 있는 경우 다음 코드를 사용하여 중간 체크포인트에서 추론을 실행할 수 있습니다:
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path_to_saved_model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
+```
+
+[저장된 학습 체크포인트](#inference-from-a-saved-checkpoint)에서도 추론을 실행할 수도 있습니다.
diff --git a/diffusers/docs/source/ko/training/instructpix2pix.md b/diffusers/docs/source/ko/training/instructpix2pix.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d80ef6328fc7355dbd9c19108f2aef3ece6ea4f
--- /dev/null
+++ b/diffusers/docs/source/ko/training/instructpix2pix.md
@@ -0,0 +1,211 @@
+
+
+# InstructPix2Pix
+
+[InstructPix2Pix](https://arxiv.org/abs/2211.09800)는 text-conditioned diffusion 모델이 한 이미지에 편집을 따를 수 있도록 파인튜닝하는 방법입니다. 이 방법을 사용하여 파인튜닝된 모델은 다음을 입력으로 사용합니다:
+
+
+
+
+
+출력은 입력 이미지에 편집 지시가 반영된 "수정된" 이미지입니다:
+
+
+
+
+
+`train_instruct_pix2pix.py` 스크립트([여기](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py)에서 찾을 수 있습니다.)는 학습 절차를 설명하고 Stable Diffusion에 적용할 수 있는 방법을 보여줍니다.
+
+
+*** `train_instruct_pix2pix.py`는 [원래 구현](https://github.com/timothybrooks/instruct-pix2pix)에 충실하면서 InstructPix2Pix 학습 절차를 구현하고 있지만, [소규모 데이터셋](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples)에서만 테스트를 했습니다. 이는 최종 결과에 영향을 끼칠 수 있습니다. 더 나은 결과를 위해, 더 큰 데이터셋에서 더 길게 학습하는 것을 권장합니다. [여기](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered)에서 InstructPix2Pix 학습을 위해 큰 데이터셋을 찾을 수 있습니다.
+***
+
+## PyTorch로 로컬에서 실행하기
+
+### 종속성(dependencies) 설치하기
+
+이 스크립트를 실행하기 전에, 라이브러리의 학습 종속성을 설치하세요:
+
+**중요**
+
+최신 버전의 예제 스크립트를 성공적으로 실행하기 위해, **원본으로부터 설치**하는 것과 예제 스크립트를 자주 업데이트하고 예제별 요구사항을 설치하기 때문에 최신 상태로 유지하는 것을 권장합니다. 이를 위해, 새로운 가상 환경에서 다음 스텝을 실행하세요:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+cd 명령어로 예제 폴더로 이동하세요.
+```bash
+cd examples/instruct_pix2pix
+```
+
+이제 실행하세요.
+```bash
+pip install -r requirements.txt
+```
+
+그리고 [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경에서 초기화하세요:
+
+```bash
+accelerate config
+```
+
+혹은 환경에 대한 질문 없이 기본적인 accelerate 구성을 사용하려면 다음을 실행하세요.
+
+```bash
+accelerate config default
+```
+
+혹은 사용 중인 환경이 notebook과 같은 대화형 쉘은 지원하지 않는 경우는 다음 절차를 따라주세요.
+
+```python
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+### 예시
+
+이전에 언급했듯이, 학습을 위해 [작은 데이터셋](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples)을 사용할 것입니다. 그 데이터셋은 InstructPix2Pix 논문에서 사용된 [원래의 데이터셋](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered)보다 작은 버전입니다. 자신의 데이터셋을 사용하기 위해, [학습을 위한 데이터셋 만들기](create_dataset) 가이드를 참고하세요.
+
+`MODEL_NAME` 환경 변수(허브 모델 레포지토리 또는 모델 가중치가 포함된 폴더 경로)를 지정하고 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인수에 전달합니다. `DATASET_ID`에 데이터셋 이름을 지정해야 합니다:
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+```
+
+지금, 학습을 실행할 수 있습니다. 스크립트는 레포지토리의 하위 폴더의 모든 구성요소(`feature_extractor`, `scheduler`, `text_encoder`, `unet` 등)를 저장합니다.
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+
+추가적으로, 가중치와 바이어스를 학습 과정에 모니터링하여 검증 추론을 수행하는 것을 지원합니다. `report_to="wandb"`와 이 기능을 사용할 수 있습니다:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
+ --validation_prompt="make the mountains snowy" \
+ --seed=42 \
+ --report_to=wandb \
+ --push_to_hub
+ ```
+
+모델 디버깅에 유용한 이 평가 방법 권장합니다. 이를 사용하기 위해 `wandb`를 설치하는 것을 주목해주세요. `pip install wandb`로 실행해 `wandb`를 설치할 수 있습니다.
+
+[여기](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), 몇 가지 평가 방법과 학습 파라미터를 포함하는 예시를 볼 수 있습니다.
+
+ ***참고: 원본 논문에서, 저자들은 256x256 이미지 해상도로 학습한 모델로 512x512와 같은 더 큰 해상도로 잘 일반화되는 것을 볼 수 있었습니다. 이는 학습에 사용한 큰 데이터셋을 사용했기 때문입니다.***
+
+ ## 다수의 GPU로 학습하기
+
+`accelerate`는 원활한 다수의 GPU로 학습을 가능하게 합니다. `accelerate`로 분산 학습을 실행하는 [여기](https://huggingface.co/docs/accelerate/basic_tutorials/launch) 설명을 따라 해 주시기 바랍니다. 예시의 명령어 입니다:
+
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
+ --dataset_name=sayakpaul/instructpix2pix-1000-samples \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+ ## 추론하기
+
+일단 학습이 완료되면, 추론 할 수 있습니다:
+
+ ```python
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionInstructPix2PixPipeline
+
+model_id = "your_model_id" # <- 이를 수정하세요.
+pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+image = download_image(url)
+prompt = "wipe out the lake"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipe(
+ prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+학습 스크립트를 사용해 얻은 예시의 모델 레포지토리는 여기 [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix)에서 확인할 수 있습니다.
+
+성능을 위한 속도와 품질을 제어하기 위해 세 가지 파라미터를 사용하는 것이 좋습니다:
+
+* `num_inference_steps`
+* `image_guidance_scale`
+* `guidance_scale`
+
+특히, `image_guidance_scale`와 `guidance_scale`는 생성된("수정된") 이미지에서 큰 영향을 미칠 수 있습니다.([여기](https://twitter.com/RisingSayak/status/1628392199196151808?s=20)예시를 참고해주세요.)
+
+
+만약 InstructPix2Pix 학습 방법을 사용해 몇 가지 흥미로운 방법을 찾고 있다면, 이 블로그 게시물[Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd)을 확인해주세요.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/lora.md b/diffusers/docs/source/ko/training/lora.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a6320d6b1564896bfdff7acb68fc70a657ed0aa
--- /dev/null
+++ b/diffusers/docs/source/ko/training/lora.md
@@ -0,0 +1,128 @@
+
+
+# Low-Rank Adaptation of Large Language Models (LoRA)
+
+[[open-in-colab]]
+
+
+
+현재 LoRA는 [`UNet2DConditionalModel`]의 어텐션 레이어에서만 지원됩니다.
+
+
+
+[LoRA(Low-Rank Adaptation of Large Language Models)](https://arxiv.org/abs/2106.09685)는 메모리를 적게 사용하면서 대규모 모델의 학습을 가속화하는 학습 방법입니다. 이는 rank-decomposition weight 행렬 쌍(**업데이트 행렬**이라고 함)을 추가하고 새로 추가된 가중치**만** 학습합니다. 여기에는 몇 가지 장점이 있습니다.
+
+- 이전에 미리 학습된 가중치는 고정된 상태로 유지되므로 모델이 [치명적인 망각](https://www.pnas.org/doi/10.1073/pnas.1611835114) 경향이 없습니다.
+- Rank-decomposition 행렬은 원래 모델보다 파라메터 수가 훨씬 적으므로 학습된 LoRA 가중치를 쉽게 끼워넣을 수 있습니다.
+- LoRA 매트릭스는 일반적으로 원본 모델의 어텐션 레이어에 추가됩니다. 🧨 Diffusers는 [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] 메서드를 제공하여 LoRA 가중치를 모델의 어텐션 레이어로 불러옵니다. `scale` 매개변수를 통해 모델이 새로운 학습 이미지에 맞게 조정되는 범위를 제어할 수 있습니다.
+- 메모리 효율성이 향상되어 Tesla T4, RTX 3080 또는 RTX 2080 Ti와 같은 소비자용 GPU에서 파인튜닝을 실행할 수 있습니다! T4와 같은 GPU는 무료이며 Kaggle 또는 Google Colab 노트북에서 쉽게 액세스할 수 있습니다.
+
+
+
+
+💡 LoRA는 어텐션 레이어에만 한정되지는 않습니다. 저자는 언어 모델의 어텐션 레이어를 수정하는 것이 매우 효율적으로 죻은 성능을 얻기에 충분하다는 것을 발견했습니다. 이것이 LoRA 가중치를 모델의 어텐션 레이어에 추가하는 것이 일반적인 이유입니다. LoRA 작동 방식에 대한 자세한 내용은 [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) 블로그를 확인하세요!
+
+
+
+[cloneofsimo](https://github.com/cloneofsimo)는 인기 있는 [lora](https://github.com/cloneofsimo/lora) GitHub 리포지토리에서 Stable Diffusion을 위한 LoRA 학습을 최초로 시도했습니다. 🧨 Diffusers는 [text-to-image 생성](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) 및 [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)을 지원합니다. 이 가이드는 두 가지를 모두 수행하는 방법을 보여줍니다.
+
+모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](hf.co/join)하세요):
+
+```bash
+huggingface-cli login
+```
+
+## Text-to-image
+
+수십억 개의 파라메터들이 있는 Stable Diffusion과 같은 모델을 파인튜닝하는 것은 느리고 어려울 수 있습니다. LoRA를 사용하면 diffusion 모델을 파인튜닝하는 것이 훨씬 쉽고 빠릅니다. 8비트 옵티마이저와 같은 트릭에 의존하지 않고도 11GB의 GPU RAM으로 하드웨어에서 실행할 수 있습니다.
+
+
+### 학습[[dreambooth-training]]
+
+[Pokémon BLIP 캡션](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) 데이터셋으로 [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)를 파인튜닝해 나만의 포켓몬을 생성해 보겠습니다.
+
+시작하려면 `MODEL_NAME` 및 `DATASET_NAME` 환경 변수가 설정되어 있는지 확인하십시오. `OUTPUT_DIR` 및 `HUB_MODEL_ID` 변수는 선택 사항이며 허브에서 모델을 저장할 위치를 지정합니다.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
+export HUB_MODEL_ID="pokemon-lora"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+```
+
+학습을 시작하기 전에 알아야 할 몇 가지 플래그가 있습니다.
+
+* `--push_to_hub`를 명시하면 학습된 LoRA 임베딩을 허브에 저장합니다.
+* `--report_to=wandb`는 학습 결과를 가중치 및 편향 대시보드에 보고하고 기록합니다(예를 들어, 이 [보고서](https://wandb.ai/pcuenq/text2image-fine-tune/run/b4k1w0tn?workspace=user-pcuenq)를 참조하세요).
+* `--learning_rate=1e-04`, 일반적으로 LoRA에서 사용하는 것보다 더 높은 학습률을 사용할 수 있습니다.
+
+이제 학습을 시작할 준비가 되었습니다 (전체 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)에서 찾을 수 있습니다).
+
+```bash
+accelerate launch train_dreambooth_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --checkpointing_steps=100 \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=50 \
+ --seed="0" \
+ --push_to_hub
+```
+
+### 추론[[dreambooth-inference]]
+
+이제 [`StableDiffusionPipeline`]에서 기본 모델을 불러와 추론을 위해 모델을 사용할 수 있습니다:
+
+```py
+>>> import torch
+>>> from diffusers import StableDiffusionPipeline
+
+>>> model_base = "runwayml/stable-diffusion-v1-5"
+
+>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
+```
+
+*기본 모델의 가중치 위에* 파인튜닝된 DreamBooth 모델에서 LoRA 가중치를 불러온 다음, 더 빠른 추론을 위해 파이프라인을 GPU로 이동합니다. LoRA 가중치를 프리징된 사전 훈련된 모델 가중치와 병합할 때, 선택적으로 'scale' 매개변수로 어느 정도의 가중치를 병합할 지 조절할 수 있습니다:
+
+
+
+💡 `0`의 `scale` 값은 LoRA 가중치를 사용하지 않아 원래 모델의 가중치만 사용한 것과 같고, `1`의 `scale` 값은 파인튜닝된 LoRA 가중치만 사용함을 의미합니다. 0과 1 사이의 값들은 두 결과들 사이로 보간됩니다.
+
+
+
+```py
+>>> pipe.unet.load_attn_procs(model_path)
+>>> pipe.to("cuda")
+# LoRA 파인튜닝된 모델의 가중치 절반과 기본 모델의 가중치 절반 사용
+
+>>> image = pipe(
+... "A picture of a sks dog in a bucket.",
+... num_inference_steps=25,
+... guidance_scale=7.5,
+... cross_attention_kwargs={"scale": 0.5},
+... ).images[0]
+# 완전히 파인튜닝된 LoRA 모델의 가중치 사용
+
+>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
+>>> image.save("bucket-dog.png")
+```
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/overview.md b/diffusers/docs/source/ko/training/overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..3516151342360ba856e266a1e056b9a8a3e9554c
--- /dev/null
+++ b/diffusers/docs/source/ko/training/overview.md
@@ -0,0 +1,73 @@
+
+
+# 🧨 Diffusers 학습 예시
+
+이번 챕터에서는 다양한 유즈케이스들에 대한 예제 코드들을 통해 어떻게하면 효과적으로 `diffusers` 라이브러리를 사용할 수 있을까에 대해 알아보도록 하겠습니다.
+
+**Note**: 혹시 오피셜한 예시코드를 찾고 있다면, [여기](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)를 참고해보세요!
+
+여기서 다룰 예시들은 다음을 지향합니다.
+
+- **손쉬운 디펜던시 설치** (Self-contained) : 여기서 사용될 예시 코드들의 디펜던시 패키지들은 전부 `pip install` 명령어를 통해 설치 가능한 패키지들입니다. 또한 친절하게 `requirements.txt` 파일에 해당 패키지들이 명시되어 있어, `pip install -r requirements.txt`로 간편하게 해당 디펜던시들을 설치할 수 있습니다. 예시: [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt)
+- **손쉬운 수정** (Easy-to-tweak) : 저희는 가능하면 많은 유즈 케이스들을 제공하고자 합니다. 하지만 예시는 결국 그저 예시라는 점들 기억해주세요. 여기서 제공되는 예시코드들을 그저 단순히 복사-붙혀넣기하는 식으로는 여러분이 마주한 문제들을 손쉽게 해결할 순 없을 것입니다. 다시 말해 어느 정도는 여러분의 상황과 니즈에 맞춰 코드를 일정 부분 고쳐나가야 할 것입니다. 따라서 대부분의 학습 예시들은 데이터의 전처리 과정과 학습 과정에 대한 코드들을 함께 제공함으로써, 사용자가 니즈에 맞게 손쉬운 수정할 수 있도록 돕고 있습니다.
+- **입문자 친화적인** (Beginner-friendly) : 이번 챕터는 diffusion 모델과 `diffusers` 라이브러리에 대한 전반적인 이해를 돕기 위해 작성되었습니다. 따라서 diffusion 모델에 대한 최신 SOTA (state-of-the-art) 방법론들 가운데서도, 입문자에게는 많이 어려울 수 있다고 판단되면, 해당 방법론들은 여기서 다루지 않으려고 합니다.
+- **하나의 태스크만 포함할 것**(One-purpose-only): 여기서 다룰 예시들은 하나의 태스크만 포함하고 있어야 합니다. 물론 이미지 초해상화(super-resolution)와 이미지 보정(modification)과 같은 유사한 모델링 프로세스를 갖는 태스크들이 존재하겠지만, 하나의 예제에 하나의 태스크만을 담는 것이 더 이해하기 용이하다고 판단했기 때문입니다.
+
+
+
+저희는 diffusion 모델의 대표적인 태스크들을 다루는 공식 예제를 제공하고 있습니다. *공식* 예제는 현재 진행형으로 `diffusers` 관리자들(maintainers)에 의해 관리되고 있습니다. 또한 저희는 앞서 정의한 저희의 철학을 엄격하게 따르고자 노력하고 있습니다. 혹시 여러분께서 이러한 예시가 반드시 필요하다고 생각되신다면, 언제든지 [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) 혹은 직접 [Pull Request](https://github.com/huggingface/diffusers/compare)를 주시기 바랍니다. 저희는 언제나 환영입니다!
+
+학습 예시들은 다양한 태스크들에 대해 diffusion 모델을 사전학습(pretrain)하거나 파인튜닝(fine-tuning)하는 법을 보여줍니다. 현재 다음과 같은 예제들을 지원하고 있습니다.
+
+- [Unconditional Training](./unconditional_training)
+- [Text-to-Image Training](./text2image)
+- [Text Inversion](./text_inversion)
+- [Dreambooth](./dreambooth)
+
+memory-efficient attention 연산을 수행하기 위해, 가능하면 [xFormers](../optimization/xformers)를 설치해주시기 바랍니다. 이를 통해 학습 속도를 늘리고 메모리에 대한 부담을 줄일 수 있습니다.
+
+| Task | 🤗 Accelerate | 🤗 Datasets | Colab
+|---|---|:---:|:---:|
+| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [**Text-to-Image fine-tuning**](./text2image) | ✅ | ✅ |
+| [**Textual Inversion**](./text_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
+| [**Training with LoRA**](./lora) | ✅ | - | - |
+| [**ControlNet**](./controlnet) | ✅ | ✅ | - |
+| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - |
+| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - |
+
+
+## 커뮤니티
+
+공식 예제 외에도 **커뮤니티 예제** 역시 제공하고 있습니다. 해당 예제들은 우리의 커뮤니티에 의해 관리됩니다. 커뮤니티 예쩨는 학습 예시나 추론 파이프라인으로 구성될 수 있습니다. 이러한 커뮤니티 예시들의 경우, 앞서 정의했던 철학들을 좀 더 관대하게 적용하고 있습니다. 또한 이러한 커뮤니티 예시들의 경우, 모든 이슈들에 대한 유지보수를 보장할 수는 없습니다.
+
+유용하긴 하지만, 아직은 대중적이지 못하거나 저희의 철학에 부합하지 않는 예제들은 [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) 폴더에 담기게 됩니다.
+
+**Note**: 커뮤니티 예제는 `diffusers`에 기여(contribution)를 희망하는 분들에게 [아주 좋은 기여 수단](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)이 될 수 있습니다.
+
+## 주목할 사항들
+
+최신 버전의 예시 코드들의 성공적인 구동을 보장하기 위해서는, 반드시 **소스코드를 통해 `diffusers`를 설치해야 하며,** 해당 예시 코드들이 요구하는 디펜던시들 역시 설치해야 합니다. 이를 위해 새로운 가상 환경을 구축하고 다음의 명령어를 실행해야 합니다.
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+그 다음 `cd` 명령어를 통해 해당 예제 디렉토리에 접근해서 다음 명령어를 실행하면 됩니다.
+
+```bash
+pip install -r requirements.txt
+```
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/text2image.md b/diffusers/docs/source/ko/training/text2image.md
new file mode 100644
index 0000000000000000000000000000000000000000..069388603124bc6f02b3c11f9b2dbe630909f0ec
--- /dev/null
+++ b/diffusers/docs/source/ko/training/text2image.md
@@ -0,0 +1,224 @@
+
+
+
+# Text-to-image
+
+
+
+text-to-image 파인튜닝 스크립트는 experimental 상태입니다. 과적합하기 쉽고 치명적인 망각과 같은 문제에 부딪히기 쉽습니다. 자체 데이터셋에서 최상의 결과를 얻으려면 다양한 하이퍼파라미터를 탐색하는 것이 좋습니다.
+
+
+
+Stable Diffusion과 같은 text-to-image 모델은 텍스트 프롬프트에서 이미지를 생성합니다. 이 가이드는 PyTorch 및 Flax를 사용하여 자체 데이터셋에서 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 파인튜닝하는 방법을 보여줍니다. 이 가이드에 사용된 text-to-image 파인튜닝을 위한 모든 학습 스크립트에 관심이 있는 경우 이 [리포지토리](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)에서 자세히 찾을 수 있습니다.
+
+스크립트를 실행하기 전에, 라이브러리의 학습 dependency들을 설치해야 합니다:
+
+```bash
+pip install git+https://github.com/huggingface/diffusers.git
+pip install -U -r requirements.txt
+```
+
+그리고 [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화합니다:
+
+```bash
+accelerate config
+```
+
+리포지토리를 이미 복제한 경우, 이 단계를 수행할 필요가 없습니다. 대신, 로컬 체크아웃 경로를 학습 스크립트에 명시할 수 있으며 거기에서 로드됩니다.
+
+### 하드웨어 요구 사항
+
+`gradient_checkpointing` 및 `mixed_precision`을 사용하면 단일 24GB GPU에서 모델을 파인튜닝할 수 있습니다. 더 높은 `batch_size`와 더 빠른 훈련을 위해서는 GPU 메모리가 30GB 이상인 GPU를 사용하는 것이 좋습니다. TPU 또는 GPU에서 파인튜닝을 위해 JAX나 Flax를 사용할 수도 있습니다. 자세한 내용은 [아래](#flax-jax-finetuning)를 참조하세요.
+
+xFormers로 memory efficient attention을 활성화하여 메모리 사용량 훨씬 더 줄일 수 있습니다. [xFormers가 설치](./optimization/xformers)되어 있는지 확인하고 `--enable_xformers_memory_efficient_attention`를 학습 스크립트에 명시합니다.
+
+xFormers는 Flax에 사용할 수 없습니다.
+
+## Hub에 모델 업로드하기
+
+학습 스크립트에 다음 인수를 추가하여 모델을 허브에 저장합니다:
+
+```bash
+ --push_to_hub
+```
+
+
+## 체크포인트 저장 및 불러오기
+
+학습 중 발생할 수 있는 일에 대비하여 정기적으로 체크포인트를 저장해 두는 것이 좋습니다. 체크포인트를 저장하려면 학습 스크립트에 다음 인수를 명시합니다.
+
+```bash
+ --checkpointing_steps=500
+```
+
+500스텝마다 전체 학습 state가 'output_dir'의 하위 폴더에 저장됩니다. 체크포인트는 'checkpoint-'에 지금까지 학습된 step 수입니다. 예를 들어 'checkpoint-1500'은 1500 학습 step 후에 저장된 체크포인트입니다.
+
+학습을 재개하기 위해 체크포인트를 불러오려면 '--resume_from_checkpoint' 인수를 학습 스크립트에 명시하고 재개할 체크포인트를 지정하십시오. 예를 들어 다음 인수는 1500개의 학습 step 후에 저장된 체크포인트에서부터 훈련을 재개합니다.
+
+```bash
+ --resume_from_checkpoint="checkpoint-1500"
+```
+
+## 파인튜닝
+
+
+
+다음과 같이 [Pokémon BLIP 캡션](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) 데이터셋에서 파인튜닝 실행을 위해 [PyTorch 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)를 실행합니다:
+
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export dataset_name="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+자체 데이터셋으로 파인튜닝하려면 🤗 [Datasets](https://huggingface.co/docs/datasets/index)에서 요구하는 형식에 따라 데이터셋을 준비하세요. [데이터셋을 허브에 업로드](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)하거나 [파일들이 있는 로컬 폴더를 준비](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)할 수 있습니다.
+
+사용자 커스텀 loading logic을 사용하려면 스크립트를 수정하십시오. 도움이 되도록 코드의 적절한 위치에 포인터를 남겼습니다. 🤗 아래 예제 스크립트는 `TRAIN_DIR`의 로컬 데이터셋으로를 파인튜닝하는 방법과 `OUTPUT_DIR`에서 모델을 저장할 위치를 보여줍니다:
+
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export TRAIN_DIR="path_to_your_dataset"
+export OUTPUT_DIR="path_to_save_model"
+
+accelerate launch train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir=${OUTPUT_DIR}
+```
+
+
+
+[@duongna211](https://github.com/duongna21)의 기여로, Flax를 사용해 TPU 및 GPU에서 Stable Diffusion 모델을 더 빠르게 학습할 수 있습니다. 이는 TPU 하드웨어에서 매우 효율적이지만 GPU에서도 훌륭하게 작동합니다. Flax 학습 스크립트는 gradient checkpointing나 gradient accumulation과 같은 기능을 아직 지원하지 않으므로 메모리가 30GB 이상인 GPU 또는 TPU v3가 필요합니다.
+
+스크립트를 실행하기 전에 요구 사항이 설치되어 있는지 확인하십시오:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+그러면 다음과 같이 [Flax 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)를 실행할 수 있습니다.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export dataset_name="lambdalabs/pokemon-blip-captions"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-pokemon-model"
+```
+
+자체 데이터셋으로 파인튜닝하려면 🤗 [Datasets](https://huggingface.co/docs/datasets/index)에서 요구하는 형식에 따라 데이터셋을 준비하세요. [데이터셋을 허브에 업로드](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)하거나 [파일들이 있는 로컬 폴더를 준비](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)할 수 있습니다.
+
+사용자 커스텀 loading logic을 사용하려면 스크립트를 수정하십시오. 도움이 되도록 코드의 적절한 위치에 포인터를 남겼습니다. 🤗 아래 예제 스크립트는 `TRAIN_DIR`의 로컬 데이터셋으로를 파인튜닝하는 방법을 보여줍니다:
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export TRAIN_DIR="path_to_your_dataset"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-pokemon-model"
+```
+
+
+
+## LoRA
+
+Text-to-image 모델 파인튜닝을 위해, 대규모 모델 학습을 가속화하기 위한 파인튜닝 기술인 LoRA(Low-Rank Adaptation of Large Language Models)를 사용할 수 있습니다. 자세한 내용은 [LoRA 학습](lora#text-to-image) 가이드를 참조하세요.
+
+## 추론
+
+허브의 모델 경로 또는 모델 이름을 [`StableDiffusionPipeline`]에 전달하여 추론을 위해 파인 튜닝된 모델을 불러올 수 있습니다:
+
+
+
+```python
+from diffusers import StableDiffusionPipeline
+
+model_path = "path_to_saved_model"
+pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(prompt="yoda").images[0]
+image.save("yoda-pokemon.png")
+```
+
+
+```python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+model_path = "path_to_saved_model"
+pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
+
+prompt = "yoda pokemon"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("yoda-pokemon.png")
+```
+
+
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/training/text_inversion.md b/diffusers/docs/source/ko/training/text_inversion.md
new file mode 100644
index 0000000000000000000000000000000000000000..948127bc09b93839f4717253d64d0a50da6b1c3d
--- /dev/null
+++ b/diffusers/docs/source/ko/training/text_inversion.md
@@ -0,0 +1,275 @@
+
+
+
+
+# Textual-Inversion
+
+[[open-in-colab]]
+
+[textual-inversion](https://arxiv.org/abs/2208.01618)은 소수의 예시 이미지에서 새로운 콘셉트를 포착하는 기법입니다. 이 기술은 원래 [Latent Diffusion](https://github.com/CompVis/latent-diffusion)에서 시연되었지만, 이후 [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion)과 같은 유사한 다른 모델에도 적용되었습니다. 학습된 콘셉트는 text-to-image 파이프라인에서 생성된 이미지를 더 잘 제어하는 데 사용할 수 있습니다. 이 모델은 텍스트 인코더의 임베딩 공간에서 새로운 '단어'를 학습하여 개인화된 이미지 생성을 위한 텍스트 프롬프트 내에서 사용됩니다.
+
+![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG)
+By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation (image source).
+
+이 가이드에서는 textual-inversion으로 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 모델을 학습하는 방법을 설명합니다. 이 가이드에서 사용된 모든 textual-inversion 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)에서 확인할 수 있습니다. 내부적으로 어떻게 작동하는지 자세히 살펴보고 싶으시다면 해당 링크를 참조해주시기 바랍니다.
+
+
+
+[Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)에는 커뮤니티에서 제작한 학습된 textual-inversion 모델들이 있습니다. 시간이 지남에 따라 더 많은 콘셉트들이 추가되어 유용한 리소스로 성장할 것입니다!
+
+
+
+시작하기 전에 학습을 위한 의존성 라이브러리들을 설치해야 합니다:
+
+```bash
+pip install diffusers accelerate transformers
+```
+
+의존성 라이브러리들의 설치가 완료되면, [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화시킵니다.
+
+```bash
+accelerate config
+```
+
+별도의 설정없이, 기본 🤗Accelerate 환경을 설정하려면 다음과 같이 하세요:
+
+```bash
+accelerate config default
+```
+
+또는 사용 중인 환경이 노트북과 같은 대화형 셸을 지원하지 않는다면, 다음과 같이 사용할 수 있습니다:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+마지막으로, Memory-Efficient Attention을 통해 메모리 사용량을 줄이기 위해 [xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers)를 설치합니다. xFormers를 설치한 후, 학습 스크립트에 `--enable_xformers_memory_efficient_attention` 인자를 추가합니다. xFormers는 Flax에서 지원되지 않습니다.
+
+## 허브에 모델 업로드하기
+
+모델을 허브에 저장하려면, 학습 스크립트에 다음 인자를 추가해야 합니다.
+
+```bash
+--push_to_hub
+```
+
+## 체크포인트 저장 및 불러오기
+
+학습중에 모델의 체크포인트를 정기적으로 저장하는 것이 좋습니다. 이렇게 하면 어떤 이유로든 학습이 중단된 경우 저장된 체크포인트에서 학습을 다시 시작할 수 있습니다. 학습 스크립트에 다음 인자를 전달하면 500단계마다 전체 학습 상태가 `output_dir`의 하위 폴더에 체크포인트로서 저장됩니다.
+
+```bash
+--checkpointing_steps=500
+```
+
+저장된 체크포인트에서 학습을 재개하려면, 학습 스크립트와 재개할 특정 체크포인트에 다음 인자를 전달하세요.
+
+```bash
+--resume_from_checkpoint="checkpoint-1500"
+```
+
+## 파인 튜닝
+
+학습용 데이터셋으로 [고양이 장난감 데이터셋](https://huggingface.co/datasets/diffusers/cat_toy_example)을 다운로드하여 디렉토리에 저장하세요. 여러분만의 고유한 데이터셋을 사용하고자 한다면, [학습용 데이터셋 만들기](https://huggingface.co/docs/diffusers/training/create_dataset) 가이드를 살펴보시기 바랍니다.
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download(
+ "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
+)
+```
+
+모델의 리포지토리 ID(또는 모델 가중치가 포함된 디렉터리 경로)를 `MODEL_NAME` 환경 변수에 할당하고, 해당 값을 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인자에 전달합니다. 그리고 이미지가 포함된 디렉터리 경로를 `DATA_DIR` 환경 변수에 할당합니다.
+
+이제 [학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)를 실행할 수 있습니다. 스크립트는 다음 파일을 생성하고 리포지토리에 저장합니다.
+
+- `learned_embeds.bin`
+- `token_identifier.txt`
+- `type_of_concept.txt`.
+
+
+
+💡V100 GPU 1개를 기준으로 전체 학습에는 최대 1시간이 걸립니다. 학습이 완료되기를 기다리는 동안 궁금한 점이 있으면 아래 섹션에서 [textual-inversion이 어떻게 작동하는지](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) 자유롭게 확인하세요 !
+
+
+
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATA_DIR="./cat"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat" \
+ --push_to_hub
+```
+
+
+
+💡학습 성능을 올리기 위해, 플레이스홀더 토큰(``)을 (단일한 임베딩 벡터가 아닌) 복수의 임베딩 벡터로 표현하는 것 역시 고려할 있습니다. 이러한 트릭이 모델이 보다 복잡한 이미지의 스타일(앞서 말한 콘셉트)을 더 잘 캡처하는 데 도움이 될 수 있습니다. 복수의 임베딩 벡터 학습을 활성화하려면 다음 옵션을 전달하십시오.
+
+```bash
+--num_vectors=5
+```
+
+
+
+
+
+TPU에 액세스할 수 있는 경우, [Flax 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py)를 사용하여 더 빠르게 모델을 학습시켜보세요. (물론 GPU에서도 작동합니다.) 동일한 설정에서 Flax 학습 스크립트는 PyTorch 학습 스크립트보다 최소 70% 더 빨라야 합니다! ⚡️
+
+시작하기 앞서 Flax에 대한 의존성 라이브러리들을 설치해야 합니다.
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+모델의 리포지토리 ID(또는 모델 가중치가 포함된 디렉터리 경로)를 `MODEL_NAME` 환경 변수에 할당하고, 해당 값을 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인자에 전달합니다.
+
+그런 다음 [학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py)를 시작할 수 있습니다.
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="./cat"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --output_dir="textual_inversion_cat" \
+ --push_to_hub
+```
+
+
+
+### 중간 로깅
+
+모델의 학습 진행 상황을 추적하는 데 관심이 있는 경우, 학습 과정에서 생성된 이미지를 저장할 수 있습니다. 학습 스크립트에 다음 인수를 추가하여 중간 로깅을 활성화합니다.
+
+- `validation_prompt` : 샘플을 생성하는 데 사용되는 프롬프트(기본값은 `None`으로 설정되며, 이 때 중간 로깅은 비활성화됨)
+- `num_validation_images` : 생성할 샘플 이미지 수
+- `validation_steps` : `validation_prompt`로부터 샘플 이미지를 생성하기 전 스텝의 수
+
+```bash
+--validation_prompt="A backpack"
+--num_validation_images=4
+--validation_steps=100
+```
+
+## 추론
+
+모델을 학습한 후에는, 해당 모델을 [`StableDiffusionPipeline`]을 사용하여 추론에 사용할 수 있습니다.
+
+textual-inversion 스크립트는 기본적으로 textual-inversion을 통해 얻어진 임베딩 벡터만을 저장합니다. 해당 임베딩 벡터들은 텍스트 인코더의 임베딩 행렬에 추가되어 있습습니다.
+
+
+
+
+
+💡 커뮤니티는 [sd-concepts-library](https://huggingface.co/sd-concepts-library) 라는 대규모의 textual-inversion 임베딩 벡터 라이브러리를 만들었습니다. textual-inversion 임베딩을 밑바닥부터 학습하는 대신, 해당 라이브러리에 본인이 찾는 textual-inversion 임베딩이 이미 추가되어 있지 않은지를 확인하는 것도 좋은 방법이 될 것 같습니다.
+
+
+
+textual-inversion 임베딩 벡터을 불러오기 위해서는, 먼저 해당 임베딩 벡터를 학습할 때 사용한 모델을 불러와야 합니다. 여기서는 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/docs/diffusers/training/runwayml/stable-diffusion-v1-5) 모델이 사용되었다고 가정하고 불러오겠습니다.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+```
+
+다음으로 `TextualInversionLoaderMixin.load_textual_inversion` 함수를 통해, textual-inversion 임베딩 벡터를 불러와야 합니다. 여기서 우리는 이전의 `` 예제의 임베딩을 불러올 것입니다.
+
+```python
+pipe.load_textual_inversion("sd-concepts-library/cat-toy")
+```
+
+이제 플레이스홀더 토큰(``)이 잘 동작하는지를 확인하는 파이프라인을 실행할 수 있습니다.
+
+```python
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50).images[0]
+image.save("cat-backpack.png")
+```
+
+`TextualInversionLoaderMixin.load_textual_inversion`은 Diffusers 형식으로 저장된 텍스트 임베딩 벡터를 로드할 수 있을 뿐만 아니라, [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 형식으로 저장된 임베딩 벡터도 로드할 수 있습니다. 이렇게 하려면, 먼저 [civitAI](https://civitai.com/models/3036?modelVersionId=8387)에서 임베딩 벡터를 다운로드한 다음 로컬에서 불러와야 합니다.
+
+```python
+pipe.load_textual_inversion("./charturnerv2.pt")
+```
+
+
+
+현재 Flax에 대한 `load_textual_inversion` 함수는 없습니다. 따라서 학습 후 textual-inversion 임베딩 벡터가 모델의 일부로서 저장되었는지를 확인해야 합니다. 그런 다음은 다른 Flax 모델과 마찬가지로 실행할 수 있습니다.
+
+```python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+model_path = "path-to-your-trained-model"
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
+
+prompt = "A backpack"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("cat-backpack.png")
+```
+
+
+
+## 작동 방식
+
+![Diagram from the paper showing overview](https://textual-inversion.github.io/static/images/training/training.JPG)
+Architecture overview from the Textual Inversion blog post.
+
+일반적으로 텍스트 프롬프트는 모델에 전달되기 전에 임베딩으로 토큰화됩니다. textual-inversion은 비슷한 작업을 수행하지만, 위 다이어그램의 특수 토큰 `S*`로부터 새로운 토큰 임베딩 `v*`를 학습합니다. 모델의 아웃풋은 디퓨전 모델을 조정하는 데 사용되며, 디퓨전 모델이 단 몇 개의 예제 이미지에서 신속하고 새로운 콘셉트를 이해하는 데 도움을 줍니다.
+
+이를 위해 textual-inversion은 제너레이터 모델과 학습용 이미지의 노이즈 버전을 사용합니다. 제너레이터는 노이즈가 적은 버전의 이미지를 예측하려고 시도하며 토큰 임베딩 `v*`은 제너레이터의 성능에 따라 최적화됩니다. 토큰 임베딩이 새로운 콘셉트를 성공적으로 포착하면 디퓨전 모델에 더 유용한 정보를 제공하고 노이즈가 적은 더 선명한 이미지를 생성하는 데 도움이 됩니다. 이러한 최적화 프로세스는 일반적으로 다양한 프롬프트와 이미지에 수천 번에 노출됨으로써 이루어집니다.
+
diff --git a/diffusers/docs/source/ko/training/unconditional_training.md b/diffusers/docs/source/ko/training/unconditional_training.md
new file mode 100644
index 0000000000000000000000000000000000000000..62c846311114a08d15b05994a6694ad44d16542e
--- /dev/null
+++ b/diffusers/docs/source/ko/training/unconditional_training.md
@@ -0,0 +1,144 @@
+
+
+# Unconditional 이미지 생성
+
+unconditional 이미지 생성은 text-to-image 또는 image-to-image 모델과 달리 텍스트나 이미지에 대한 조건이 없이 학습 데이터 분포와 유사한 이미지만을 생성합니다.
+
+
+
+
+이 가이드에서는 기존에 존재하던 데이터셋과 자신만의 커스텀 데이터셋에 대해 unconditional image generation 모델을 훈련하는 방법을 설명합니다. 훈련 세부 사항에 대해 더 자세히 알고 싶다면 unconditional image generation을 위한 모든 학습 스크립트를 [여기](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation)에서 확인할 수 있습니다.
+
+스크립트를 실행하기 전, 먼저 의존성 라이브러리들을 설치해야 합니다.
+
+```bash
+pip install diffusers[training] accelerate datasets
+```
+
+그 다음 🤗 [Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화합니다.
+
+```bash
+accelerate config
+```
+
+별도의 설정 없이 기본 설정으로 🤗 [Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화해봅시다.
+
+```bash
+accelerate config default
+```
+
+노트북과 같은 대화형 쉘을 지원하지 않는 환경의 경우, 다음과 같이 사용해볼 수도 있습니다.
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+## 모델을 허브에 업로드하기
+
+학습 스크립트에 다음 인자를 추가하여 허브에 모델을 업로드할 수 있습니다.
+
+```bash
+--push_to_hub
+```
+
+## 체크포인트 저장하고 불러오기
+
+훈련 중 문제가 발생할 경우를 대비하여 체크포인트를 정기적으로 저장하는 것이 좋습니다. 체크포인트를 저장하려면 학습 스크립트에 다음 인자를 전달합니다:
+
+```bash
+--checkpointing_steps=500
+```
+
+전체 훈련 상태는 500스텝마다 `output_dir`의 하위 폴더에 저장되며, 학습 스크립트에 `--resume_from_checkpoint` 인자를 전달함으로써 체크포인트를 불러오고 훈련을 재개할 수 있습니다.
+
+```bash
+--resume_from_checkpoint="checkpoint-1500"
+```
+
+## 파인튜닝
+
+이제 학습 스크립트를 시작할 준비가 되었습니다! `--dataset_name` 인자에 파인튜닝할 데이터셋 이름을 지정한 다음, `--output_dir` 인자에 지정된 경로로 저장합니다. 본인만의 데이터셋를 사용하려면, [학습용 데이터셋 만들기](create_dataset) 가이드를 참조하세요.
+
+학습 스크립트는 `diffusion_pytorch_model.bin` 파일을 생성하고, 그것을 당신의 리포지토리에 저장합니다.
+
+
+
+💡 전체 학습은 V100 GPU 4개를 사용할 경우, 2시간이 소요됩니다.
+
+
+
+예를 들어, [Oxford Flowers](https://huggingface.co/datasets/huggan/flowers-102-categories) 데이터셋을 사용해 파인튜닝할 경우:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --resolution=64 \
+ --output_dir="ddpm-ema-flowers-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=no \
+ --push_to_hub
+```
+
+
+
+### 여러개의 GPU로 훈련하기
+
+`accelerate`을 사용하면 원활한 다중 GPU 훈련이 가능합니다. `accelerate`을 사용하여 분산 훈련을 실행하려면 [여기](https://huggingface.co/docs/accelerate/basic_tutorials/launch) 지침을 따르세요. 다음은 명령어 예제입니다.
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
+ --dataset_name="huggan/pokemon" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-pokemon-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision="fp16" \
+ --logger="wandb" \
+ --push_to_hub
+```
diff --git a/diffusers/docs/source/ko/tutorials/basic_training.md b/diffusers/docs/source/ko/tutorials/basic_training.md
new file mode 100644
index 0000000000000000000000000000000000000000..1cc82d2b8ce6a63ba6e7008aecbda80c0839bd87
--- /dev/null
+++ b/diffusers/docs/source/ko/tutorials/basic_training.md
@@ -0,0 +1,402 @@
+
+
+[[open-in-colab]]
+
+
+# Diffusion 모델을 학습하기
+
+Unconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한 이미지를 생성하는 diffusion 모델에서 인기 있는 어플리케이션입니다. 일반적으로, 가장 좋은 결과는 특정 데이터셋에 사전 훈련된 모델을 파인튜닝하는 것으로 얻을 수 있습니다. 이 [허브](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model)에서 이러한 많은 체크포인트를 찾을 수 있지만, 만약 마음에 드는 체크포인트를 찾지 못했다면, 언제든지 스스로 학습할 수 있습니다!
+
+이 튜토리얼은 나만의 🦋 나비 🦋를 생성하기 위해 [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) 데이터셋의 하위 집합에서 [`UNet2DModel`] 모델을 학습하는 방법을 가르쳐줄 것입니다.
+
+
+
+💡 이 학습 튜토리얼은 [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!
+
+
+
+시작 전에, 🤗 Datasets을 불러오고 전처리하기 위해 데이터셋이 설치되어 있는지 다수 GPU에서 학습을 간소화하기 위해 🤗 Accelerate 가 설치되어 있는지 확인하세요. 그 후 학습 메트릭을 시각화하기 위해 [TensorBoard](https://www.tensorflow.org/tensorboard)를 또한 설치하세요. (또한 학습 추적을 위해 [Weights & Biases](https://docs.wandb.ai/)를 사용할 수 있습니다.)
+
+```bash
+!pip install diffusers[training]
+```
+
+커뮤니티에 모델을 공유할 것을 권장하며, 이를 위해서 Hugging Face 계정에 로그인을 해야 합니다. (계정이 없다면 [여기](https://hf.co/join)에서 만들 수 있습니다.) 노트북에서 로그인할 수 있으며 메시지가 표시되면 토큰을 입력할 수 있습니다.
+
+```py
+>>> from huggingface_hub import notebook_login
+
+>>> notebook_login()
+```
+
+또는 터미널로 로그인할 수 있습니다:
+
+```bash
+huggingface-cli login
+```
+
+모델 체크포인트가 상당히 크기 때문에 [Git-LFS](https://git-lfs.com/)에서 대용량 파일의 버전 관리를 할 수 있습니다.
+
+```bash
+!sudo apt -qq install git-lfs
+!git config --global credential.helper store
+```
+
+
+## 학습 구성
+
+편의를 위해 학습 파라미터들을 포함한 `TrainingConfig` 클래스를 생성합니다 (자유롭게 조정 가능):
+
+```py
+>>> from dataclasses import dataclass
+
+
+>>> @dataclass
+... class TrainingConfig:
+... image_size = 128 # 생성되는 이미지 해상도
+... train_batch_size = 16
+... eval_batch_size = 16 # 평가 동안에 샘플링할 이미지 수
+... num_epochs = 50
+... gradient_accumulation_steps = 1
+... learning_rate = 1e-4
+... lr_warmup_steps = 500
+... save_image_epochs = 10
+... save_model_epochs = 30
+... mixed_precision = "fp16" # `no`는 float32, 자동 혼합 정밀도를 위한 `fp16`
+... output_dir = "ddpm-butterflies-128" # 로컬 및 HF Hub에 저장되는 모델명
+
+... push_to_hub = True # 저장된 모델을 HF Hub에 업로드할지 여부
+... hub_private_repo = False
+... overwrite_output_dir = True # 노트북을 다시 실행할 때 이전 모델에 덮어씌울지
+... seed = 0
+
+
+>>> config = TrainingConfig()
+```
+
+
+## 데이터셋 불러오기
+
+🤗 Datasets 라이브러리와 [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) 데이터셋을 쉽게 불러올 수 있습니다.
+
+```py
+>>> from datasets import load_dataset
+
+>>> config.dataset_name = "huggan/smithsonian_butterflies_subset"
+>>> dataset = load_dataset(config.dataset_name, split="train")
+```
+
+💡[HugGan Community Event](https://huggingface.co/huggan) 에서 추가의 데이터셋을 찾거나 로컬의 [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder)를 만듦으로써 나만의 데이터셋을 사용할 수 있습니다. HugGan Community Event 에 가져온 데이터셋의 경우 리포지토리의 id로 `config.dataset_name` 을 설정하고, 나만의 이미지를 사용하는 경우 `imagefolder` 를 설정합니다.
+
+🤗 Datasets은 [`~datasets.Image`] 기능을 사용해 자동으로 이미지 데이터를 디코딩하고 [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html)로 불러옵니다. 이를 시각화 해보면:
+
+```py
+>>> import matplotlib.pyplot as plt
+
+>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))
+>>> for i, image in enumerate(dataset[:4]["image"]):
+... axs[i].imshow(image)
+... axs[i].set_axis_off()
+>>> fig.show()
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png)
+
+이미지는 모두 다른 사이즈이기 때문에, 우선 전처리가 필요합니다:
+
+- `Resize` 는 `config.image_size` 에 정의된 이미지 사이즈로 변경합니다.
+- `RandomHorizontalFlip` 은 랜덤적으로 이미지를 미러링하여 데이터셋을 보강합니다.
+- `Normalize` 는 모델이 예상하는 [-1, 1] 범위로 픽셀 값을 재조정 하는데 중요합니다.
+
+```py
+>>> from torchvision import transforms
+
+>>> preprocess = transforms.Compose(
+... [
+... transforms.Resize((config.image_size, config.image_size)),
+... transforms.RandomHorizontalFlip(),
+... transforms.ToTensor(),
+... transforms.Normalize([0.5], [0.5]),
+... ]
+... )
+```
+
+ 학습 도중에 `preprocess` 함수를 적용하려면 🤗 Datasets의 [`~datasets.Dataset.set_transform`] 방법이 사용됩니다.
+
+```py
+>>> def transform(examples):
+... images = [preprocess(image.convert("RGB")) for image in examples["image"]]
+... return {"images": images}
+
+
+>>> dataset.set_transform(transform)
+```
+
+이미지의 크기가 조정되었는지 확인하기 위해 이미지를 다시 시각화해보세요. 이제 [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader)에 데이터셋을 포함해 학습할 준비가 되었습니다!
+
+```py
+>>> import torch
+
+>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
+```
+
+
+## UNet2DModel 생성하기
+
+🧨 Diffusers에 사전학습된 모델들은 모델 클래스에서 원하는 파라미터로 쉽게 생성할 수 있습니다. 예를 들어, [`UNet2DModel`]를 생성하려면:
+
+```py
+>>> from diffusers import UNet2DModel
+
+>>> model = UNet2DModel(
+... sample_size=config.image_size, # 타겟 이미지 해상도
+... in_channels=3, # 입력 채널 수, RGB 이미지에서 3
+... out_channels=3, # 출력 채널 수
+... layers_per_block=2, # UNet 블럭당 몇 개의 ResNet 레이어가 사용되는지
+... block_out_channels=(128, 128, 256, 256, 512, 512), # 각 UNet 블럭을 위한 출력 채널 수
+... down_block_types=(
+... "DownBlock2D", # 일반적인 ResNet 다운샘플링 블럭
+... "DownBlock2D",
+... "DownBlock2D",
+... "DownBlock2D",
+... "AttnDownBlock2D", # spatial self-attention이 포함된 일반적인 ResNet 다운샘플링 블럭
+... "DownBlock2D",
+... ),
+... up_block_types=(
+... "UpBlock2D", # 일반적인 ResNet 업샘플링 블럭
+... "AttnUpBlock2D", # spatial self-attention이 포함된 일반적인 ResNet 업샘플링 블럭
+... "UpBlock2D",
+... "UpBlock2D",
+... "UpBlock2D",
+... "UpBlock2D",
+... ),
+... )
+```
+
+샘플의 이미지 크기와 모델 출력 크기가 맞는지 빠르게 확인하기 위한 좋은 아이디어가 있습니다:
+
+```py
+>>> sample_image = dataset[0]["images"].unsqueeze(0)
+>>> print("Input shape:", sample_image.shape)
+Input shape: torch.Size([1, 3, 128, 128])
+
+>>> print("Output shape:", model(sample_image, timestep=0).sample.shape)
+Output shape: torch.Size([1, 3, 128, 128])
+```
+
+훌륭해요! 다음, 이미지에 약간의 노이즈를 더하기 위해 스케줄러가 필요합니다.
+
+
+## 스케줄러 생성하기
+
+스케줄러는 모델을 학습 또는 추론에 사용하는지에 따라 다르게 작동합니다. 추론시에, 스케줄러는 노이즈로부터 이미지를 생성합니다. 학습시 스케줄러는 diffusion 과정에서의 특정 포인트로부터 모델의 출력 또는 샘플을 가져와 *노이즈 스케줄* 과 *업데이트 규칙*에 따라 이미지에 노이즈를 적용합니다.
+
+`DDPMScheduler`를 보면 이전으로부터 `sample_image`에 랜덤한 노이즈를 더하는 `add_noise` 메서드를 사용합니다:
+
+```py
+>>> import torch
+>>> from PIL import Image
+>>> from diffusers import DDPMScheduler
+
+>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
+>>> noise = torch.randn(sample_image.shape)
+>>> timesteps = torch.LongTensor([50])
+>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
+
+>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png)
+
+모델의 학습 목적은 이미지에 더해진 노이즈를 예측하는 것입니다. 이 단계에서 손실은 다음과 같이 계산될 수 있습니다:
+
+```py
+>>> import torch.nn.functional as F
+
+>>> noise_pred = model(noisy_image, timesteps).sample
+>>> loss = F.mse_loss(noise_pred, noise)
+```
+
+## 모델 학습하기
+
+지금까지, 모델 학습을 시작하기 위해 많은 부분을 갖추었으며 이제 남은 것은 모든 것을 조합하는 것입니다.
+
+우선 옵티마이저(optimizer)와 학습률 스케줄러(learning rate scheduler)가 필요할 것입니다:
+
+```py
+>>> from diffusers.optimization import get_cosine_schedule_with_warmup
+
+>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
+>>> lr_scheduler = get_cosine_schedule_with_warmup(
+... optimizer=optimizer,
+... num_warmup_steps=config.lr_warmup_steps,
+... num_training_steps=(len(train_dataloader) * config.num_epochs),
+... )
+```
+
+그 후, 모델을 평가하는 방법이 필요합니다. 평가를 위해, `DDPMPipeline`을 사용해 배치의 이미지 샘플들을 생성하고 그리드 형태로 저장할 수 있습니다:
+
+```py
+>>> from diffusers import DDPMPipeline
+>>> import math
+>>> import os
+
+
+>>> def make_grid(images, rows, cols):
+... w, h = images[0].size
+... grid = Image.new("RGB", size=(cols * w, rows * h))
+... for i, image in enumerate(images):
+... grid.paste(image, box=(i % cols * w, i // cols * h))
+... return grid
+
+
+>>> def evaluate(config, epoch, pipeline):
+... # 랜덤한 노이즈로 부터 이미지를 추출합니다.(이는 역전파 diffusion 과정입니다.)
+... # 기본 파이프라인 출력 형태는 `List[PIL.Image]` 입니다.
+... images = pipeline(
+... batch_size=config.eval_batch_size,
+... generator=torch.manual_seed(config.seed),
+... ).images
+
+... # 이미지들을 그리드로 만들어줍니다.
+... image_grid = make_grid(images, rows=4, cols=4)
+
+... # 이미지들을 저장합니다.
+... test_dir = os.path.join(config.output_dir, "samples")
+... os.makedirs(test_dir, exist_ok=True)
+... image_grid.save(f"{test_dir}/{epoch:04d}.png")
+```
+
+TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽게 수행하기 위해 🤗 Accelerate를 학습 루프에 함께 앞서 말한 모든 구성 정보들을 묶어 진행할 수 있습니다. 허브에 모델을 업로드 하기 위해 리포지토리 이름 및 정보를 가져오기 위한 함수를 작성하고 허브에 업로드할 수 있습니다.
+
+💡아래의 학습 루프는 어렵고 길어 보일 수 있지만, 나중에 한 줄의 코드로 학습을 한다면 그만한 가치가 있을 것입니다! 만약 기다리지 못하고 이미지를 생성하고 싶다면, 아래 코드를 자유롭게 붙여넣고 작동시키면 됩니다. 🤗
+
+```py
+>>> from accelerate import Accelerator
+>>> from huggingface_hub import create_repo, upload_folder
+>>> from tqdm.auto import tqdm
+>>> from pathlib import Path
+>>> import os
+
+
+>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
+... # Initialize accelerator and tensorboard logging
+... accelerator = Accelerator(
+... mixed_precision=config.mixed_precision,
+... gradient_accumulation_steps=config.gradient_accumulation_steps,
+... log_with="tensorboard",
+... project_dir=os.path.join(config.output_dir, "logs"),
+... )
+... if accelerator.is_main_process:
+... if config.output_dir is not None:
+... os.makedirs(config.output_dir, exist_ok=True)
+... if config.push_to_hub:
+... repo_id = create_repo(
+... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
+... ).repo_id
+... accelerator.init_trackers("train_example")
+
+... # 모든 것이 준비되었습니다.
+... # 기억해야 할 특정한 순서는 없으며 준비한 방법에 제공한 것과 동일한 순서로 객체의 압축을 풀면 됩니다.
+... model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+... model, optimizer, train_dataloader, lr_scheduler
+... )
+
+... global_step = 0
+
+... # 이제 모델을 학습합니다.
+... for epoch in range(config.num_epochs):
+... progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
+... progress_bar.set_description(f"Epoch {epoch}")
+
+... for step, batch in enumerate(train_dataloader):
+... clean_images = batch["images"]
+... # 이미지에 더할 노이즈를 샘플링합니다.
+... noise = torch.randn(clean_images.shape, device=clean_images.device)
+... bs = clean_images.shape[0]
+
+... # 각 이미지를 위한 랜덤한 타임스텝(timestep)을 샘플링합니다.
+... timesteps = torch.randint(
+... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
+... dtype=torch.int64
+... )
+
+... # 각 타임스텝의 노이즈 크기에 따라 깨끗한 이미지에 노이즈를 추가합니다.
+... # (이는 foward diffusion 과정입니다.)
+... noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
+
+... with accelerator.accumulate(model):
+... # 노이즈를 반복적으로 예측합니다.
+... noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
+... loss = F.mse_loss(noise_pred, noise)
+... accelerator.backward(loss)
+
+... accelerator.clip_grad_norm_(model.parameters(), 1.0)
+... optimizer.step()
+... lr_scheduler.step()
+... optimizer.zero_grad()
+
+... progress_bar.update(1)
+... logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+... progress_bar.set_postfix(**logs)
+... accelerator.log(logs, step=global_step)
+... global_step += 1
+
+... # 각 에포크가 끝난 후 evaluate()와 몇 가지 데모 이미지를 선택적으로 샘플링하고 모델을 저장합니다.
+... if accelerator.is_main_process:
+... pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
+
+... if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
+... evaluate(config, epoch, pipeline)
+
+... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
+... if config.push_to_hub:
+... upload_folder(
+... repo_id=repo_id,
+... folder_path=config.output_dir,
+... commit_message=f"Epoch {epoch}",
+... ignore_patterns=["step_*", "epoch_*"],
+... )
+... else:
+... pipeline.save_pretrained(config.output_dir)
+```
+
+휴, 코드가 꽤 많았네요! 하지만 🤗 Accelerate의 [`~accelerate.notebook_launcher`] 함수와 학습을 시작할 준비가 되었습니다. 함수에 학습 루프, 모든 학습 인수, 학습에 사용할 프로세스 수(사용 가능한 GPU의 수를 변경할 수 있음)를 전달합니다:
+
+```py
+>>> from accelerate import notebook_launcher
+
+>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
+
+>>> notebook_launcher(train_loop, args, num_processes=1)
+```
+
+한번 학습이 완료되면, diffusion 모델로 생성된 최종 🦋이미지🦋를 확인해보길 바랍니다!
+
+```py
+>>> import glob
+
+>>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
+>>> Image.open(sample_images[-1])
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png)
+
+## 다음 단계
+
+Unconditional 이미지 생성은 학습될 수 있는 작업 중 하나의 예시입니다. 다른 작업과 학습 방법은 [🧨 Diffusers 학습 예시](../training/overview) 페이지에서 확인할 수 있습니다. 다음은 학습할 수 있는 몇 가지 예시입니다:
+
+- [Textual Inversion](../training/text_inversion), 특정 시각적 개념을 학습시켜 생성된 이미지에 통합시키는 알고리즘입니다.
+- [DreamBooth](../training/dreambooth), 주제에 대한 몇 가지 입력 이미지들이 주어지면 주제에 대한 개인화된 이미지를 생성하기 위한 기술입니다.
+- [Guide](../training/text2image) 데이터셋에 Stable Diffusion 모델을 파인튜닝하는 방법입니다.
+- [Guide](../training/lora) LoRA를 사용해 매우 큰 모델을 빠르게 파인튜닝하기 위한 메모리 효율적인 기술입니다.
diff --git a/diffusers/docs/source/ko/tutorials/tutorial_overview.md b/diffusers/docs/source/ko/tutorials/tutorial_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..bf9cf39f64e6206c3a10d24f004b0b0368df4028
--- /dev/null
+++ b/diffusers/docs/source/ko/tutorials/tutorial_overview.md
@@ -0,0 +1,23 @@
+
+
+# Overview
+
+🧨 Diffusers에 오신 걸 환영합니다! 여러분이 diffusion 모델과 생성 AI를 처음 접하고, 더 많은 걸 배우고 싶으셨다면 제대로 찾아오셨습니다. 이 튜토리얼은 diffusion model을 여러분에게 젠틀하게 소개하고, 라이브러리의 기본 사항(핵심 구성요소와 🧨 Diffusers 사용법)을 이해하는 데 도움이 되도록 설계되었습니다.
+
+여러분은 이 튜토리얼을 통해 빠르게 생성하기 위해선 추론 파이프라인을 어떻게 사용해야 하는지, 그리고 라이브러리를 modular toolbox처럼 이용해서 여러분만의 diffusion system을 구축할 수 있도록 파이프라인을 분해하는 법을 배울 수 있습니다. 다음 단원에서는 여러분이 원하는 것을 생성하기 위해 자신만의 diffusion model을 학습하는 방법을 배우게 됩니다.
+
+튜토리얼을 완료한다면 여러분은 라이브러리를 직접 탐색하고, 자신의 프로젝트와 애플리케이션에 적용할 스킬들을 습득할 수 있을 겁니다.
+
+[Discord](https://discord.com/invite/JfAtkvEtRb)나 [포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) 커뮤니티에 자유롭게 참여해서 다른 사용자와 개발자들과 교류하고 협업해 보세요!
+
+자 지금부터 diffusing을 시작해 보겠습니다! 🧨
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/conditional_image_generation.md b/diffusers/docs/source/ko/using-diffusers/conditional_image_generation.md
new file mode 100644
index 0000000000000000000000000000000000000000..5525ac990ca457bc5040c313e0a3d9aad0abdc46
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/conditional_image_generation.md
@@ -0,0 +1,60 @@
+
+
+# 조건부 이미지 생성
+
+[[open-in-colab]]
+
+조건부 이미지 생성을 사용하면 텍스트 프롬프트에서 이미지를 생성할 수 있습니다. 텍스트는 임베딩으로 변환되며, 임베딩은 노이즈에서 이미지를 생성하도록 모델을 조건화하는 데 사용됩니다.
+
+[`DiffusionPipeline`]은 추론을 위해 사전 훈련된 diffusion 시스템을 사용하는 가장 쉬운 방법입니다.
+
+먼저 [`DiffusionPipeline`]의 인스턴스를 생성하고 다운로드할 파이프라인 [체크포인트](https://huggingface.co/models?library=diffusers&sort=downloads)를 지정합니다.
+
+이 가이드에서는 [잠재 Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)과 함께 텍스트-이미지 생성에 [`DiffusionPipeline`]을 사용합니다:
+
+```python
+>>> from diffusers import DiffusionPipeline
+
+>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+```
+
+[`DiffusionPipeline`]은 모든 모델링, 토큰화, 스케줄링 구성 요소를 다운로드하고 캐시합니다.
+이 모델은 약 14억 개의 파라미터로 구성되어 있기 때문에 GPU에서 실행할 것을 강력히 권장합니다.
+PyTorch에서와 마찬가지로 생성기 객체를 GPU로 이동할 수 있습니다:
+
+```python
+>>> generator.to("cuda")
+```
+
+이제 텍스트 프롬프트에서 `생성기`를 사용할 수 있습니다:
+
+```python
+>>> image = generator("An image of a squirrel in Picasso style").images[0]
+```
+
+출력값은 기본적으로 [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 래핑됩니다.
+
+호출하여 이미지를 저장할 수 있습니다:
+
+```python
+>>> image.save("image_of_squirrel_painting.png")
+```
+
+아래 스페이스를 사용해보고 안내 배율 매개변수를 자유롭게 조정하여 이미지 품질에 어떤 영향을 미치는지 확인해 보세요!
+
+
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/contribute_pipeline.md b/diffusers/docs/source/ko/using-diffusers/contribute_pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..415d3da1a10d4ed5bd2ad261287c5d761c865a15
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/contribute_pipeline.md
@@ -0,0 +1,182 @@
+
+
+# 커뮤니티 파이프라인에 기여하는 방법
+
+
+
+💡 모든 사람이 속도 저하 없이 쉽게 작업을 공유할 수 있도록 커뮤니티 파이프라인을 추가하는 이유에 대한 자세한 내용은 GitHub 이슈 [#841](https://github.com/huggingface/diffusers/issues/841)를 참조하세요.
+
+
+
+커뮤니티 파이프라인을 사용하면 [`DiffusionPipeline`] 위에 원하는 추가 기능을 추가할 수 있습니다. `DiffusionPipeline` 위에 구축할 때의 가장 큰 장점은 누구나 인수를 하나만 추가하면 파이프라인을 로드하고 사용할 수 있어 커뮤니티가 매우 쉽게 접근할 수 있다는 것입니다.
+
+이번 가이드에서는 커뮤니티 파이프라인을 생성하는 방법과 작동 원리를 설명합니다.
+간단하게 설명하기 위해 `UNet`이 단일 forward pass를 수행하고 스케줄러를 한 번 호출하는 "one-step" 파이프라인을 만들겠습니다.
+
+## 파이프라인 초기화
+
+커뮤니티 파이프라인을 위한 `one_step_unet.py` 파일을 생성하는 것으로 시작합니다. 이 파일에서, Hub에서 모델 가중치와 스케줄러 구성을 로드할 수 있도록 [`DiffusionPipeline`]을 상속하는 파이프라인 클래스를 생성합니다. one-step 파이프라인에는 `UNet`과 스케줄러가 필요하므로 이를 `__init__` 함수에 인수로 추가해야합니다:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+
+class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+```
+
+파이프라인과 그 구성요소(`unet` and `scheduler`)를 [`~DiffusionPipeline.save_pretrained`]으로 저장할 수 있도록 하려면 `register_modules` 함수에 추가하세요:
+
+```diff
+ from diffusers import DiffusionPipeline
+ import torch
+
+ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
++ self.register_modules(unet=unet, scheduler=scheduler)
+```
+
+이제 '초기화' 단계가 완료되었으니 forward pass로 이동할 수 있습니다! 🔥
+
+## Forward pass 정의
+
+Forward pass 에서는(`__call__`로 정의하는 것이 좋습니다) 원하는 기능을 추가할 수 있는 완전한 창작 자유가 있습니다. 우리의 놀라운 one-step 파이프라인의 경우, 임의의 이미지를 생성하고 `timestep=1`을 설정하여 `unet`과 `scheduler`를 한 번만 호출합니다:
+
+```diff
+ from diffusers import DiffusionPipeline
+ import torch
+
+
+ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
++ def __call__(self):
++ image = torch.randn(
++ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
++ )
++ timestep = 1
+
++ model_output = self.unet(image, timestep).sample
++ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
+
++ return scheduler_output
+```
+
+끝났습니다! 🚀 이제 이 파이프라인에 `unet`과 `scheduler`를 전달하여 실행할 수 있습니다:
+
+```python
+from diffusers import DDPMScheduler, UNet2DModel
+
+scheduler = DDPMScheduler()
+unet = UNet2DModel()
+
+pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)
+
+output = pipeline()
+```
+
+하지만 파이프라인 구조가 동일한 경우 기존 가중치를 파이프라인에 로드할 수 있다는 장점이 있습니다. 예를 들어 one-step 파이프라인에 [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32) 가중치를 로드할 수 있습니다:
+
+```python
+pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32")
+
+output = pipeline()
+```
+
+## 파이프라인 공유
+
+🧨Diffusers [리포지토리](https://github.com/huggingface/diffusers)에서 Pull Request를 열어 [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) 하위 폴더에 `one_step_unet.py`의 멋진 파이프라인을 추가하세요.
+
+병합이 되면, `diffusers >= 0.4.0`이 설치된 사용자라면 누구나 `custom_pipeline` 인수에 지정하여 이 파이프라인을 마술처럼 🪄 사용할 수 있습니다:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
+pipe()
+```
+
+커뮤니티 파이프라인을 공유하는 또 다른 방법은 Hub 에서 선호하는 [모델 리포지토리](https://huggingface.co/docs/hub/models-uploading)에 직접 `one_step_unet.py` 파일을 업로드하는 것입니다. `one_step_unet.py` 파일을 지정하는 대신 모델 저장소 id를 `custom_pipeline` 인수에 전달하세요:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet")
+```
+
+다음 표에서 두 가지 공유 워크플로우를 비교하여 자신에게 가장 적합한 옵션을 결정하는 데 도움이 되는 정보를 확인하세요:
+
+| | GitHub 커뮤니티 파이프라인 | HF Hub 커뮤니티 파이프라인 |
+|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
+| 사용법 | 동일 | 동일 |
+| 리뷰 과정 | 병합하기 전에 GitHub에서 Pull Request를 열고 Diffusers 팀의 검토 과정을 거칩니다. 속도가 느릴 수 있습니다. | 검토 없이 Hub 저장소에 바로 업로드합니다. 가장 빠른 워크플로우 입니다. |
+| 가시성 | 공식 Diffusers 저장소 및 문서에 포함되어 있습니다. | HF 허브 프로필에 포함되며 가시성을 확보하기 위해 자신의 사용량/프로모션에 의존합니다. |
+
+
+
+💡 커뮤니티 파이프라인 파일에 원하는 패키지를 사용할 수 있습니다. 사용자가 패키지를 설치하기만 하면 모든 것이 정상적으로 작동합니다. 파이프라인이 자동으로 감지되므로 `DiffusionPipeline`에서 상속하는 파이프라인 클래스가 하나만 있는지 확인하세요.
+
+
+
+## 커뮤니티 파이프라인은 어떻게 작동하나요?
+
+커뮤니티 파이프라인은 [`DiffusionPipeline`]을 상속하는 클래스입니다:
+
+- [`custom_pipeline`] 인수로 로드할 수 있습니다.
+- 모델 가중치 및 스케줄러 구성은 [`pretrained_model_name_or_path`]에서 로드됩니다.
+- 커뮤니티 파이프라인에서 기능을 구현하는 코드는 `pipeline.py` 파일에 정의되어 있습니다.
+
+공식 저장소에서 모든 파이프라인 구성 요소 가중치를 로드할 수 없는 경우가 있습니다. 이 경우 다른 구성 요소는 파이프라인에 직접 전달해야 합니다:
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPFeatureExtractor, CLIPModel
+
+model_id = "CompVis/stable-diffusion-v1-4"
+clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
+clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
+
+pipeline = DiffusionPipeline.from_pretrained(
+ model_id,
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ torch_dtype=torch.float16,
+)
+```
+
+커뮤니티 파이프라인의 마법은 다음 코드에 담겨 있습니다. 이 코드를 통해 커뮤니티 파이프라인을 GitHub 또는 Hub에서 로드할 수 있으며, 모든 🧨 Diffusers 패키지에서 사용할 수 있습니다.
+
+```python
+# 2. 파이프라인 클래스를 로드합니다. 사용자 지정 모듈을 사용하는 경우 Hub에서 로드합니다
+# 명시적 클래스에서 로드하는 경우, 이를 사용해 보겠습니다.
+if custom_pipeline is not None:
+ pipeline_class = get_class_from_dynamic_module(
+ custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
+ )
+elif cls != DiffusionPipeline:
+ pipeline_class = cls
+else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+```
diff --git a/diffusers/docs/source/ko/using-diffusers/control_brightness.md b/diffusers/docs/source/ko/using-diffusers/control_brightness.md
new file mode 100644
index 0000000000000000000000000000000000000000..522da736ec64c69cfcd1a0f40d6a2ea832f37321
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/control_brightness.md
@@ -0,0 +1,45 @@
+# 이미지 밝기 조절하기
+
+Stable Diffusion 파이프라인은 [일반적인 디퓨전 노이즈 스케줄과 샘플 단계에 결함이 있음](https://huggingface.co/papers/2305.08891) 논문에서 설명한 것처럼 매우 밝거나 어두운 이미지를 생성하는 데는 성능이 평범합니다. 이 논문에서 제안한 솔루션은 현재 [`DDIMScheduler`]에 구현되어 있으며 이미지의 밝기를 개선하는 데 사용할 수 있습니다.
+
+
+
+💡 제안된 솔루션에 대한 자세한 내용은 위에 링크된 논문을 참고하세요!
+
+
+
+해결책 중 하나는 *v 예측값*과 *v 로스*로 모델을 훈련하는 것입니다. 다음 flag를 [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) 또는 [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) 스크립트에 추가하여 `v_prediction`을 활성화합니다:
+
+```bash
+--prediction_type="v_prediction"
+```
+
+예를 들어, `v_prediction`으로 미세 조정된 [`ptx0/pseudo-journey-v2`](https://huggingface.co/ptx0/pseudo-journey-v2) 체크포인트를 사용해 보겠습니다.
+
+다음으로 [`DDIMScheduler`]에서 다음 파라미터를 설정합니다:
+
+1. rescale_betas_zero_snr=True`, 노이즈 스케줄을 제로 터미널 신호 대 잡음비(SNR)로 재조정합니다.
+2. `timestep_spacing="trailing"`, 마지막 타임스텝부터 샘플링 시작
+
+```py
+>>> from diffusers import DiffusionPipeline, DDIMScheduler
+
+>>> pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2")
+# switch the scheduler in the pipeline to use the DDIMScheduler
+
+>>> pipeline.scheduler = DDIMScheduler.from_config(
+... pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
+... )
+>>> pipeline.to("cuda")
+```
+
+마지막으로 파이프라인에 대한 호출에서 `guidance_rescale`을 설정하여 과다 노출을 방지합니다:
+
+```py
+prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
+image = pipeline(prompt, guidance_rescale=0.7).images[0]
+```
+
+
+
+
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/controlling_generation.md b/diffusers/docs/source/ko/using-diffusers/controlling_generation.md
new file mode 100644
index 0000000000000000000000000000000000000000..b018aab9b970a9a47fde1861f00fdbc555571615
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/controlling_generation.md
@@ -0,0 +1,226 @@
+
+
+# 제어된 생성
+
+Diffusion 모델에 의해 생성된 출력을 제어하는 것은 커뮤니티에서 오랫동안 추구해 왔으며 현재 활발한 연구 주제입니다. 널리 사용되는 많은 diffusion 모델에서는 이미지와 텍스트 프롬프트 등 입력의 미묘한 변화로 인해 출력이 크게 달라질 수 있습니다. 이상적인 세계에서는 의미가 유지되고 변경되는 방식을 제어할 수 있기를 원합니다.
+
+의미 보존의 대부분의 예는 입력의 변화를 출력의 변화에 정확하게 매핑하는 것으로 축소됩니다. 즉, 프롬프트에서 피사체에 형용사를 추가하면 전체 이미지가 보존되고 변경된 피사체만 수정됩니다. 또는 특정 피사체의 이미지를 변형하면 피사체의 포즈가 유지됩니다.
+
+추가적으로 생성된 이미지의 품질에는 의미 보존 외에도 영향을 미치고자 하는 품질이 있습니다. 즉, 일반적으로 결과물의 품질이 좋거나 특정 스타일을 고수하거나 사실적이기를 원합니다.
+
+diffusion 모델 생성을 제어하기 위해 `diffusers`가 지원하는 몇 가지 기술을 문서화합니다. 많은 부분이 최첨단 연구이며 미묘한 차이가 있을 수 있습니다. 명확한 설명이 필요하거나 제안 사항이 있으면 주저하지 마시고 [포럼](https://discuss.huggingface.co/) 또는 [GitHub 이슈](https://github.com/huggingface/diffusers/issues)에서 토론을 시작하세요.
+
+생성 제어 방법에 대한 개략적인 설명과 기술 개요를 제공합니다. 기술에 대한 자세한 설명은 파이프라인에서 링크된 원본 논문을 참조하는 것이 가장 좋습니다.
+
+사용 사례에 따라 적절한 기술을 선택해야 합니다. 많은 경우 이러한 기법을 결합할 수 있습니다. 예를 들어, 텍스트 반전과 SEGA를 결합하여 텍스트 반전을 사용하여 생성된 출력에 더 많은 의미적 지침을 제공할 수 있습니다.
+
+별도의 언급이 없는 한, 이러한 기법은 기존 모델과 함께 작동하며 자체 가중치가 필요하지 않은 기법입니다.
+
+1. [Instruct Pix2Pix](#instruct-pix2pix)
+2. [Pix2Pix Zero](#pix2pixzero)
+3. [Attend and Excite](#attend-and-excite)
+4. [Semantic Guidance](#semantic-guidance)
+5. [Self-attention Guidance](#self-attention-guidance)
+6. [Depth2Image](#depth2image)
+7. [MultiDiffusion Panorama](#multidiffusion-panorama)
+8. [DreamBooth](#dreambooth)
+9. [Textual Inversion](#textual-inversion)
+10. [ControlNet](#controlnet)
+11. [Prompt Weighting](#prompt-weighting)
+12. [Custom Diffusion](#custom-diffusion)
+13. [Model Editing](#model-editing)
+14. [DiffEdit](#diffedit)
+15. [T2I-Adapter](#t2i-adapter)
+
+편의를 위해, 추론만 하거나 파인튜닝/학습하는 방법에 대한 표를 제공합니다.
+
+| **Method** | **Inference only** | **Requires training / fine-tuning** | **Comments** |
+| :-------------------------------------------------: | :----------------: | :-------------------------------------: | :---------------------------------------------------------------------------------------------: |
+| [Instruct Pix2Pix](#instruct-pix2pix) | ✅ | ❌ | Can additionally be fine-tuned for better performance on specific edit instructions. |
+| [Pix2Pix Zero](#pix2pixzero) | ✅ | ❌ | |
+| [Attend and Excite](#attend-and-excite) | ✅ | ❌ | |
+| [Semantic Guidance](#semantic-guidance) | ✅ | ❌ | |
+| [Self-attention Guidance](#self-attention-guidance) | ✅ | ❌ | |
+| [Depth2Image](#depth2image) | ✅ | ❌ | |
+| [MultiDiffusion Panorama](#multidiffusion-panorama) | ✅ | ❌ | |
+| [DreamBooth](#dreambooth) | ❌ | ✅ | |
+| [Textual Inversion](#textual-inversion) | ❌ | ✅ | |
+| [ControlNet](#controlnet) | ✅ | ❌ | A ControlNet can be trained/fine-tuned on a custom conditioning. |
+| [Prompt Weighting](#prompt-weighting) | ✅ | ❌ | |
+| [Custom Diffusion](#custom-diffusion) | ❌ | ✅ | |
+| [Model Editing](#model-editing) | ✅ | ❌ | |
+| [DiffEdit](#diffedit) | ✅ | ❌ | |
+| [T2I-Adapter](#t2i-adapter) | ✅ | ❌ | |
+
+## Pix2Pix Instruct
+
+[Paper](https://arxiv.org/abs/2211.09800)
+
+[Instruct Pix2Pix](../api/pipelines/stable_diffusion/pix2pix) 는 입력 이미지 편집을 지원하기 위해 stable diffusion에서 미세-조정되었습니다. 이미지와 편집을 설명하는 프롬프트를 입력으로 받아 편집된 이미지를 출력합니다.
+Instruct Pix2Pix는 [InstructGPT](https://openai.com/blog/instruction-following/)와 같은 프롬프트와 잘 작동하도록 명시적으로 훈련되었습니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/pix2pix)를 참조하세요.
+
+## Pix2Pix Zero
+
+[Paper](https://arxiv.org/abs/2302.03027)
+
+[Pix2Pix Zero](../api/pipelines/stable_diffusion/pix2pix_zero)를 사용하면 일반적인 이미지 의미를 유지하면서 한 개념이나 피사체가 다른 개념이나 피사체로 변환되도록 이미지를 수정할 수 있습니다.
+
+노이즈 제거 프로세스는 한 개념적 임베딩에서 다른 개념적 임베딩으로 안내됩니다. 중간 잠복(intermediate latents)은 디노이징(denoising?) 프로세스 중에 최적화되어 참조 주의 지도(reference attention maps)를 향해 나아갑니다. 참조 주의 지도(reference attention maps)는 입력 이미지의 노이즈 제거(?) 프로세스에서 나온 것으로 의미 보존을 장려하는 데 사용됩니다.
+
+Pix2Pix Zero는 합성 이미지와 실제 이미지를 편집하는 데 모두 사용할 수 있습니다.
+
+- 합성 이미지를 편집하려면 먼저 캡션이 지정된 이미지를 생성합니다.
+ 다음으로 편집할 컨셉과 새로운 타겟 컨셉에 대한 이미지 캡션을 생성합니다. 이를 위해 [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)와 같은 모델을 사용할 수 있습니다. 그런 다음 텍스트 인코더를 통해 소스 개념과 대상 개념 모두에 대한 "평균" 프롬프트 임베딩을 생성합니다. 마지막으로, 합성 이미지를 편집하기 위해 pix2pix-zero 알고리즘을 사용합니다.
+- 실제 이미지를 편집하려면 먼저 [BLIP](https://huggingface.co/docs/transformers/model_doc/blip)과 같은 모델을 사용하여 이미지 캡션을 생성합니다. 그런 다음 프롬프트와 이미지에 ddim 반전을 적용하여 "역(inverse)" latents을 생성합니다. 이전과 마찬가지로 소스 및 대상 개념 모두에 대한 "평균(mean)" 프롬프트 임베딩이 생성되고 마지막으로 "역(inverse)" latents와 결합된 pix2pix-zero 알고리즘이 이미지를 편집하는 데 사용됩니다.
+
+
+
+Pix2Pix Zero는 '제로 샷(zero-shot)' 이미지 편집이 가능한 최초의 모델입니다.
+즉, 이 모델은 다음과 같이 일반 소비자용 GPU에서 1분 이내에 이미지를 편집할 수 있습니다(../api/pipelines/stable_diffusion/pix2pix_zero#usage-example).
+
+
+
+위에서 언급했듯이 Pix2Pix Zero에는 특정 개념으로 세대를 유도하기 위해 (UNet, VAE 또는 텍스트 인코더가 아닌) latents을 최적화하는 기능이 포함되어 있습니다.즉, 전체 파이프라인에 표준 [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img)보다 더 많은 메모리가 필요할 수 있습니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/pix2pix_zero)를 참조하세요.
+
+## Attend and Excite
+
+[Paper](https://arxiv.org/abs/2301.13826)
+
+[Attend and Excite](../api/pipelines/stable_diffusion/attend_and_excite)를 사용하면 프롬프트의 피사체가 최종 이미지에 충실하게 표현되도록 할 수 있습니다.
+
+이미지에 존재해야 하는 프롬프트의 피사체에 해당하는 일련의 토큰 인덱스가 입력으로 제공됩니다. 노이즈 제거 중에 각 토큰 인덱스는 이미지의 최소 한 패치 이상에 대해 최소 주의 임계값을 갖도록 보장됩니다. 모든 피사체 토큰에 대해 주의 임계값이 통과될 때까지 노이즈 제거 프로세스 중에 중간 잠복기가 반복적으로 최적화되어 가장 소홀히 취급되는 피사체 토큰의 주의력을 강화합니다.
+
+Pix2Pix Zero와 마찬가지로 Attend and Excite 역시 파이프라인에 미니 최적화 루프(사전 학습된 가중치를 그대로 둔 채)가 포함되며, 일반적인 'StableDiffusionPipeline'보다 더 많은 메모리가 필요할 수 있습니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/attend_and_excite)를 참조하세요.
+
+## Semantic Guidance (SEGA)
+
+[Paper](https://arxiv.org/abs/2301.12247)
+
+의미유도(SEGA)를 사용하면 이미지에서 하나 이상의 컨셉을 적용하거나 제거할 수 있습니다. 컨셉의 강도도 조절할 수 있습니다. 즉, 스마일 컨셉을 사용하여 인물 사진의 스마일을 점진적으로 늘리거나 줄일 수 있습니다.
+
+분류기 무료 안내(classifier free guidance)가 빈 프롬프트 입력을 통해 안내를 제공하는 방식과 유사하게, SEGA는 개념 프롬프트에 대한 안내를 제공합니다. 이러한 개념 프롬프트는 여러 개를 동시에 적용할 수 있습니다. 각 개념 프롬프트는 안내가 긍정적으로 적용되는지 또는 부정적으로 적용되는지에 따라 해당 개념을 추가하거나 제거할 수 있습니다.
+
+Pix2Pix Zero 또는 Attend and Excite와 달리 SEGA는 명시적인 그라데이션 기반 최적화를 수행하는 대신 확산 프로세스와 직접 상호 작용합니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/semantic_stable_diffusion)를 참조하세요.
+
+## Self-attention Guidance (SAG)
+
+[Paper](https://arxiv.org/abs/2210.00939)
+
+[자기 주의 안내](../api/pipelines/stable_diffusion/self_attention_guidance)는 이미지의 전반적인 품질을 개선합니다.
+
+SAG는 고빈도 세부 정보를 기반으로 하지 않은 예측에서 완전히 조건화된 이미지에 이르기까지 가이드를 제공합니다. 고빈도 디테일은 UNet 자기 주의 맵에서 추출됩니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/self_attention_guidance)를 참조하세요.
+
+## Depth2Image
+
+[Project](https://huggingface.co/stabilityai/stable-diffusion-2-depth)
+
+[Depth2Image](../pipelines/stable_diffusion_2#depthtoimage)는 텍스트 안내 이미지 변화에 대한 시맨틱을 더 잘 보존하도록 안정적 확산에서 미세 조정되었습니다.
+
+원본 이미지의 단안(monocular) 깊이 추정치를 조건으로 합니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion_2#depthtoimage)를 참조하세요.
+
+
+
+InstructPix2Pix와 Pix2Pix Zero와 같은 방법의 중요한 차이점은 전자의 경우
+는 사전 학습된 가중치를 미세 조정하는 반면, 후자는 그렇지 않다는 것입니다. 즉, 다음을 수행할 수 있습니다.
+사용 가능한 모든 안정적 확산 모델에 Pix2Pix Zero를 적용할 수 있습니다.
+
+
+
+## MultiDiffusion Panorama
+
+[Paper](https://arxiv.org/abs/2302.08113)
+
+MultiDiffusion은 사전 학습된 diffusion model을 통해 새로운 생성 프로세스를 정의합니다. 이 프로세스는 고품질의 다양한 이미지를 생성하는 데 쉽게 적용할 수 있는 여러 diffusion 생성 방법을 하나로 묶습니다. 결과는 원하는 종횡비(예: 파노라마) 및 타이트한 분할 마스크에서 바운딩 박스에 이르는 공간 안내 신호와 같은 사용자가 제공한 제어를 준수합니다.
+[MultiDiffusion 파노라마](../api/pipelines/stable_diffusion/panorama)를 사용하면 임의의 종횡비(예: 파노라마)로 고품질 이미지를 생성할 수 있습니다.
+
+파노라마 이미지를 생성하는 데 사용하는 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/panorama)를 참조하세요.
+
+## 나만의 모델 파인튜닝
+
+사전 학습된 모델 외에도 Diffusers는 사용자가 제공한 데이터에 대해 모델을 파인튜닝할 수 있는 학습 스크립트가 있습니다.
+
+## DreamBooth
+
+[DreamBooth](../training/dreambooth)는 모델을 파인튜닝하여 새로운 주제에 대해 가르칩니다. 즉, 한 사람의 사진 몇 장을 사용하여 다양한 스타일로 그 사람의 이미지를 생성할 수 있습니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../training/dreambooth)를 참조하세요.
+
+## Textual Inversion
+
+[Textual Inversion](../training/text_inversion)은 모델을 파인튜닝하여 새로운 개념에 대해 학습시킵니다. 즉, 특정 스타일의 아트웍 사진 몇 장을 사용하여 해당 스타일의 이미지를 생성할 수 있습니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../training/text_inversion)를 참조하세요.
+
+## ControlNet
+
+[Paper](https://arxiv.org/abs/2302.05543)
+
+[ControlNet](../api/pipelines/stable_diffusion/controlnet)은 추가 조건을 추가하는 보조 네트워크입니다.
+가장자리 감지, 낙서, 깊이 맵, 의미적 세그먼트와 같은 다양한 조건에 대해 훈련된 8개의 표준 사전 훈련된 ControlNet이 있습니다,
+깊이 맵, 시맨틱 세그먼테이션과 같은 다양한 조건으로 훈련된 8개의 표준 제어망이 있습니다.
+
+사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/controlnet)를 참조하세요.
+
+## Prompt Weighting
+
+프롬프트 가중치는 텍스트의 특정 부분에 더 많은 관심 가중치를 부여하는 간단한 기법입니다.
+입력에 가중치를 부여하는 간단한 기법입니다.
+
+자세한 설명과 예시는 [여기](../using-diffusers/weighted_prompts)를 참조하세요.
+
+## Custom Diffusion
+
+[Custom Diffusion](../training/custom_diffusion)은 사전 학습된 text-to-image 간 확산 모델의 교차 관심도 맵만 미세 조정합니다.
+또한 textual inversion을 추가로 수행할 수 있습니다. 설계상 다중 개념 훈련을 지원합니다.
+DreamBooth 및 Textual Inversion 마찬가지로, 사용자 지정 확산은 사전학습된 text-to-image diffusion 모델에 새로운 개념을 학습시켜 관심 있는 개념과 관련된 출력을 생성하는 데에도 사용됩니다.
+
+자세한 설명은 [공식 문서](../training/custom_diffusion)를 참조하세요.
+
+## Model Editing
+
+[Paper](https://arxiv.org/abs/2303.08084)
+
+[텍스트-이미지 모델 편집 파이프라인](../api/pipelines/model_editing)을 사용하면 사전학습된 text-to-image diffusion 모델이 입력 프롬프트에 있는 피사체에 대해 내릴 수 있는 잘못된 암시적 가정을 완화하는 데 도움이 됩니다.
+예를 들어, 안정적 확산에 "A pack of roses"에 대한 이미지를 생성하라는 메시지를 표시하면 생성된 이미지의 장미는 빨간색일 가능성이 높습니다. 이 파이프라인은 이러한 가정을 변경하는 데 도움이 됩니다.
+
+자세한 설명은 [공식 문서](../api/pipelines/model_editing)를 참조하세요.
+
+## DiffEdit
+
+[Paper](https://arxiv.org/abs/2210.11427)
+
+[DiffEdit](../api/pipelines/diffedit)를 사용하면 원본 입력 이미지를 최대한 보존하면서 입력 프롬프트와 함께 입력 이미지의 의미론적 편집이 가능합니다.
+
+
+자세한 설명은 [공식 문서](../api/pipelines/diffedit)를 참조하세요.
+
+## T2I-Adapter
+
+[Paper](https://arxiv.org/abs/2302.08453)
+
+[T2I-어댑터](../api/pipelines/stable_diffusion/adapter)는 추가적인 조건을 추가하는 auxiliary 네트워크입니다.
+가장자리 감지, 스케치, depth maps, semantic segmentations와 같은 다양한 조건에 대해 훈련된 8개의 표준 사전훈련된 adapter가 있습니다,
+
+[공식 문서](api/pipelines/stable_diffusion/adapter)에서 사용 방법에 대한 정보를 참조하세요.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/custom_pipeline_examples.md b/diffusers/docs/source/ko/using-diffusers/custom_pipeline_examples.md
new file mode 100644
index 0000000000000000000000000000000000000000..b32e731ea34fcdc773ca18d11b41cc9549611e82
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/custom_pipeline_examples.md
@@ -0,0 +1,275 @@
+
+
+# 커뮤니티 파이프라인
+
+> **커뮤니티 파이프라인에 대한 자세한 내용은 [이 이슈](https://github.com/huggingface/diffusers/issues/841)를 참조하세요.
+
+**커뮤니티** 예제는 커뮤니티에서 추가한 추론 및 훈련 예제로 구성되어 있습니다.
+다음 표를 참조하여 모든 커뮤니티 예제에 대한 개요를 확인하시기 바랍니다. **코드 예제**를 클릭하면 복사하여 붙여넣기할 수 있는 코드 예제를 확인할 수 있습니다.
+커뮤니티가 예상대로 작동하지 않는 경우 이슈를 개설하고 작성자에게 핑을 보내주세요.
+
+| 예 | 설명 | 코드 예제 | 콜랩 |저자 |
+|:---------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:|
+| CLIP Guided Stable Diffusion | CLIP 가이드 기반의 Stable Diffusion으로 텍스트에서 이미지로 생성하기 | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![콜랩에서 열기](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
+| One Step U-Net (Dummy) | 커뮤니티 파이프라인을 어떻게 사용해야 하는지에 대한 예시(참고 https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Stable Diffusion Interpolation | 서로 다른 프롬프트/시드 간 Stable Diffusion의 latent space 보간 | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
+| Stable Diffusion Mega | 모든 기능을 갖춘 **하나의** Stable Diffusion 파이프라인 [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Long Prompt Weighting Stable Diffusion | 토큰 길이 제한이 없고 프롬프트에서 파싱 가중치 지원을 하는 **하나의** Stable Diffusion 파이프라인, | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) |- | [SkyTNT](https://github.com/SkyTNT) |
+| Speech to Image | 자동 음성 인식을 사용하여 텍스트를 작성하고 Stable Diffusion을 사용하여 이미지를 생성합니다. | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) |
+
+커스텀 파이프라인을 불러오려면 `diffusers/examples/community`에 있는 파일 중 하나로서 `custom_pipeline` 인수를 `DiffusionPipeline`에 전달하기만 하면 됩니다. 자신만의 파이프라인이 있는 PR을 보내주시면 빠르게 병합해드리겠습니다.
+```py
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", custom_pipeline="filename_in_the_community_folder"
+)
+```
+
+## 사용 예시
+
+### CLIP 가이드 기반의 Stable Diffusion
+
+모든 노이즈 제거 단계에서 추가 CLIP 모델을 통해 Stable Diffusion을 가이드함으로써 CLIP 모델 기반의 Stable Diffusion은 보다 더 사실적인 이미지를 생성을 할 수 있습니다.
+
+다음 코드는 약 12GB의 GPU RAM이 필요합니다.
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+import torch
+
+
+feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
+clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16)
+
+
+guided_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+)
+guided_pipeline.enable_attention_slicing()
+guided_pipeline = guided_pipeline.to("cuda")
+
+prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+
+generator = torch.Generator(device="cuda").manual_seed(0)
+images = []
+for i in range(4):
+ image = guided_pipeline(
+ prompt,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+# 이미지 로컬에 저장하기
+for i, img in enumerate(images):
+ img.save(f"./clip_guided_sd/image_{i}.png")
+```
+
+이미지` 목록에는 로컬에 저장하거나 구글 콜랩에 직접 표시할 수 있는 PIL 이미지 목록이 포함되어 있습니다. 생성된 이미지는 기본적으로 안정적인 확산을 사용하는 것보다 품질이 높은 경향이 있습니다. 예를 들어 위의 스크립트는 다음과 같은 이미지를 생성합니다:
+
+![clip_guidance](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/clip_guidance/merged_clip_guidance.jpg).
+
+### One Step Unet
+
+예시 "one-step-unet"는 다음과 같이 실행할 수 있습니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
+pipe()
+```
+
+**참고**: 이 커뮤니티 파이프라인은 기능으로 유용하지 않으며 커뮤니티 파이프라인을 추가할 수 있는 방법의 예시일 뿐입니다(https://github.com/huggingface/diffusers/issues/841 참조).
+
+### Stable Diffusion Interpolation
+
+다음 코드는 최소 8GB VRAM의 GPU에서 실행할 수 있으며 약 5분 정도 소요됩니다.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ torch_dtype=torch.float16,
+ safety_checker=None, # Very important for videos...lots of false positives while interpolating
+ custom_pipeline="interpolate_stable_diffusion",
+).to("cuda")
+pipe.enable_attention_slicing()
+
+frame_filepaths = pipe.walk(
+ prompts=["a dog", "a cat", "a horse"],
+ seeds=[42, 1337, 1234],
+ num_interpolation_steps=16,
+ output_dir="./dreams",
+ batch_size=4,
+ height=512,
+ width=512,
+ guidance_scale=8.5,
+ num_inference_steps=50,
+)
+```
+
+walk(...)` 함수의 출력은 `output_dir`에 정의된 대로 폴더에 저장된 이미지 목록을 반환합니다. 이 이미지를 사용하여 안정적으로 확산되는 동영상을 만들 수 있습니다.
+
+> 안정된 확산을 이용한 동영상 제작 방법과 더 많은 기능에 대한 자세한 내용은 https://github.com/nateraw/stable-diffusion-videos 에서 확인하시기 바랍니다.
+
+### Stable Diffusion Mega
+
+The Stable Diffusion Mega 파이프라인을 사용하면 Stable Diffusion 파이프라인의 주요 사용 사례를 단일 클래스에서 사용할 수 있습니다.
+```python
+#!/usr/bin/env python3
+from diffusers import DiffusionPipeline
+import PIL
+import requests
+from io import BytesIO
+import torch
+
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="stable_diffusion_mega",
+ torch_dtype=torch.float16,
+)
+pipe.to("cuda")
+pipe.enable_attention_slicing()
+
+
+### Text-to-Image
+
+images = pipe.text2img("An astronaut riding a horse").images
+
+### Image-to-Image
+
+init_image = download_image(
+ "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+)
+
+prompt = "A fantasy landscape, trending on artstation"
+
+images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+
+### Inpainting
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+prompt = "a cat sitting on a bench"
+images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
+```
+
+위에 표시된 것처럼 하나의 파이프라인에서 '텍스트-이미지 변환', '이미지-이미지 변환', '인페인팅'을 모두 실행할 수 있습니다.
+
+### Long Prompt Weighting Stable Diffusion
+
+파이프라인을 사용하면 77개의 토큰 길이 제한 없이 프롬프트를 입력할 수 있습니다. 또한 "()"를 사용하여 단어 가중치를 높이거나 "[]"를 사용하여 단어 가중치를 낮출 수 있습니다.
+또한 파이프라인을 사용하면 단일 클래스에서 Stable Diffusion 파이프라인의 주요 사용 사례를 사용할 수 있습니다.
+
+#### pytorch
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "hakurei/waifu-diffusion", custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16
+)
+pipe = pipe.to("cuda")
+
+prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
+neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+```
+
+#### onnxruntime
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="lpw_stable_diffusion_onnx",
+ revision="onnx",
+ provider="CUDAExecutionProvider",
+)
+
+prompt = "a photo of an astronaut riding a horse on mars, best quality"
+neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+```
+
+토큰 인덱스 시퀀스 길이가 이 모델에 지정된 최대 시퀀스 길이보다 길면(*** > 77). 이 시퀀스를 모델에서 실행하면 인덱싱 오류가 발생합니다`. 정상적인 현상이니 걱정하지 마세요.
+### Speech to Image
+
+다음 코드는 사전학습된 OpenAI whisper-small과 Stable Diffusion을 사용하여 오디오 샘플에서 이미지를 생성할 수 있습니다.
+```Python
+import torch
+
+import matplotlib.pyplot as plt
+from datasets import load_dataset
+from diffusers import DiffusionPipeline
+from transformers import (
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+audio_sample = ds[3]
+
+text = audio_sample["text"].lower()
+speech_data = audio_sample["audio"]["array"]
+
+model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
+processor = WhisperProcessor.from_pretrained("openai/whisper-small")
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="speech_to_image_diffusion",
+ speech_model=model,
+ speech_processor=processor,
+
+ torch_dtype=torch.float16,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
+
+output = diffuser_pipeline(speech_data)
+plt.imshow(output.images[0])
+```
+위 예시는 다음의 결과 이미지를 보입니다.
+
+![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png)
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/custom_pipeline_overview.md b/diffusers/docs/source/ko/using-diffusers/custom_pipeline_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..0361e7b9edd5ad6ea1a071d9b32d9a032450cae3
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/custom_pipeline_overview.md
@@ -0,0 +1,56 @@
+
+
+# 커스텀 파이프라인 불러오기
+
+[[open-in-colab]]
+
+커뮤니티 파이프라인은 논문에 명시된 원래의 구현체와 다른 형태로 구현된 모든 [`DiffusionPipeline`] 클래스를 의미합니다. (예를 들어, [`StableDiffusionControlNetPipeline`]는 ["Text-to-Image Generation with ControlNet Conditioning"](https://arxiv.org/abs/2302.05543) 해당) 이들은 추가 기능을 제공하거나 파이프라인의 원래 구현을 확장합니다.
+
+[Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) 또는 [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion) 과 같은 멋진 커뮤니티 파이프라인이 많이 있으며 [여기에서](https://github.com/huggingface/diffusers/tree/main/examples/community) 모든 공식 커뮤니티 파이프라인을 찾을 수 있습니다.
+
+허브에서 커뮤니티 파이프라인을 로드하려면, 커뮤니티 파이프라인의 리포지토리 ID와 (파이프라인 가중치 및 구성 요소를 로드하려는) 모델의 리포지토리 ID를 인자로 전달해야 합니다. 예를 들어, 아래 예시에서는 `hf-internal-testing/diffusers-dummy-pipeline`에서 더미 파이프라인을 불러오고, `google/ddpm-cifar10-32`에서 파이프라인의 가중치와 컴포넌트들을 로드합니다.
+
+
+
+🔒 허깅 페이스 허브에서 커뮤니티 파이프라인을 불러오는 것은 곧 해당 코드가 안전하다고 신뢰하는 것입니다. 코드를 자동으로 불러오고 실행하기 앞서 반드시 온라인으로 해당 코드의 신뢰성을 검사하세요!
+
+
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
+)
+```
+
+공식 커뮤니티 파이프라인을 불러오는 것은 비슷하지만, 공식 리포지토리 ID에서 가중치를 불러오는 것과 더불어 해당 파이프라인 내의 컴포넌트를 직접 지정하는 것 역시 가능합니다. 아래 예제를 보면 커뮤니티 [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) 파이프라인을 로드할 때, 해당 파이프라인에서 사용할 `clip_model` 컴포넌트와 `feature_extractor` 컴포넌트를 직접 설정하는 것을 확인할 수 있습니다.
+
+```py
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+
+clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+
+feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id)
+clip_model = CLIPModel.from_pretrained(clip_model_id)
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+)
+```
+
+커뮤니티 파이프라인에 대한 자세한 내용은 [커뮤니티 파이프라인](https://github.com/huggingface/diffusers/blob/main/docs/source/en/using-diffusers/custom_pipeline_examples) 가이드를 살펴보세요. 커뮤니티 파이프라인 등록에 관심이 있는 경우 [커뮤니티 파이프라인에 기여하는 방법](https://github.com/huggingface/diffusers/blob/main/docs/source/en/using-diffusers/contribute_pipeline)에 대한 가이드를 확인하세요 !
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/depth2img.md b/diffusers/docs/source/ko/using-diffusers/depth2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..b5602e3081daa6089265e002cc4df1cd8473a1e3
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/depth2img.md
@@ -0,0 +1,57 @@
+
+
+# Text-guided depth-to-image 생성
+
+[[open-in-colab]]
+
+[`StableDiffusionDepth2ImgPipeline`]을 사용하면 텍스트 프롬프트와 초기 이미지를 전달하여 새 이미지의 생성을 조절할 수 있습니다. 또한 이미지 구조를 보존하기 위해 `depth_map`을 전달할 수도 있습니다. `depth_map`이 제공되지 않으면 파이프라인은 통합된 [depth-estimation model](https://github.com/isl-org/MiDaS)을 통해 자동으로 깊이를 예측합니다.
+
+
+먼저 [`StableDiffusionDepth2ImgPipeline`]의 인스턴스를 생성합니다:
+
+```python
+import torch
+import requests
+from PIL import Image
+
+from diffusers import StableDiffusionDepth2ImgPipeline
+
+pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-depth",
+ torch_dtype=torch.float16,
+).to("cuda")
+```
+
+이제 프롬프트를 파이프라인에 전달합니다. 특정 단어가 이미지 생성을 가이드 하는것을 방지하기 위해 `negative_prompt`를 전달할 수도 있습니다:
+
+```python
+url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+init_image = Image.open(requests.get(url, stream=True).raw)
+prompt = "two tigers"
+n_prompt = "bad, deformed, ugly, bad anatomy"
+image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0]
+image
+```
+
+| Input | Output |
+|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|
+| | |
+
+아래의 Spaces를 가지고 놀며 depth map이 있는 이미지와 없는 이미지의 차이가 있는지 확인해 보세요!
+
+
diff --git a/diffusers/docs/source/ko/using-diffusers/img2img.md b/diffusers/docs/source/ko/using-diffusers/img2img.md
new file mode 100644
index 0000000000000000000000000000000000000000..d99d803339f1f1b00113f977710cc9bd1e246ec7
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/img2img.md
@@ -0,0 +1,100 @@
+
+
+# 텍스트 기반 image-to-image 생성
+
+[[open-in-colab]]
+
+[`StableDiffusionImg2ImgPipeline`]을 사용하면 텍스트 프롬프트와 시작 이미지를 전달하여 새 이미지 생성의 조건을 지정할 수 있습니다.
+
+시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:
+
+```bash
+!pip install diffusers transformers ftfy accelerate
+```
+
+[`nitrosocke/Ghibli-Diffusion`](https://huggingface.co/nitrosocke/Ghibli-Diffusion)과 같은 사전학습된 stable diffusion 모델로 [`StableDiffusionImg2ImgPipeline`]을 생성하여 시작하세요.
+
+
+```python
+import torch
+import requests
+from PIL import Image
+from io import BytesIO
+from diffusers import StableDiffusionImg2ImgPipeline
+
+device = "cuda"
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained("nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16).to(
+ device
+)
+```
+
+초기 이미지를 다운로드하고 사전 처리하여 파이프라인에 전달할 수 있습니다:
+
+```python
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image.thumbnail((768, 768))
+init_image
+```
+
+
+
+
+
+
+
+💡 `strength`는 입력 이미지에 추가되는 노이즈의 양을 제어하는 0.0에서 1.0 사이의 값입니다. 1.0에 가까운 값은 다양한 변형을 허용하지만 입력 이미지와 의미적으로 일치하지 않는 이미지를 생성합니다.
+
+
+
+프롬프트를 정의하고(지브리 스타일(Ghibli-style)에 맞게 조정된 이 체크포인트의 경우 프롬프트 앞에 `ghibli style` 토큰을 붙여야 합니다) 파이프라인을 실행합니다:
+
+```python
+prompt = "ghibli style, a fantasy landscape with castles"
+generator = torch.Generator(device=device).manual_seed(1024)
+image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
+image
+```
+
+
+
+
+
+다른 스케줄러로 실험하여 출력에 어떤 영향을 미치는지 확인할 수도 있습니다:
+
+```python
+from diffusers import LMSDiscreteScheduler
+
+lms = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
+pipe.scheduler = lms
+generator = torch.Generator(device=device).manual_seed(1024)
+image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
+image
+```
+
+
+
+
+
+아래 공백을 확인하고 `strength` 값을 다르게 설정하여 이미지를 생성해 보세요. `strength`를 낮게 설정하면 원본 이미지와 더 유사한 이미지가 생성되는 것을 확인할 수 있습니다.
+
+자유롭게 스케줄러를 [`LMSDiscreteScheduler`]로 전환하여 출력에 어떤 영향을 미치는지 확인해 보세요.
+
+
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/inpaint.md b/diffusers/docs/source/ko/using-diffusers/inpaint.md
new file mode 100644
index 0000000000000000000000000000000000000000..c817a8fa80dd6c06c7fe6e9ef763b4874bd0b2e1
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/inpaint.md
@@ -0,0 +1,75 @@
+
+
+# Text-guided 이미지 인페인팅(inpainting)
+
+[[open-in-colab]]
+
+[`StableDiffusionInpaintPipeline`]은 마스크와 텍스트 프롬프트를 제공하여 이미지의 특정 부분을 편집할 수 있도록 합니다. 이 기능은 인페인팅 작업을 위해 특별히 훈련된 [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting)과 같은 Stable Diffusion 버전을 사용합니다.
+
+먼저 [`StableDiffusionInpaintPipeline`] 인스턴스를 불러옵니다:
+
+```python
+import PIL
+import requests
+import torch
+from io import BytesIO
+
+from diffusers import StableDiffusionInpaintPipeline
+
+pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ torch_dtype=torch.float16,
+)
+pipeline = pipeline.to("cuda")
+```
+
+나중에 교체할 강아지 이미지와 마스크를 다운로드하세요:
+
+```python
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+```
+
+이제 마스크를 다른 것으로 교체하라는 프롬프트를 만들 수 있습니다:
+
+```python
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+`image` | `mask_image` | `prompt` | output |
+:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|
+ | | ***Face of a yellow cat, high resolution, sitting on a park bench*** | |
+
+
+
+이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.
+
+
+
+아래 Space에서 이미지 인페인팅을 직접 해보세요!
+
+
diff --git a/diffusers/docs/source/ko/using-diffusers/loading.md b/diffusers/docs/source/ko/using-diffusers/loading.md
new file mode 100644
index 0000000000000000000000000000000000000000..8b21ed4478b134ab696c1eb9b77eb0d9db25b293
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/loading.md
@@ -0,0 +1,442 @@
+
+
+
+
+# 파이프라인, 모델, 스케줄러 불러오기
+
+기본적으로 diffusion 모델은 다양한 컴포넌트들(모델, 토크나이저, 스케줄러) 간의 복잡한 상호작용을 기반으로 동작합니다. 디퓨저스(Diffusers)는 이러한 diffusion 모델을 보다 쉽고 간편한 API로 제공하는 것을 목표로 설계되었습니다. [`DiffusionPipeline`]은 diffusion 모델이 갖는 복잡성을 하나의 파이프라인 API로 통합하고, 동시에 이를 구성하는 각각의 컴포넌트들을 태스크에 맞춰 유연하게 커스터마이징할 수 있도록 지원하고 있습니다.
+
+diffusion 모델의 훈련과 추론에 필요한 모든 것은 [`DiffusionPipeline.from_pretrained`] 메서드를 통해 접근할 수 있습니다. (이 말의 의미는 다음 단락에서 보다 자세하게 다뤄보도록 하겠습니다.)
+
+이 문서에서는 설명할 내용은 다음과 같습니다.
+
+* 허브를 통해 혹은 로컬로 파이프라인을 불러오는 법
+
+* 파이프라인에 다른 컴포넌트들을 적용하는 법
+* 오리지널 체크포인트가 아닌 variant를 불러오는 법 (variant란 기본으로 설정된 `fp32`가 아닌 다른 부동 소수점 타입(예: `fp16`)을 사용하거나 Non-EMA 가중치를 사용하는 체크포인트들을 의미합니다.)
+* 모델과 스케줄러를 불러오는 법
+
+
+
+## Diffusion 파이프라인
+
+
+
+💡 [`DiffusionPipeline`] 클래스가 동작하는 방식에 보다 자세한 내용이 궁금하다면, [DiffusionPipeline explained](#diffusionpipeline에-대해-알아보기) 섹션을 확인해보세요.
+
+
+
+[`DiffusionPipeline`] 클래스는 diffusion 모델을 [허브](https://huggingface.co/models?library=diffusers)로부터 불러오는 가장 심플하면서 보편적인 방식입니다. [`DiffusionPipeline.from_pretrained`] 메서드는 적합한 파이프라인 클래스를 자동으로 탐지하고, 필요한 구성요소(configuration)와 가중치(weight) 파일들을 다운로드하고 캐싱한 다음, 해당 파이프라인 인스턴스를 반환합니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipe = DiffusionPipeline.from_pretrained(repo_id)
+```
+
+물론 [`DiffusionPipeline`] 클래스를 사용하지 않고, 명시적으로 직접 해당 파이프라인 클래스를 불러오는 것도 가능합니다. 아래 예시 코드는 위 예시와 동일한 인스턴스를 반환합니다.
+
+```python
+from diffusers import StableDiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(repo_id)
+```
+
+[CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)이나 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) 같은 체크포인트들의 경우, 하나 이상의 다양한 태스크에 활용될 수 있습니다. (예를 들어 위의 두 체크포인트의 경우, text-to-image와 image-to-image에 모두 활용될 수 있습니다.) 만약 이러한 체크포인트들을 기본 설정 태스크가 아닌 다른 태스크에 활용하고자 한다면, 해당 태스크에 대응되는 파이프라인(task-specific pipeline)을 사용해야 합니다.
+
+```python
+from diffusers import StableDiffusionImg2ImgPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
+```
+
+
+
+### 로컬 파이프라인
+
+파이프라인을 로컬로 불러오고자 한다면, `git-lfs`를 사용하여 직접 체크포인트를 로컬 디스크에 다운로드 받아야 합니다. 아래의 명령어를 실행하면 `./stable-diffusion-v1-5`란 이름으로 폴더가 로컬디스크에 생성됩니다.
+
+```bash
+git lfs install
+git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+그런 다음 해당 로컬 경로를 [`~DiffusionPipeline.from_pretrained`] 메서드에 전달합니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "./stable-diffusion-v1-5"
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
+```
+
+위의 예시코드처럼 만약 `repo_id`가 로컬 패스(local path)라면, [`~DiffusionPipeline.from_pretrained`] 메서드는 이를 자동으로 감지하여 허브에서 파일을 다운로드하지 않습니다. 만약 로컬 디스크에 저장된 파이프라인 체크포인트가 최신 버전이 아닐 경우에도, 최신 버전을 다운로드하지 않고 기존 로컬 디스크에 저장된 체크포인트를 사용한다는 것을 의미합니다.
+
+
+
+### 파이프라인 내부의 컴포넌트 교체하기
+
+파이프라인 내부의 컴포넌트들은 호환 가능한 다른 컴포넌트로 교체될 수 있습니다. 이와 같은 컴포넌트 교체가 중요한 이유는 다음과 같습니다.
+
+- 어떤 스케줄러를 사용할 것인가는 생성속도와 생성품질 간의 트레이드오프를 정의하는 중요한 요소입니다.
+- diffusion 모델 내부의 컴포넌트들은 일반적으로 각각 독립적으로 훈련되기 때문에, 더 좋은 성능을 보여주는 컴포넌트가 있다면 그걸로 교체하는 식으로 성능을 향상시킬 수 있습니다.
+- 파인 튜닝 단계에서는 일반적으로 UNet 혹은 텍스트 인코더와 같은 일부 컴포넌트들만 훈련하게 됩니다.
+
+어떤 스케줄러들이 호환가능한지는 `compatibles` 속성을 통해 확인할 수 있습니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
+stable_diffusion.scheduler.compatibles
+```
+
+이번에는 [`SchedulerMixin.from_pretrained`] 메서드를 사용해서, 기존 기본 스케줄러였던 [`PNDMScheduler`]를 보다 우수한 성능의 [`EulerDiscreteScheduler`]로 바꿔봅시다. 스케줄러를 로드할 때는 `subfolder` 인자를 통해, 해당 파이프라인의 리포지토리에서 [스케줄러에 관한 하위폴더](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)를 명시해주어야 합니다.
+
+그 다음 새롭게 생성한 [`EulerDiscreteScheduler`] 인스턴스를 [`DiffusionPipeline`]의 `scheduler` 인자에 전달합니다.
+
+```python
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+
+scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
+```
+
+### 세이프티 체커
+
+스테이블 diffusion과 같은 diffusion 모델들은 유해한 이미지를 생성할 수도 있습니다. 이를 예방하기 위해 디퓨저스는 생성된 이미지의 유해성을 판단하는 [세이프티 체커(safety checker)](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) 기능을 지원하고 있습니다. 만약 세이프티 체커의 사용을 원하지 않는다면, `safety_checker` 인자에 `None`을 전달해주시면 됩니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None)
+```
+
+### 컴포넌트 재사용
+
+복수의 파이프라인에 동일한 모델이 반복적으로 사용한다면, 굳이 해당 모델의 동일한 가중치를 중복으로 RAM에 불러올 필요는 없을 것입니다. [`~DiffusionPipeline.components`] 속성을 통해 파이프라인 내부의 컴포넌트들을 참조할 수 있는데, 이번 단락에서는 이를 통해 동일한 모델 가중치를 RAM에 중복으로 불러오는 것을 방지하는 법에 대해 알아보겠습니다.
+
+```python
+from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
+
+components = stable_diffusion_txt2img.components
+```
+
+그 다음 위 예시 코드에서 선언한 `components` 변수를 다른 파이프라인에 전달함으로써, 모델의 가중치를 중복으로 RAM에 로딩하지 않고, 동일한 컴포넌트를 재사용할 수 있습니다.
+
+```python
+stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
+```
+
+물론 각각의 컴포넌트들을 따로 따로 파이프라인에 전달할 수도 있습니다. 예를 들어 `stable_diffusion_txt2img` 파이프라인 안의 컴포넌트들 가운데서 세이프티 체커(`safety_checker`)와 피쳐 익스트랙터(`feature_extractor`)를 제외한 컴포넌트들만 `stable_diffusion_img2img` 파이프라인에서 재사용하는 방식 역시 가능합니다.
+
+```python
+from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
+stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
+ vae=stable_diffusion_txt2img.vae,
+ text_encoder=stable_diffusion_txt2img.text_encoder,
+ tokenizer=stable_diffusion_txt2img.tokenizer,
+ unet=stable_diffusion_txt2img.unet,
+ scheduler=stable_diffusion_txt2img.scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=False,
+)
+```
+
+## Checkpoint variants
+
+Variant란 일반적으로 다음과 같은 체크포인트들을 의미합니다.
+
+- `torch.float16`과 같이 정밀도는 더 낮지만, 용량 역시 더 작은 부동소수점 타입의 가중치를 사용하는 체크포인트. *(다만 이와 같은 variant의 경우, 추가적인 훈련과 CPU환경에서의 구동이 불가능합니다.)*
+- Non-EMA 가중치를 사용하는 체크포인트. *(Non-EMA 가중치의 경우, 파인 튜닝 단계에서 사용하는 것이 권장되는데, 추론 단계에선 사용하지 않는 것이 권장됩니다.)*
+
+
+
+💡 모델 구조는 동일하지만 서로 다른 학습 환경에서 서로 다른 데이터셋으로 학습된 체크포인트들이 있을 경우, 해당 체크포인트들은 variant 단계가 아닌 리포지토리 단계에서 분리되어 관리되어야 합니다. (즉, 해당 체크포인트들은 서로 다른 리포지토리에서 따로 관리되어야 합니다. 예시: [`stable-diffusion-v1-4`], [`stable-diffusion-v1-5`]).
+
+
+
+| **checkpoint type** | **weight name** | **argument for loading weights** |
+| ------------------- | ----------------------------------- | -------------------------------- |
+| original | diffusion_pytorch_model.bin | |
+| floating point | diffusion_pytorch_model.fp16.bin | `variant`, `torch_dtype` |
+| non-EMA | diffusion_pytorch_model.non_ema.bin | `variant` |
+
+variant를 로드할 때 2개의 중요한 argument가 있습니다.
+
+* `torch_dtype`은 불러올 체크포인트의 부동소수점을 정의합니다. 예를 들어 `torch_dtype=torch.float16`을 명시함으로써 가중치의 부동소수점 타입을 `fl16`으로 변환할 수 있습니다. (만약 따로 설정하지 않을 경우, 기본값으로 `fp32` 타입의 가중치가 로딩됩니다.) 또한 `variant` 인자를 명시하지 않은 채로 체크포인트를 불러온 다음, 해당 체크포인트를 `torch_dtype=torch.float16` 인자를 통해 `fp16` 타입으로 변환하는 것 역시 가능합니다. 이 경우 기본으로 설정된 `fp32` 가중치가 먼저 다운로드되고, 해당 가중치들을 불러온 다음 `fp16` 타입으로 변환하게 됩니다.
+* `variant` 인자는 리포지토리에서 어떤 variant를 불러올 것인가를 정의합니다. 가령 [`diffusers/stable-diffusion-variants`](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main/unet) 리포지토리로부터 `non_ema` 체크포인트를 불러오고자 한다면, `variant="non_ema"` 인자를 전달해야 합니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+# load fp16 variant
+stable_diffusion = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
+)
+# load non_ema variant
+stable_diffusion = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
+```
+
+다른 부동소수점 타입의 가중치 혹은 non-EMA 가중치를 사용하는 체크포인트를 저장하기 위해서는, [`DiffusionPipeline.save_pretrained`] 메서드를 사용해야 하며, 이 때 `variant` 인자를 명시해줘야 합니다. 원래의 체크포인트와 동일한 폴더에 variant를 저장해야 하며, 이렇게 하면 동일한 폴더에서 오리지널 체크포인트과 variant를 모두 불러올 수 있습니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+# save as fp16 variant
+stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16")
+# save as non-ema variant
+stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
+```
+
+만약 variant를 기존 폴더에 저장하지 않을 경우, `variant` 인자를 반드시 명시해야 합니다. 그렇게 하지 않을 경우 원래의 오리지널 체크포인트를 찾을 수 없게 되기 때문에 에러가 발생합니다.
+
+```python
+# 👎 this won't work
+stable_diffusion = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16)
+# 👍 this works
+stable_diffusion = DiffusionPipeline.from_pretrained(
+ "./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
+)
+```
+
+### 모델 불러오기
+
+모델들은 [`ModelMixin.from_pretrained`] 메서드를 통해 불러올 수 있습니다. 해당 메서드는 최신 버전의 모델 가중치 파일과 설정 파일(configurations)을 다운로드하고 캐싱합니다. 만약 이러한 파일들이 최신 버전으로 로컬 캐시에 저장되어 있다면, [`ModelMixin.from_pretrained`]는 굳이 해당 파일들을 다시 다운로드하지 않으며, 그저 캐시에 있는 최신 파일들을 재사용합니다.
+
+모델은 `subfolder` 인자에 명시된 하위 폴더로부터 로드됩니다. 예를 들어 `runwayml/stable-diffusion-v1-5`의 UNet 모델의 가중치는 [`unet`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet) 폴더에 저장되어 있습니다.
+
+```python
+from diffusers import UNet2DConditionModel
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
+```
+
+혹은 [해당 모델의 리포지토리](https://huggingface.co/google/ddpm-cifar10-32/tree/main)로부터 다이렉트로 가져오는 것 역시 가능합니다.
+
+```python
+from diffusers import UNet2DModel
+
+repo_id = "google/ddpm-cifar10-32"
+model = UNet2DModel.from_pretrained(repo_id)
+```
+
+또한 앞서 봤던 `variant` 인자를 명시함으로써, Non-EMA나 `fp16`의 가중치를 가져오는 것 역시 가능합니다.
+
+```python
+from diffusers import UNet2DConditionModel
+
+model = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non-ema")
+model.save_pretrained("./local-unet", variant="non-ema")
+```
+
+### 스케줄러
+
+스케줄러들은 [`SchedulerMixin.from_pretrained`] 메서드를 통해 불러올 수 있습니다. 모델과 달리 스케줄러는 별도의 가중치를 갖지 않으며, 따라서 당연히 별도의 학습과정을 요구하지 않습니다. 이러한 스케줄러들은 (해당 스케줄러 하위폴더의) configration 파일을 통해 정의됩니다.
+
+여러개의 스케줄러를 불러온다고 해서 많은 메모리를 소모하는 것은 아니며, 다양한 스케줄러들에 동일한 스케줄러 configration을 적용하는 것 역시 가능합니다. 다음 예시 코드에서 불러오는 스케줄러들은 모두 [`StableDiffusionPipeline`]과 호환되는데, 이는 곧 해당 스케줄러들에 동일한 스케줄러 configration 파일을 적용할 수 있음을 의미합니다.
+
+```python
+from diffusers import StableDiffusionPipeline
+from diffusers import (
+ DDPMScheduler,
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+)
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+
+ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
+
+# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler`
+pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
+```
+
+### DiffusionPipeline에 대해 알아보기
+
+클래스 메서드로서 [`DiffusionPipeline.from_pretrained`]은 2가지를 담당합니다.
+
+- 첫째로, `from_pretrained` 메서드는 최신 버전의 파이프라인을 다운로드하고, 캐시에 저장합니다. 이미 로컬 캐시에 최신 버전의 파이프라인이 저장되어 있다면, [`DiffusionPipeline.from_pretrained`]은 해당 파일들을 다시 다운로드하지 않고, 로컬 캐시에 저장되어 있는 파이프라인을 불러옵니다.
+- `model_index.json` 파일을 통해 체크포인트에 대응되는 적합한 파이프라인 클래스로 불러옵니다.
+
+파이프라인의 폴더 구조는 해당 파이프라인 클래스의 구조와 직접적으로 일치합니다. 예를 들어 [`StableDiffusionPipeline`] 클래스는 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 리포지토리와 대응되는 구조를 갖습니다.
+
+```python
+from diffusers import DiffusionPipeline
+
+repo_id = "runwayml/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(repo_id)
+print(pipeline)
+```
+
+위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다.
+
+- `"feature_extractor"`: [`~transformers.CLIPFeatureExtractor`]의 인스턴스
+- `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32)
+- `"scheduler"`: [`PNDMScheduler`]의 인스턴스
+- `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스
+- `"tokenizer"`: a [`~transformers.CLIPTokenizer`]의 인스턴스
+- `"unet"`: [`UNet2DConditionModel`]의 인스턴스
+- `"vae"` [`AutoencoderKL`]의 인스턴스
+
+```json
+StableDiffusionPipeline {
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "safety_checker": [
+ "stable_diffusion",
+ "StableDiffusionSafetyChecker"
+ ],
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ "text_encoder": [
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "tokenizer": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
+
+파이프라인 인스턴스의 컴포넌트들을 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)의 폴더 구조와 비교해볼 경우, 각각의 컴포넌트마다 별도의 폴더가 있음을 확인할 수 있습니다.
+
+```
+.
+├── feature_extractor
+│ └── preprocessor_config.json
+├── model_index.json
+├── safety_checker
+│ ├── config.json
+│ └── pytorch_model.bin
+├── scheduler
+│ └── scheduler_config.json
+├── text_encoder
+│ ├── config.json
+│ └── pytorch_model.bin
+├── tokenizer
+│ ├── merges.txt
+│ ├── special_tokens_map.json
+│ ├── tokenizer_config.json
+│ └── vocab.json
+├── unet
+│ ├── config.json
+│ ├── diffusion_pytorch_model.bin
+└── vae
+ ├── config.json
+ ├── diffusion_pytorch_model.bin
+```
+
+또한 각각의 컴포넌트들을 파이프라인 인스턴스의 속성으로써 참조할 수 있습니다.
+
+```py
+pipeline.tokenizer
+```
+
+```python
+CLIPTokenizer(
+ name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer",
+ vocab_size=49408,
+ model_max_length=77,
+ is_fast=False,
+ padding_side="right",
+ truncation_side="right",
+ special_tokens={
+ "bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
+ "eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
+ "unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
+ "pad_token": "<|endoftext|>",
+ },
+)
+```
+
+모든 파이프라인은 `model_index.json` 파일을 통해 [`DiffusionPipeline`]에 다음과 같은 정보를 전달합니다.
+
+- `_class_name` 는 어떤 파이프라인 클래스를 사용해야 하는지에 대해 알려줍니다.
+- `_diffusers_version`는 어떤 버전의 디퓨저스로 파이프라인 안의 모델들이 만들어졌는지를 알려줍니다.
+- 그 다음은 각각의 컴포넌트들이 어떤 라이브러리의 어떤 클래스로 만들어졌는지에 대해 알려줍니다. (아래 예시에서 `"feature_extractor" : ["transformers", "CLIPImageProcessor"]`의 경우, `feature_extractor` 컴포넌트는 `transformers` 라이브러리의 `CLIPImageProcessor` 클래스를 통해 만들어졌다는 것을 의미합니다.)
+
+```json
+{
+ "_class_name": "StableDiffusionPipeline",
+ "_diffusers_version": "0.6.0",
+ "feature_extractor": [
+ "transformers",
+ "CLIPImageProcessor"
+ ],
+ "safety_checker": [
+ "stable_diffusion",
+ "StableDiffusionSafetyChecker"
+ ],
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ "text_encoder": [
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "tokenizer": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
+
diff --git a/diffusers/docs/source/ko/using-diffusers/loading_overview.md b/diffusers/docs/source/ko/using-diffusers/loading_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..a99c6b04c8f6ec26669918b85f6937fea9afb5d0
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/loading_overview.md
@@ -0,0 +1,18 @@
+
+
+# Overview
+
+🧨 Diffusers는 생성 작업을 위한 다양한 파이프라인, 모델, 스케줄러를 제공합니다. 이러한 컴포넌트를 최대한 간단하게 로드할 수 있도록 단일 통합 메서드인 `from_pretrained()`를 제공하여 Hugging Face [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 또는 로컬 머신에서 이러한 컴포넌트를 불러올 수 있습니다. 파이프라인이나 모델을 로드할 때마다, 최신 파일이 자동으로 다운로드되고 캐시되므로, 다음에 파일을 다시 다운로드하지 않고도 빠르게 재사용할 수 있습니다.
+
+이 섹션은 파이프라인 로딩, 파이프라인에서 다양한 컴포넌트를 로드하는 방법, 체크포인트 variants를 불러오는 방법, 그리고 커뮤니티 파이프라인을 불러오는 방법에 대해 알아야 할 모든 것들을 다룹니다. 또한 스케줄러를 불러오는 방법과 서로 다른 스케줄러를 사용할 때 발생하는 속도와 품질간의 트레이드 오프를 비교하는 방법 역시 다룹니다. 그리고 마지막으로 🧨 Diffusers와 함께 파이토치에서 사용할 수 있도록 KerasCV 체크포인트를 변환하고 불러오는 방법을 살펴봅니다.
+
diff --git a/diffusers/docs/source/ko/using-diffusers/other-formats.md b/diffusers/docs/source/ko/using-diffusers/other-formats.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0aab5b0cc9f80c319fb39e8a1ec08b46ebd4320
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/other-formats.md
@@ -0,0 +1,191 @@
+
+
+# 다양한 Stable Diffusion 포맷 불러오기
+
+Stable Diffusion 모델들은 학습 및 저장된 프레임워크와 다운로드 위치에 따라 다양한 형식으로 제공됩니다. 이러한 형식을 🤗 Diffusers에서 사용할 수 있도록 변환하면 추론을 위한 [다양한 스케줄러 사용](schedulers), 사용자 지정 파이프라인 구축, 추론 속도 최적화를 위한 다양한 기법과 방법 등 라이브러리에서 지원하는 모든 기능을 사용할 수 있습니다.
+
+
+
+우리는 `.safetensors` 형식을 추천합니다. 왜냐하면 기존의 pickled 파일은 취약하고 머신에서 코드를 실행할 때 악용될 수 있는 것에 비해 훨씬 더 안전합니다. (safetensors 불러오기 가이드에서 자세히 알아보세요.)
+
+
+
+이 가이드에서는 다른 Stable Diffusion 형식을 🤗 Diffusers와 호환되도록 변환하는 방법을 설명합니다.
+
+## PyTorch .ckpt
+
+체크포인트 또는 `.ckpt` 형식은 일반적으로 모델을 저장하는 데 사용됩니다. `.ckpt` 파일은 전체 모델을 포함하며 일반적으로 크기가 몇 GB입니다. `.ckpt` 파일을 [~StableDiffusionPipeline.from_ckpt] 메서드를 사용하여 직접 불러와서 사용할 수도 있지만, 일반적으로 두 가지 형식을 모두 사용할 수 있도록 `.ckpt` 파일을 🤗 Diffusers로 변환하는 것이 더 좋습니다.
+
+`.ckpt` 파일을 변환하는 두 가지 옵션이 있습니다. Space를 사용하여 체크포인트를 변환하거나 스크립트를 사용하여 `.ckpt` 파일을 변환합니다.
+
+### Space로 변환하기
+
+`.ckpt` 파일을 변환하는 가장 쉽고 편리한 방법은 SD에서 Diffusers로 스페이스를 사용하는 것입니다. Space의 지침에 따라 .ckpt 파일을 변환 할 수 있습니다.
+
+이 접근 방식은 기본 모델에서는 잘 작동하지만 더 많은 사용자 정의 모델에서는 어려움을 겪을 수 있습니다. 빈 pull request나 오류를 반환하면 Space가 실패한 것입니다.
+이 경우 스크립트를 사용하여 `.ckpt` 파일을 변환해 볼 수 있습니다.
+
+### 스크립트로 변환하기
+
+🤗 Diffusers는 `.ckpt` 파일 변환을 위한 변환 스크립트를 제공합니다. 이 접근 방식은 위의 Space보다 더 안정적입니다.
+
+시작하기 전에 스크립트를 실행할 🤗 Diffusers의 로컬 클론(clone)이 있는지 확인하고 Hugging Face 계정에 로그인하여 pull request를 열고 변환된 모델을 허브에 푸시할 수 있도록 하세요.
+
+```bash
+huggingface-cli login
+```
+
+스크립트를 사용하려면:
+
+1. 변환하려는 `.ckpt` 파일이 포함된 리포지토리를 Git으로 클론(clone)합니다.
+
+이 예제에서는 TemporalNet .ckpt 파일을 변환해 보겠습니다:
+
+```bash
+git lfs install
+git clone https://huggingface.co/CiaraRowles/TemporalNet
+```
+
+2. 체크포인트를 변환할 리포지토리에서 pull request를 엽니다:
+
+```bash
+cd TemporalNet && git fetch origin refs/pr/13:pr/13
+git checkout pr/13
+```
+
+3. 변환 스크립트에서 구성할 입력 인수는 여러 가지가 있지만 가장 중요한 인수는 다음과 같습니다:
+
+- `checkpoint_path`: 변환할 `.ckpt` 파일의 경로를 입력합니다.
+- `original_config_file`: 원래 아키텍처의 구성을 정의하는 YAML 파일입니다. 이 파일을 찾을 수 없는 경우 `.ckpt` 파일을 찾은 GitHub 리포지토리에서 YAML 파일을 검색해 보세요.
+- `dump_path`: 변환된 모델의 경로
+
+예를 들어, TemporalNet 모델은 Stable Diffusion v1.5 및 ControlNet 모델이기 때문에 ControlNet 리포지토리에서 cldm_v15.yaml 파일을 가져올 수 있습니다.
+
+4. 이제 스크립트를 실행하여 .ckpt 파일을 변환할 수 있습니다:
+
+```bash
+python ../diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path temporalnetv3.ckpt --original_config_file cldm_v15.yaml --dump_path ./ --controlnet
+```
+
+5. 변환이 완료되면 변환된 모델을 업로드하고 결과물을 pull request [pull request](https://huggingface.co/CiaraRowles/TemporalNet/discussions/13)를 테스트하세요!
+
+```bash
+git push origin pr/13:refs/pr/13
+```
+
+## **Keras .pb or .h5**
+
+🧪 이 기능은 실험적인 기능입니다. 현재로서는 Stable Diffusion v1 체크포인트만 변환 KerasCV Space에서 지원됩니다.
+
+[KerasCV](https://keras.io/keras_cv/)는 [Stable Diffusion](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion) v1 및 v2에 대한 학습을 지원합니다. 그러나 추론 및 배포를 위한 Stable Diffusion 모델 실험을 제한적으로 지원하는 반면, 🤗 Diffusers는 다양한 [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other optimization techniques](https://huggingface.co/docs/diffusers/optimization/fp16) 등 이러한 목적을 위한 보다 완벽한 기능을 갖추고 있습니다.
+
+[Convert KerasCV](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) Space 변환은 `.pb` 또는 `.h5`을 PyTorch로 변환한 다음, 추론할 수 있도록 [`StableDiffusionPipeline`] 으로 감싸서 준비합니다. 변환된 체크포인트는 Hugging Face Hub의 리포지토리에 저장됩니다.
+
+예제로, textual-inversion으로 학습된 `[sayakpaul/textual-inversion-kerasio](https://huggingface.co/sayakpaul/textual-inversion-kerasio/tree/main)` 체크포인트를 변환해 보겠습니다. 이것은 특수 토큰 ``을 사용하여 고양이로 이미지를 개인화합니다.
+
+KerasCV Space 변환에서는 다음을 입력할 수 있습니다:
+
+- Hugging Face 토큰.
+- UNet 과 텍스트 인코더(text encoder) 가중치를 다운로드하는 경로입니다. 모델을 어떻게 학습할지 방식에 따라, UNet과 텍스트 인코더의 경로를 모두 제공할 필요는 없습니다. 예를 들어, textual-inversion에는 텍스트 인코더의 임베딩만 필요하고 텍스트-이미지(text-to-image) 모델 변환에는 UNet 가중치만 필요합니다.
+- Placeholder 토큰은 textual-inversion 모델에만 적용됩니다.
+- `output_repo_prefix`는 변환된 모델이 저장되는 리포지토리의 이름입니다.
+
+**Submit** (제출) 버튼을 클릭하면 KerasCV 체크포인트가 자동으로 변환됩니다! 체크포인트가 성공적으로 변환되면, 변환된 체크포인트가 포함된 새 리포지토리로 연결되는 링크가 표시됩니다. 새 리포지토리로 연결되는 링크를 따라가면 변환된 모델을 사용해 볼 수 있는 추론 위젯이 포함된 모델 카드가 생성된 KerasCV Space 변환을 확인할 수 있습니다.
+
+코드를 사용하여 추론을 실행하려면 모델 카드의 오른쪽 상단 모서리에 있는 **Use in Diffusers** 버튼을 클릭하여 예시 코드를 복사하여 붙여넣습니다:
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline")
+```
+
+그러면 다음과 같은 이미지를 생성할 수 있습니다:
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline")
+pipeline.to("cuda")
+
+placeholder_token = ""
+prompt = f"two {placeholder_token} getting married, photorealistic, high quality"
+image = pipeline(prompt, num_inference_steps=50).images[0]
+```
+
+## **A1111 LoRA files**
+
+[Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (A1111)은 Stable Diffusion을 위해 널리 사용되는 웹 UI로, [Civitai](https://civitai.com/) 와 같은 모델 공유 플랫폼을 지원합니다. 특히 LoRA 기법으로 학습된 모델은 학습 속도가 빠르고 완전히 파인튜닝된 모델보다 파일 크기가 훨씬 작기 때문에 인기가 높습니다.
+
+🤗 Diffusers는 [`~loaders.LoraLoaderMixin.load_lora_weights`]:를 사용하여 A1111 LoRA 체크포인트 불러오기를 지원합니다:
+
+```py
+from diffusers import DiffusionPipeline, UniPCMultistepScheduler
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "andite/anything-v4.0", torch_dtype=torch.float16, safety_checker=None
+).to("cuda")
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+```
+
+Civitai에서 LoRA 체크포인트를 다운로드하세요; 이 예제에서는 [Howls Moving Castle,Interior/Scenery LoRA (Ghibli Stlye)](https://civitai.com/models/14605?modelVersionId=19998) 체크포인트를 사용했지만, 어떤 LoRA 체크포인트든 자유롭게 사용해 보세요!
+
+```bash
+!wget https://civitai.com/api/download/models/19998 -O howls_moving_castle.safetensors
+```
+
+메서드를 사용하여 파이프라인에 LoRA 체크포인트를 불러옵니다:
+
+```py
+pipeline.load_lora_weights(".", weight_name="howls_moving_castle.safetensors")
+```
+
+이제 파이프라인을 사용하여 이미지를 생성할 수 있습니다:
+
+```py
+prompt = "masterpiece, illustration, ultra-detailed, cityscape, san francisco, golden gate bridge, california, bay area, in the snow, beautiful detailed starry sky"
+negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
+
+images = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=512,
+ height=512,
+ num_inference_steps=25,
+ num_images_per_prompt=4,
+ generator=torch.manual_seed(0),
+).images
+```
+
+마지막으로, 디스플레이에 이미지를 표시하는 헬퍼 함수를 만듭니다:
+
+```py
+from PIL import Image
+
+
+def image_grid(imgs, rows=2, cols=2):
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+image_grid(images)
+```
+
+
+
+
diff --git a/diffusers/docs/source/ko/using-diffusers/pipeline_overview.md b/diffusers/docs/source/ko/using-diffusers/pipeline_overview.md
new file mode 100644
index 0000000000000000000000000000000000000000..da39e738325fcf074a66215f1ecc27c8972ba8f5
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/pipeline_overview.md
@@ -0,0 +1,17 @@
+
+
+# Overview
+
+파이프라인은 독립적으로 훈련된 모델과 스케줄러를 함께 모아서 추론을 위해 diffusion 시스템을 빠르고 쉽게 사용할 수 있는 방법을 제공하는 end-to-end 클래스입니다. 모델과 스케줄러의 특정 조합은 특수한 기능과 함께 [`StableDiffusionPipeline`] 또는 [`StableDiffusionControlNetPipeline`]과 같은 특정 파이프라인 유형을 정의합니다. 모든 파이프라인 유형은 기본 [`DiffusionPipeline`] 클래스에서 상속됩니다. 어느 체크포인트를 전달하면, 파이프라인 유형을 자동으로 감지하고 필요한 구성 요소들을 불러옵니다.
+
+이 섹션에서는 unconditional 이미지 생성, text-to-image 생성의 다양한 테크닉과 변화를 파이프라인에서 지원하는 작업들을 소개합니다. 프롬프트에 있는 특정 단어가 출력에 영향을 미치는 것을 조정하기 위해 재현성을 위한 시드 설정과 프롬프트에 가중치를 부여하는 것으로 생성 프로세스를 더 잘 제어하는 방법에 대해 배울 수 있습니다. 마지막으로 음성에서부터 이미지 생성과 같은 커스텀 작업을 위한 커뮤니티 파이프라인을 만드는 방법을 알 수 있습니다.
diff --git a/diffusers/docs/source/ko/using-diffusers/reproducibility.md b/diffusers/docs/source/ko/using-diffusers/reproducibility.md
new file mode 100644
index 0000000000000000000000000000000000000000..fdbfa036caa870e080a857a8626596f2f7a9f2b7
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/reproducibility.md
@@ -0,0 +1,201 @@
+
+
+# 재현 가능한 파이프라인 생성하기
+
+[[open-in-colab]]
+
+재현성은 테스트, 결과 재현, 그리고 [이미지 퀄리티 높이기](resuing_seeds)에서 중요합니다.
+그러나 diffusion 모델의 무작위성은 매번 모델이 돌아갈 때마다 파이프라인이 다른 이미지를 생성할 수 있도록 하는 이유로 필요합니다.
+플랫폼 간에 정확하게 동일한 결과를 얻을 수는 없지만, 특정 허용 범위 내에서 릴리스 및 플랫폼 간에 결과를 재현할 수는 있습니다.
+그럼에도 diffusion 파이프라인과 체크포인트에 따라 허용 오차가 달라집니다.
+
+diffusion 모델에서 무작위성의 원천을 제어하거나 결정론적 알고리즘을 사용하는 방법을 이해하는 것이 중요한 이유입니다.
+
+
+
+💡 Pytorch의 [재현성에 대한 선언](https://pytorch.org/docs/stable/notes/randomness.html)를 꼭 읽어보길 추천합니다:
+
+> 완전하게 재현가능한 결과는 Pytorch 배포, 개별적인 커밋, 혹은 다른 플랫폼들에서 보장되지 않습니다.
+> 또한, 결과는 CPU와 GPU 실행간에 심지어 같은 seed를 사용할 때도 재현 가능하지 않을 수 있습니다.
+
+
+
+## 무작위성 제어하기
+
+추론에서, 파이프라인은 노이즈를 줄이기 위해 가우시안 노이즈를 생성하거나 스케줄링 단계에 노이즈를 더하는 등의 랜덤 샘플링 실행에 크게 의존합니다,
+
+[DDIMPipeline](https://huggingface.co/docs/diffusers/v0.18.0/en/api/pipelines/ddim#diffusers.DDIMPipeline)에서 두 추론 단계 이후의 텐서 값을 살펴보세요:
+
+```python
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# 모델과 스케줄러를 불러오기
+ddim = DDIMPipeline.from_pretrained(model_id)
+
+# 두 개의 단계에 대해서 파이프라인을 실행하고 numpy tensor로 값을 반환하기
+image = ddim(num_inference_steps=2, output_type="np").images
+print(np.abs(image).sum())
+```
+
+위의 코드를 실행하면 하나의 값이 나오지만, 다시 실행하면 다른 값이 나옵니다. 무슨 일이 일어나고 있는 걸까요?
+
+파이프라인이 실행될 때마다, [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html)은
+단계적으로 노이즈 제거되는 가우시안 노이즈가 생성하기 위한 다른 랜덤 seed를 사용합니다.
+
+그러나 동일한 이미지를 안정적으로 생성해야 하는 경우에는 CPU에서 파이프라인을 실행하는지 GPU에서 실행하는지에 따라 달라집니다.
+
+### CPU
+
+CPU에서 재현 가능한 결과를 생성하려면, PyTorch [Generator](https://pytorch.org/docs/stable/generated/torch.randn.html)로 seed를 고정합니다:
+
+```python
+import torch
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# 모델과 스케줄러 불러오기
+ddim = DDIMPipeline.from_pretrained(model_id)
+
+# 재현성을 위해 generator 만들기
+generator = torch.Generator(device="cpu").manual_seed(0)
+
+# 두 개의 단계에 대해서 파이프라인을 실행하고 numpy tensor로 값을 반환하기
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+이제 위의 코드를 실행하면 seed를 가진 `Generator` 객체가 파이프라인의 모든 랜덤 함수에 전달되므로 항상 `1491.1711` 값이 출력됩니다.
+
+특정 하드웨어 및 PyTorch 버전에서 이 코드 예제를 실행하면 동일하지는 않더라도 유사한 결과를 얻을 수 있습니다.
+
+
+
+💡 처음에는 시드를 나타내는 정수값 대신에 `Generator` 개체를 파이프라인에 전달하는 것이 약간 비직관적일 수 있지만,
+`Generator`는 순차적으로 여러 파이프라인에 전달될 수 있는 \랜덤상태\이기 때문에 PyTorch에서 확률론적 모델을 다룰 때 권장되는 설계입니다.
+
+
+
+### GPU
+
+예를 들면, GPU 상에서 같은 코드 예시를 실행하면:
+
+```python
+import torch
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# 모델과 스케줄러 불러오기
+ddim = DDIMPipeline.from_pretrained(model_id)
+ddim.to("cuda")
+
+# 재현성을 위한 generator 만들기
+generator = torch.Generator(device="cuda").manual_seed(0)
+
+# 두 개의 단계에 대해서 파이프라인을 실행하고 numpy tensor로 값을 반환하기
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+GPU가 CPU와 다른 난수 생성기를 사용하기 때문에 동일한 시드를 사용하더라도 결과가 같지 않습니다.
+
+이 문제를 피하기 위해 🧨 Diffusers는 CPU에 임의의 노이즈를 생성한 다음 필요에 따라 텐서를 GPU로 이동시키는
+[randn_tensor()](https://huggingface.co/docs/diffusers/v0.18.0/en/api/utilities#diffusers.utils.randn_tensor)기능을 가지고 있습니다.
+`randn_tensor` 기능은 파이프라인 내부 어디에서나 사용되므로 파이프라인이 GPU에서 실행되더라도 **항상** CPU `Generator`를 통과할 수 있습니다.
+
+이제 결과에 훨씬 더 다가왔습니다!
+
+```python
+import torch
+from diffusers import DDIMPipeline
+import numpy as np
+
+model_id = "google/ddpm-cifar10-32"
+
+# 모델과 스케줄러 불러오기
+ddim = DDIMPipeline.from_pretrained(model_id)
+ddim.to("cuda")
+
+#재현성을 위한 generator 만들기 (GPU에 올리지 않도록 조심한다!)
+generator = torch.manual_seed(0)
+
+# 두 개의 단계에 대해서 파이프라인을 실행하고 numpy tensor로 값을 반환하기
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+
+
+💡 재현성이 중요한 경우에는 항상 CPU generator를 전달하는 것이 좋습니다.
+성능 손실은 무시할 수 없는 경우가 많으며 파이프라인이 GPU에서 실행되었을 때보다 훨씬 더 비슷한 값을 생성할 수 있습니다.
+
+
+
+마지막으로 [UnCLIPPipeline](https://huggingface.co/docs/diffusers/v0.18.0/en/api/pipelines/unclip#diffusers.UnCLIPPipeline)과 같은
+더 복잡한 파이프라인의 경우, 이들은 종종 정밀 오차 전파에 극도로 취약합니다.
+다른 GPU 하드웨어 또는 PyTorch 버전에서 유사한 결과를 기대하지 마세요.
+이 경우 완전한 재현성을 위해 완전히 동일한 하드웨어 및 PyTorch 버전을 실행해야 합니다.
+
+## 결정론적 알고리즘
+
+결정론적 알고리즘을 사용하여 재현 가능한 파이프라인을 생성하도록 PyTorch를 구성할 수도 있습니다.
+그러나 결정론적 알고리즘은 비결정론적 알고리즘보다 느리고 성능이 저하될 수 있습니다.
+하지만 재현성이 중요하다면, 이것이 최선의 방법입니다!
+
+둘 이상의 CUDA 스트림에서 작업이 시작될 때 비결정론적 동작이 발생합니다.
+이 문제를 방지하려면 환경 변수 [CUBLAS_WORKSPACE_CONFIG](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility)를 `:16:8`로 설정해서
+런타임 중에 오직 하나의 버퍼 크리만 사용하도록 설정합니다.
+
+PyTorch는 일반적으로 가장 빠른 알고리즘을 선택하기 위해 여러 알고리즘을 벤치마킹합니다.
+하지만 재현성을 원하는 경우, 벤치마크가 매 순간 다른 알고리즘을 선택할 수 있기 때문에 이 기능을 사용하지 않도록 설정해야 합니다.
+마지막으로, [torch.use_deterministic_algorithms](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html)에
+`True`를 통과시켜 결정론적 알고리즘이 활성화 되도록 합니다.
+
+```py
+import os
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+
+torch.backends.cudnn.benchmark = False
+torch.use_deterministic_algorithms(True)
+```
+
+이제 동일한 파이프라인을 두번 실행하면 동일한 결과를 얻을 수 있습니다.
+
+```py
+import torch
+from diffusers import DDIMScheduler, StableDiffusionPipeline
+import numpy as np
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
+pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+g = torch.Generator(device="cuda")
+
+prompt = "A bear is playing a guitar on Times Square"
+
+g.manual_seed(0)
+result1 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
+
+g.manual_seed(0)
+result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
+
+print("L_inf dist = ", abs(result1 - result2).max())
+"L_inf dist = tensor(0., device='cuda:0')"
+```
diff --git a/diffusers/docs/source/ko/using-diffusers/reusing_seeds.md b/diffusers/docs/source/ko/using-diffusers/reusing_seeds.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ad27c3f2ac7f3bcda29f344420efef2c7588cd9
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/reusing_seeds.md
@@ -0,0 +1,63 @@
+
+
+# Deterministic(결정적) 생성을 통한 이미지 품질 개선
+
+생성된 이미지의 품질을 개선하는 일반적인 방법은 *결정적 batch(배치) 생성*을 사용하는 것입니다. 이 방법은 이미지 batch(배치)를 생성하고 두 번째 추론 라운드에서 더 자세한 프롬프트와 함께 개선할 이미지 하나를 선택하는 것입니다. 핵심은 일괄 이미지 생성을 위해 파이프라인에 [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator) 목록을 전달하고, 각 `Generator`를 시드에 연결하여 이미지에 재사용할 수 있도록 하는 것입니다.
+
+예를 들어 [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5)를 사용하여 다음 프롬프트의 여러 버전을 생성해 봅시다.
+
+```py
+prompt = "Labrador in the style of Vermeer"
+```
+
+(가능하다면) 파이프라인을 [`DiffusionPipeline.from_pretrained`]로 인스턴스화하여 GPU에 배치합니다.
+
+```python
+>>> from diffusers import DiffusionPipeline
+
+>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+>>> pipe = pipe.to("cuda")
+```
+
+이제 네 개의 서로 다른 `Generator`를 정의하고 각 `Generator`에 시드(`0` ~ `3`)를 할당하여 나중에 특정 이미지에 대해 `Generator`를 재사용할 수 있도록 합니다.
+
+```python
+>>> import torch
+
+>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
+```
+
+이미지를 생성하고 살펴봅니다.
+
+```python
+>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
+>>> images
+```
+
+![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds.jpg)
+
+이 예제에서는 첫 번째 이미지를 개선했지만 실제로는 원하는 모든 이미지를 사용할 수 있습니다(심지어 두 개의 눈이 있는 이미지도!). 첫 번째 이미지에서는 시드가 '0'인 '생성기'를 사용했기 때문에 두 번째 추론 라운드에서는 이 '생성기'를 재사용할 것입니다. 이미지의 품질을 개선하려면 프롬프트에 몇 가지 텍스트를 추가합니다:
+
+```python
+prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
+generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
+```
+
+시드가 `0`인 제너레이터 4개를 생성하고, 이전 라운드의 첫 번째 이미지처럼 보이는 다른 이미지 batch(배치)를 생성합니다!
+
+```python
+>>> images = pipe(prompt, generator=generator).images
+>>> images
+```
+
+![img](https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds_2.jpg)
diff --git a/diffusers/docs/source/ko/using-diffusers/schedulers.md b/diffusers/docs/source/ko/using-diffusers/schedulers.md
new file mode 100644
index 0000000000000000000000000000000000000000..6a8864fbe8f35a5d265cd8992c5726911cdb0d2d
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/schedulers.md
@@ -0,0 +1,329 @@
+
+
+# 스케줄러
+
+diffusion 파이프라인은 diffusion 모델, 스케줄러 등의 컴포넌트들로 구성됩니다. 그리고 파이프라인 안의 일부 컴포넌트를 다른 컴포넌트로 교체하는 식의 커스터마이징 역시 가능합니다. 이와 같은 컴포넌트 커스터마이징의 가장 대표적인 예시가 바로 [스케줄러](../api/schedulers/overview.md)를 교체하는 것입니다.
+
+
+
+스케쥴러는 다음과 같이 diffusion 시스템의 전반적인 디노이징 프로세스를 정의합니다.
+
+- 디노이징 스텝을 얼마나 가져가야 할까?
+- 확률적으로(stochastic) 혹은 확정적으로(deterministic)?
+- 디노이징 된 샘플을 찾아내기 위해 어떤 알고리즘을 사용해야 할까?
+
+이러한 프로세스는 다소 난해하고, 디노이징 속도와 디노이징 퀄리티 사이의 트레이드 오프를 정의해야 하는 문제가 될 수 있습니다. 주어진 파이프라인에 어떤 스케줄러가 가장 적합한지를 정량적으로 판단하는 것은 매우 어려운 일입니다. 이로 인해 일단 해당 스케줄러를 직접 사용하여, 생성되는 이미지를 직접 눈으로 보며, 정성적으로 성능을 판단해보는 것이 추천되곤 합니다.
+
+
+
+
+
+## 파이프라인 불러오기
+
+먼저 스테이블 diffusion 파이프라인을 불러오도록 해보겠습니다. 물론 스테이블 diffusion을 사용하기 위해서는, 허깅페이스 허브에 등록된 사용자여야 하며, 관련 [라이센스](https://huggingface.co/runwayml/stable-diffusion-v1-5)에 동의해야 한다는 점을 잊지 말아주세요.
+
+*역자 주: 다만, 현재 신규로 생성한 허깅페이스 계정에 대해서는 라이센스 동의를 요구하지 않는 것으로 보입니다!*
+
+```python
+from huggingface_hub import login
+from diffusers import DiffusionPipeline
+import torch
+
+# first we need to login with our access token
+login()
+
+# Now we can download the pipeline
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+```
+
+다음으로, GPU로 이동합니다.
+
+```python
+pipeline.to("cuda")
+```
+
+
+
+
+
+## 스케줄러 액세스
+
+스케줄러는 언제나 파이프라인의 컴포넌트로서 존재하며, 일반적으로 파이프라인 인스턴스 내에 `scheduler`라는 이름의 속성(property)으로 정의되어 있습니다.
+
+```python
+pipeline.scheduler
+```
+
+**Output**:
+
+```
+PNDMScheduler {
+ "_class_name": "PNDMScheduler",
+ "_diffusers_version": "0.8.0.dev0",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "num_train_timesteps": 1000,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "trained_betas": null
+}
+```
+
+출력 결과를 통해, 우리는 해당 스케줄러가 [`PNDMScheduler`]의 인스턴스라는 것을 알 수 있습니다. 이제 [`PNDMScheduler`]와 다른 스케줄러들의 성능을 비교해보도록 하겠습니다. 먼저 테스트에 사용할 프롬프트를 다음과 같이 정의해보도록 하겠습니다.
+
+```python
+prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
+```
+
+다음으로 유사한 이미지 생성을 보장하기 위해서, 다음과 같이 랜덤시드를 고정해주도록 하겠습니다.
+
+```python
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+
+
+## 스케줄러 교체하기
+
+다음으로 파이프라인의 스케줄러를 다른 스케줄러로 교체하는 방법에 대해 알아보겠습니다. 모든 스케줄러는 [`SchedulerMixin.compatibles`]라는 속성(property)을 갖고 있습니다. 해당 속성은 **호환 가능한** 스케줄러들에 대한 정보를 담고 있습니다.
+
+```python
+pipeline.scheduler.compatibles
+```
+
+**Output**:
+
+```
+[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
+```
+
+호환되는 스케줄러들을 살펴보면 아래와 같습니다.
+
+- [`LMSDiscreteScheduler`],
+- [`DDIMScheduler`],
+- [`DPMSolverMultistepScheduler`],
+- [`EulerDiscreteScheduler`],
+- [`PNDMScheduler`],
+- [`DDPMScheduler`],
+- [`EulerAncestralDiscreteScheduler`].
+
+앞서 정의했던 프롬프트를 사용해서 각각의 스케줄러들을 비교해보도록 하겠습니다.
+
+먼저 파이프라인 안의 스케줄러를 바꾸기 위해 [`ConfigMixin.config`] 속성과 [`ConfigMixin.from_config`] 메서드를 활용해보려고 합니다.
+
+
+
+```python
+pipeline.scheduler.config
+```
+
+**Output**:
+
+```
+FrozenDict([('num_train_timesteps', 1000),
+ ('beta_start', 0.00085),
+ ('beta_end', 0.012),
+ ('beta_schedule', 'scaled_linear'),
+ ('trained_betas', None),
+ ('skip_prk_steps', True),
+ ('set_alpha_to_one', False),
+ ('steps_offset', 1),
+ ('_class_name', 'PNDMScheduler'),
+ ('_diffusers_version', '0.8.0.dev0'),
+ ('clip_sample', False)])
+```
+
+기존 스케줄러의 config를 호환 가능한 다른 스케줄러에 이식하는 것 역시 가능합니다.
+
+다음 예시는 기존 스케줄러(`pipeline.scheduler`)를 다른 종류의 스케줄러(`DDIMScheduler`)로 바꾸는 코드입니다. 기존 스케줄러가 갖고 있던 config를 `.from_config` 메서드의 인자로 전달하는 것을 확인할 수 있습니다.
+
+```python
+from diffusers import DDIMScheduler
+
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+```
+
+
+
+이제 파이프라인을 실행해서 두 스케줄러 사이의 생성된 이미지의 퀄리티를 비교해봅시다.
+
+```python
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+
+
+## 스케줄러들 비교해보기
+
+지금까지는 [`PNDMScheduler`]와 [`DDIMScheduler`] 스케줄러를 실행해보았습니다. 아직 비교해볼 스케줄러들이 더 많이 남아있으니 계속 비교해보도록 하겠습니다.
+
+
+
+[`LMSDiscreteScheduler`]을 일반적으로 더 좋은 결과를 보여줍니다.
+
+```python
+from diffusers import LMSDiscreteScheduler
+
+pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+[`EulerDiscreteScheduler`]와 [`EulerAncestralDiscreteScheduler`] 고작 30번의 inference step만으로도 높은 퀄리티의 이미지를 생성하는 것을 알 수 있습니다.
+
+```python
+from diffusers import EulerDiscreteScheduler
+
+pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
+image
+```
+
+
+
+
+지금 이 문서를 작성하는 현시점 기준에선, [`DPMSolverMultistepScheduler`]가 시간 대비 가장 좋은 품질의 이미지를 생성하는 것 같습니다. 20번 정도의 스텝만으로도 실행될 수 있습니다.
+
+
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+
+
+
+보시다시피 생성된 이미지들은 매우 비슷하고, 비슷한 퀄리티를 보이는 것 같습니다. 실제로 어떤 스케줄러를 선택할 것인가는 종종 특정 이용 사례에 기반해서 결정되곤 합니다. 결국 여러 종류의 스케줄러를 직접 실행시켜보고 눈으로 직접 비교해서 판단하는 게 좋은 선택일 것 같습니다.
+
+
+
+## Flax에서 스케줄러 교체하기
+
+JAX/Flax 사용자인 경우 기본 파이프라인 스케줄러를 변경할 수도 있습니다. 다음은 Flax Stable Diffusion 파이프라인과 초고속 [DDPM-Solver++ 스케줄러를](../api/schedulers/multistep_dpm_solver) 사용하여 추론을 실행하는 방법에 대한 예시입니다 .
+
+```Python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+
+from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
+
+model_id = "runwayml/stable-diffusion-v1-5"
+scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ model_id,
+ subfolder="scheduler"
+)
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ model_id,
+ scheduler=scheduler,
+ revision="bf16",
+ dtype=jax.numpy.bfloat16,
+)
+params["scheduler"] = scheduler_state
+
+# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
+prompt = "a photo of an astronaut riding a horse on mars"
+num_samples = jax.device_count()
+prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
+
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 25
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+```
+
+
+
+다음 Flax 스케줄러는 *아직* Flax Stable Diffusion 파이프라인과 호환되지 않습니다.
+
+- `FlaxLMSDiscreteScheduler`
+- `FlaxDDPMScheduler`
+
+
+
diff --git a/diffusers/docs/source/ko/using-diffusers/stable_diffusion_jax_how_to.md b/diffusers/docs/source/ko/using-diffusers/stable_diffusion_jax_how_to.md
new file mode 100644
index 0000000000000000000000000000000000000000..e5785374413ce07ec02edfe420edeb3a4f82cf8f
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/stable_diffusion_jax_how_to.md
@@ -0,0 +1,264 @@
+
+
+# JAX / Flax에서의 🧨 Stable Diffusion!
+
+[[open-in-colab]]
+
+🤗 Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) 는 버전 0.5.1부터 Flax를 지원합니다! 이를 통해 Colab, Kaggle, Google Cloud Platform에서 사용할 수 있는 것처럼 Google TPU에서 초고속 추론이 가능합니다.
+
+이 노트북은 JAX / Flax를 사용해 추론을 실행하는 방법을 보여줍니다. Stable Diffusion의 작동 방식에 대한 자세한 내용을 원하거나 GPU에서 실행하려면 이 [노트북] ](https://huggingface.co/docs/diffusers/stable_diffusion)을 참조하세요.
+
+먼저, TPU 백엔드를 사용하고 있는지 확인합니다. Colab에서 이 노트북을 실행하는 경우, 메뉴에서 런타임을 선택한 다음 "런타임 유형 변경" 옵션을 선택한 다음 하드웨어 가속기 설정에서 TPU를 선택합니다.
+
+JAX는 TPU 전용은 아니지만 각 TPU 서버에는 8개의 TPU 가속기가 병렬로 작동하기 때문에 해당 하드웨어에서 더 빛을 발한다는 점은 알아두세요.
+
+
+## Setup
+
+먼저 diffusers가 설치되어 있는지 확인합니다.
+
+```bash
+!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
+!pip install diffusers
+```
+
+```python
+import jax.tools.colab_tpu
+
+jax.tools.colab_tpu.setup_tpu()
+import jax
+```
+
+```python
+num_devices = jax.device_count()
+device_type = jax.devices()[0].device_kind
+
+print(f"Found {num_devices} JAX devices of type {device_type}.")
+assert (
+ "TPU" in device_type
+), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
+```
+
+```python out
+Found 8 JAX devices of type Cloud TPU.
+```
+
+그런 다음 모든 dependencies를 가져옵니다.
+
+```python
+import numpy as np
+import jax
+import jax.numpy as jnp
+
+from pathlib import Path
+from jax import pmap
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from PIL import Image
+
+from huggingface_hub import notebook_login
+from diffusers import FlaxStableDiffusionPipeline
+```
+
+## 모델 불러오기
+
+TPU 장치는 효율적인 half-float 유형인 bfloat16을 지원합니다. 테스트에는 이 유형을 사용하지만 대신 float32를 사용하여 전체 정밀도(full precision)를 사용할 수도 있습니다.
+
+```python
+dtype = jnp.bfloat16
+```
+
+Flax는 함수형 프레임워크이므로 모델은 무상태(stateless)형이며 매개변수는 모델 외부에 저장됩니다. 사전학습된 Flax 파이프라인을 불러오면 파이프라인 자체와 모델 가중치(또는 매개변수)가 모두 반환됩니다. 저희는 bf16 버전의 가중치를 사용하고 있으므로 유형 경고가 표시되지만 무시해도 됩니다.
+
+```python
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ revision="bf16",
+ dtype=dtype,
+)
+```
+
+## 추론
+
+TPU에는 일반적으로 8개의 디바이스가 병렬로 작동하므로 보유한 디바이스 수만큼 프롬프트를 복제합니다. 그런 다음 각각 하나의 이미지 생성을 담당하는 8개의 디바이스에서 한 번에 추론을 수행합니다. 따라서 하나의 칩이 하나의 이미지를 생성하는 데 걸리는 시간과 동일한 시간에 8개의 이미지를 얻을 수 있습니다.
+
+프롬프트를 복제하고 나면 파이프라인의 `prepare_inputs` 함수를 호출하여 토큰화된 텍스트 ID를 얻습니다. 토큰화된 텍스트의 길이는 기본 CLIP 텍스트 모델의 구성에 따라 77토큰으로 설정됩니다.
+
+```python
+prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
+prompt = [prompt] * jax.device_count()
+prompt_ids = pipeline.prepare_inputs(prompt)
+prompt_ids.shape
+```
+
+```python out
+(8, 77)
+```
+
+### 복사(Replication) 및 정렬화
+
+모델 매개변수와 입력값은 우리가 보유한 8개의 병렬 장치에 복사(Replication)되어야 합니다. 매개변수 딕셔너리는 `flax.jax_utils.replicate`(딕셔너리를 순회하며 가중치의 모양을 변경하여 8번 반복하는 함수)를 사용하여 복사됩니다. 배열은 `shard`를 사용하여 복제됩니다.
+
+```python
+p_params = replicate(params)
+```
+
+```python
+prompt_ids = shard(prompt_ids)
+prompt_ids.shape
+```
+
+```python out
+(8, 1, 77)
+```
+
+이 shape은 8개의 디바이스 각각이 shape `(1, 77)`의 jnp 배열을 입력값으로 받는다는 의미입니다. 즉 1은 디바이스당 batch(배치) 크기입니다. 메모리가 충분한 TPU에서는 한 번에 여러 이미지(칩당)를 생성하려는 경우 1보다 클 수 있습니다.
+
+이미지를 생성할 준비가 거의 완료되었습니다! 이제 생성 함수에 전달할 난수 생성기만 만들면 됩니다. 이것은 난수를 다루는 모든 함수에 난수 생성기가 있어야 한다는, 난수에 대해 매우 진지하고 독단적인 Flax의 표준 절차입니다. 이렇게 하면 여러 분산된 기기에서 훈련할 때에도 재현성이 보장됩니다.
+
+아래 헬퍼 함수는 시드를 사용하여 난수 생성기를 초기화합니다. 동일한 시드를 사용하는 한 정확히 동일한 결과를 얻을 수 있습니다. 나중에 노트북에서 결과를 탐색할 때엔 다른 시드를 자유롭게 사용하세요.
+
+```python
+def create_key(seed=0):
+ return jax.random.PRNGKey(seed)
+```
+
+rng를 얻은 다음 8번 '분할'하여 각 디바이스가 다른 제너레이터를 수신하도록 합니다. 따라서 각 디바이스마다 다른 이미지가 생성되며 전체 프로세스를 재현할 수 있습니다.
+
+```python
+rng = create_key(0)
+rng = jax.random.split(rng, jax.device_count())
+```
+
+JAX 코드는 매우 빠르게 실행되는 효율적인 표현으로 컴파일할 수 있습니다. 하지만 후속 호출에서 모든 입력이 동일한 모양을 갖도록 해야 하며, 그렇지 않으면 JAX가 코드를 다시 컴파일해야 하므로 최적화된 속도를 활용할 수 없습니다.
+
+`jit = True`를 인수로 전달하면 Flax 파이프라인이 코드를 컴파일할 수 있습니다. 또한 모델이 사용 가능한 8개의 디바이스에서 병렬로 실행되도록 보장합니다.
+
+다음 셀을 처음 실행하면 컴파일하는 데 시간이 오래 걸리지만 이후 호출(입력이 다른 경우에도)은 훨씬 빨라집니다. 예를 들어, 테스트했을 때 TPU v2-8에서 컴파일하는 데 1분 이상 걸리지만 이후 추론 실행에는 약 7초가 걸립니다.
+
+```
+%%time
+images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
+```
+
+```python out
+CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
+Wall time: 1min 29s
+```
+
+반환된 배열의 shape은 `(8, 1, 512, 512, 3)`입니다. 이를 재구성하여 두 번째 차원을 제거하고 512 × 512 × 3의 이미지 8개를 얻은 다음 PIL로 변환합니다.
+
+```python
+images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+images = pipeline.numpy_to_pil(images)
+```
+
+### 시각화
+
+이미지를 그리드에 표시하는 도우미 함수를 만들어 보겠습니다.
+
+```python
+def image_grid(imgs, rows, cols):
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+```
+
+```python
+image_grid(images, 2, 4)
+```
+
+![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
+
+
+## 다른 프롬프트 사용
+
+모든 디바이스에서 동일한 프롬프트를 복제할 필요는 없습니다. 프롬프트 2개를 각각 4번씩 생성하거나 한 번에 8개의 서로 다른 프롬프트를 생성하는 등 원하는 것은 무엇이든 할 수 있습니다. 한번 해보세요!
+
+먼저 입력 준비 코드를 편리한 함수로 리팩터링하겠습니다:
+
+```python
+prompts = [
+ "Labrador in the style of Hokusai",
+ "Painting of a squirrel skating in New York",
+ "HAL-9000 in the style of Van Gogh",
+ "Times Square under water, with fish and a dolphin swimming around",
+ "Ancient Roman fresco showing a man working on his laptop",
+ "Close-up photograph of young black woman against urban background, high quality, bokeh",
+ "Armchair in the shape of an avocado",
+ "Clown astronaut in space, with Earth in the background",
+]
+```
+
+```python
+prompt_ids = pipeline.prepare_inputs(prompts)
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, p_params, rng, jit=True).images
+images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+images = pipeline.numpy_to_pil(images)
+
+image_grid(images, 2, 4)
+```
+
+![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
+
+
+## 병렬화(parallelization)는 어떻게 작동하는가?
+
+앞서 `diffusers` Flax 파이프라인이 모델을 자동으로 컴파일하고 사용 가능한 모든 기기에서 병렬로 실행한다고 말씀드렸습니다. 이제 그 프로세스를 간략하게 살펴보고 작동 방식을 보여드리겠습니다.
+
+JAX 병렬화는 여러 가지 방법으로 수행할 수 있습니다. 가장 쉬운 방법은 jax.pmap 함수를 사용하여 단일 프로그램, 다중 데이터(SPMD) 병렬화를 달성하는 것입니다. 즉, 동일한 코드의 복사본을 각각 다른 데이터 입력에 대해 여러 개 실행하는 것입니다. 더 정교한 접근 방식도 가능하므로 관심이 있으시다면 [JAX 문서](https://jax.readthedocs.io/en/latest/index.html)와 [`pjit` 페이지](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit)에서 이 주제를 살펴보시기 바랍니다!
+
+`jax.pmap`은 두 가지 기능을 수행합니다:
+
+- `jax.jit()`를 호출한 것처럼 코드를 컴파일(또는 `jit`)합니다. 이 작업은 `pmap`을 호출할 때가 아니라 pmapped 함수가 처음 호출될 때 수행됩니다.
+- 컴파일된 코드가 사용 가능한 모든 기기에서 병렬로 실행되도록 합니다.
+
+작동 방식을 보여드리기 위해 이미지 생성을 실행하는 비공개 메서드인 파이프라인의 `_generate` 메서드를 `pmap`합니다. 이 메서드는 향후 `Diffusers` 릴리스에서 이름이 변경되거나 제거될 수 있다는 점에 유의하세요.
+
+```python
+p_generate = pmap(pipeline._generate)
+```
+
+`pmap`을 사용한 후 준비된 함수 `p_generate`는 개념적으로 다음을 수행합니다:
+* 각 장치에서 기본 함수 `pipeline._generate`의 복사본을 호출합니다.
+* 각 장치에 입력 인수의 다른 부분을 보냅니다. 이것이 바로 샤딩이 사용되는 이유입니다. 이 경우 `prompt_ids`의 shape은 `(8, 1, 77, 768)`입니다. 이 배열은 8개로 분할되고 `_generate`의 각 복사본은 `(1, 77, 768)`의 shape을 가진 입력을 받게 됩니다.
+
+병렬로 호출된다는 사실을 완전히 무시하고 `_generate`를 코딩할 수 있습니다. batch(배치) 크기(이 예제에서는 `1`)와 코드에 적합한 차원만 신경 쓰면 되며, 병렬로 작동하기 위해 아무것도 변경할 필요가 없습니다.
+
+파이프라인 호출을 사용할 때와 마찬가지로, 다음 셀을 처음 실행할 때는 시간이 걸리지만 그 이후에는 훨씬 빨라집니다.
+
+```
+%%time
+images = p_generate(prompt_ids, p_params, rng)
+images = images.block_until_ready()
+images.shape
+```
+
+```python out
+CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
+Wall time: 1min 15s
+```
+
+```python
+images.shape
+```
+
+```python out
+(8, 1, 512, 512, 3)
+```
+
+JAX는 비동기 디스패치를 사용하고 가능한 한 빨리 제어권을 Python 루프에 반환하기 때문에 추론 시간을 정확하게 측정하기 위해 `block_until_ready()`를 사용합니다. 아직 구체화되지 않은 계산 결과를 사용하려는 경우 자동으로 차단이 수행되므로 코드에서 이 함수를 사용할 필요가 없습니다.
\ No newline at end of file
diff --git a/diffusers/docs/source/ko/using-diffusers/textual_inversion_inference.md b/diffusers/docs/source/ko/using-diffusers/textual_inversion_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..1b52fee923b3dbacb16766d20d05b519a08d3516
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/textual_inversion_inference.md
@@ -0,0 +1,80 @@
+# Textual inversion
+
+[[open-in-colab]]
+
+[`StableDiffusionPipeline`]은 textual-inversion을 지원하는데, 이는 몇 개의 샘플 이미지만으로 stable diffusion과 같은 모델이 새로운 컨셉을 학습할 수 있도록 하는 기법입니다. 이를 통해 생성된 이미지를 더 잘 제어하고 특정 컨셉에 맞게 모델을 조정할 수 있습니다. 커뮤니티에서 만들어진 컨셉들의 컬렉션은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer)를 통해 빠르게 사용해볼 수 있습니다.
+
+이 가이드에서는 Stable Diffusion Conceptualizer에서 사전학습한 컨셉을 사용하여 textual-inversion으로 추론을 실행하는 방법을 보여드립니다. textual-inversion으로 모델에 새로운 컨셉을 학습시키는 데 관심이 있으시다면, [Textual Inversion](./training/text_inversion) 훈련 가이드를 참조하세요.
+
+Hugging Face 계정으로 로그인하세요:
+
+```py
+from huggingface_hub import notebook_login
+
+notebook_login()
+```
+
+필요한 라이브러리를 불러오고 생성된 이미지를 시각화하기 위한 도우미 함수 `image_grid`를 만듭니다:
+
+```py
+import os
+import torch
+
+import PIL
+from PIL import Image
+
+from diffusers import StableDiffusionPipeline
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+ grid_w, grid_h = grid.size
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+```
+
+Stable Diffusion과 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer)에서 사전학습된 컨셉을 선택합니다:
+
+```py
+pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
+repo_id_embeds = "sd-concepts-library/cat-toy"
+```
+
+이제 파이프라인을 로드하고 사전학습된 컨셉을 파이프라인에 전달할 수 있습니다:
+
+```py
+pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16).to("cuda")
+
+pipeline.load_textual_inversion(repo_id_embeds)
+```
+
+특별한 placeholder token '``'를 사용하여 사전학습된 컨셉으로 프롬프트를 만들고, 생성할 샘플의 수와 이미지 행의 수를 선택합니다:
+
+```py
+prompt = "a grafitti in a favela wall with a on it"
+
+num_samples = 2
+num_rows = 2
+```
+
+그런 다음 파이프라인을 실행하고, 생성된 이미지들을 저장합니다. 그리고 처음에 만들었던 도우미 함수 `image_grid`를 사용하여 생성 결과들을 시각화합니다. 이 때 `num_inference_steps`와 `guidance_scale`과 같은 매개 변수들을 조정하여, 이것들이 이미지 품질에 어떠한 영향을 미치는지를 자유롭게 확인해보시기 바랍니다.
+
+```py
+all_images = []
+for _ in range(num_rows):
+ images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
+ all_images.extend(images)
+
+grid = image_grid(all_images, num_samples, num_rows)
+grid
+```
+
+
+
+
diff --git a/diffusers/docs/source/ko/using-diffusers/unconditional_image_generation.md b/diffusers/docs/source/ko/using-diffusers/unconditional_image_generation.md
new file mode 100644
index 0000000000000000000000000000000000000000..45b66bd86efec4d623b33b1c780ddd31aec143f5
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/unconditional_image_generation.md
@@ -0,0 +1,60 @@
+
+
+# Unconditional 이미지 생성
+
+[[open-in-colab]]
+
+Unconditional 이미지 생성은 비교적 간단한 작업입니다. 모델이 텍스트나 이미지와 같은 추가 조건 없이 이미 학습된 학습 데이터와 유사한 이미지만 생성합니다.
+
+['DiffusionPipeline']은 추론을 위해 미리 학습된 diffusion 시스템을 사용하는 가장 쉬운 방법입니다.
+
+먼저 ['DiffusionPipeline']의 인스턴스를 생성하고 다운로드할 파이프라인의 [체크포인트](https://huggingface.co/models?library=diffusers&sort=downloads)를 지정합니다. 허브의 🧨 diffusion 체크포인트 중 하나를 사용할 수 있습니다(사용할 체크포인트는 나비 이미지를 생성합니다).
+
+
+
+💡 나만의 unconditional 이미지 생성 모델을 학습시키고 싶으신가요? 학습 가이드를 살펴보고 나만의 이미지를 생성하는 방법을 알아보세요.
+
+
+
+
+이 가이드에서는 unconditional 이미지 생성에 ['DiffusionPipeline']과 [DDPM](https://arxiv.org/abs/2006.11239)을 사용합니다:
+
+```python
+ >>> from diffusers import DiffusionPipeline
+
+ >>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128")
+```
+
+[diffusion 파이프라인]은 모든 모델링, 토큰화, 스케줄링 구성 요소를 다운로드하고 캐시합니다. 이 모델은 약 14억 개의 파라미터로 구성되어 있기 때문에 GPU에서 실행할 것을 강력히 권장합니다. PyTorch에서와 마찬가지로 제너레이터 객체를 GPU로 옮길 수 있습니다:
+
+```python
+ >>> generator.to("cuda")
+```
+
+이제 제너레이터를 사용하여 이미지를 생성할 수 있습니다:
+
+```python
+ >>> image = generator().images[0]
+```
+
+출력은 기본적으로 [PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 감싸집니다.
+
+다음을 호출하여 이미지를 저장할 수 있습니다:
+
+```python
+ >>> image.save("generated_image.png")
+```
+
+아래 스페이스(데모 링크)를 이용해 보고, 추론 단계의 매개변수를 자유롭게 조절하여 이미지 품질에 어떤 영향을 미치는지 확인해 보세요!
+
+
diff --git a/diffusers/docs/source/ko/using-diffusers/using_safetensors.md b/diffusers/docs/source/ko/using-diffusers/using_safetensors.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e1c6758e13fcc1597584c6386e0105154b80e59
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/using_safetensors.md
@@ -0,0 +1,67 @@
+# 세이프텐서 로드
+
+[safetensors](https://github.com/huggingface/safetensors)는 텐서를 저장하고 로드하기 위한 안전하고 빠른 파일 형식입니다. 일반적으로 PyTorch 모델 가중치는 Python의 [`pickle`](https://docs.python.org/3/library/pickle.html) 유틸리티를 사용하여 `.bin` 파일에 저장되거나 `피클`됩니다. 그러나 `피클`은 안전하지 않으며 피클된 파일에는 실행될 수 있는 악성 코드가 포함될 수 있습니다. 세이프텐서는 `피클`의 안전한 대안으로 모델 가중치를 공유하는 데 이상적입니다.
+
+이 가이드에서는 `.safetensor` 파일을 로드하는 방법과 다른 형식으로 저장된 안정적 확산 모델 가중치를 `.safetensor`로 변환하는 방법을 보여드리겠습니다. 시작하기 전에 세이프텐서가 설치되어 있는지 확인하세요:
+
+```bash
+!pip install safetensors
+```
+
+['runwayml/stable-diffusion-v1-5`] (https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) 리포지토리를 보면 `text_encoder`, `unet` 및 `vae` 하위 폴더에 가중치가 `.safetensors` 형식으로 저장되어 있는 것을 볼 수 있습니다. 기본적으로 🤗 디퓨저는 모델 저장소에서 사용할 수 있는 경우 해당 하위 폴더에서 이러한 '.safetensors` 파일을 자동으로 로드합니다.
+
+보다 명시적인 제어를 위해 선택적으로 `사용_세이프텐서=True`를 설정할 수 있습니다(`세이프텐서`가 설치되지 않은 경우 설치하라는 오류 메시지가 표시됨):
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+```
+
+그러나 모델 가중치가 위의 예시처럼 반드시 별도의 하위 폴더에 저장되는 것은 아닙니다. 모든 가중치가 하나의 '.safetensors` 파일에 저장되는 경우도 있습니다. 이 경우 가중치가 Stable Diffusion 가중치인 경우 [`~diffusers.loaders.FromCkptMixin.from_ckpt`] 메서드를 사용하여 파일을 직접 로드할 수 있습니다:
+
+```py
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_ckpt(
+ "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
+)
+```
+
+## 세이프텐서로 변환
+
+허브의 모든 가중치를 '.safetensors` 형식으로 사용할 수 있는 것은 아니며, '.bin`으로 저장된 가중치가 있을 수 있습니다. 이 경우 [Convert Space](https://huggingface.co/spaces/diffusers/convert)을 사용하여 가중치를 '.safetensors'로 변환하세요. Convert Space는 피클된 가중치를 다운로드하여 변환한 후 풀 리퀘스트를 열어 허브에 새로 변환된 `.safetensors` 파일을 업로드합니다. 이렇게 하면 피클된 파일에 악성 코드가 포함되어 있는 경우, 안전하지 않은 파일과 의심스러운 피클 가져오기를 탐지하는 [보안 스캐너](https://huggingface.co/docs/hub/security-pickle#hubs-security-scanner)가 있는 허브로 업로드됩니다. - 개별 컴퓨터가 아닌.
+
+개정` 매개변수에 풀 리퀘스트에 대한 참조를 지정하여 새로운 '.safetensors` 가중치가 적용된 모델을 사용할 수 있습니다(허브의 [Check PR](https://huggingface.co/spaces/diffusers/check_pr) 공간에서 테스트할 수도 있음)(예: `refs/pr/22`):
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", revision="refs/pr/22")
+```
+
+## 세이프센서를 사용하는 이유는 무엇인가요?
+
+세이프티 센서를 사용하는 데에는 여러 가지 이유가 있습니다:
+
+- 세이프텐서를 사용하는 가장 큰 이유는 안전입니다.오픈 소스 및 모델 배포가 증가함에 따라 다운로드한 모델 가중치에 악성 코드가 포함되어 있지 않다는 것을 신뢰할 수 있는 것이 중요해졌습니다.세이프센서의 현재 헤더 크기는 매우 큰 JSON 파일을 구문 분석하지 못하게 합니다.
+- 모델 전환 간의 로딩 속도는 텐서의 제로 카피를 수행하는 세이프텐서를 사용해야 하는 또 다른 이유입니다. 가중치를 CPU(기본값)로 로드하는 경우 '피클'에 비해 특히 빠르며, 가중치를 GPU로 직접 로드하는 경우에도 빠르지는 않더라도 비슷하게 빠릅니다. 모델이 이미 로드된 경우에만 성능 차이를 느낄 수 있으며, 가중치를 다운로드하거나 모델을 처음 로드하는 경우에는 성능 차이를 느끼지 못할 것입니다.
+
+ 전체 파이프라인을 로드하는 데 걸리는 시간입니다:
+
+ ```py
+ from diffusers import StableDiffusionPipeline
+
+ pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
+ "Loaded in safetensors 0:00:02.033658"
+ "Loaded in PyTorch 0:00:02.663379"
+ ```
+
+ 하지만 실제로 500MB의 모델 가중치를 로드하는 데 걸리는 시간은 얼마 되지 않습니다:
+
+ ```bash
+ safetensors: 3.4873ms
+ PyTorch: 172.7537ms
+ ```
+
+지연 로딩은 세이프텐서에서도 지원되며, 이는 분산 설정에서 일부 텐서만 로드하는 데 유용합니다. 이 형식을 사용하면 [BLOOM](https://huggingface.co/bigscience/bloom) 모델을 일반 PyTorch 가중치를 사용하여 10분이 걸리던 것을 8개의 GPU에서 45초 만에 로드할 수 있습니다.
diff --git a/diffusers/docs/source/ko/using-diffusers/weighted_prompts.md b/diffusers/docs/source/ko/using-diffusers/weighted_prompts.md
new file mode 100644
index 0000000000000000000000000000000000000000..ce08f4949555618dbfe14b94f3964118d0fc6df3
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/weighted_prompts.md
@@ -0,0 +1,115 @@
+
+
+# 프롬프트에 가중치 부여하기
+
+[[open-in-colab]]
+
+텍스트 가이드 기반의 diffusion 모델은 주어진 텍스트 프롬프트를 기반으로 이미지를 생성합니다.
+텍스트 프롬프트에는 모델이 생성해야 하는 여러 개념이 포함될 수 있으며 프롬프트의 특정 부분에 가중치를 부여하는 것이 바람직한 경우가 많습니다.
+
+Diffusion 모델은 문맥화된 텍스트 임베딩으로 diffusion 모델의 cross attention 레이어를 조절함으로써 작동합니다.
+([더 많은 정보를 위한 Stable Diffusion Guide](https://huggingface.co/docs/optimum-neuron/main/en/package_reference/modeling#stable-diffusion)를 참고하세요).
+따라서 프롬프트의 특정 부분을 강조하는(또는 강조하지 않는) 간단한 방법은 프롬프트의 관련 부분에 해당하는 텍스트 임베딩 벡터의 크기를 늘리거나 줄이는 것입니다.
+이것은 "프롬프트 가중치 부여" 라고 하며, 커뮤니티에서 가장 요구하는 기능입니다.([이곳](https://github.com/huggingface/diffusers/issues/2431)의 issue를 보세요 ).
+
+## Diffusers에서 프롬프트 가중치 부여하는 방법
+
+우리는 `diffusers`의 역할이 다른 프로젝트를 가능하게 하는 필수적인 기능을 제공하는 toolbex라고 생각합니다.
+[InvokeAI](https://github.com/invoke-ai/InvokeAI) 나 [diffuzers](https://github.com/abhishekkrthakur/diffuzers) 같은 강력한 UI를 구축할 수 있습니다.
+프롬프트를 조작하는 방법을 지원하기 위해, `diffusers` 는
+[StableDiffusionPipeline](https://huggingface.co/docs/diffusers/v0.18.2/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline)와 같은
+많은 파이프라인에 [prompt_embeds](https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds)
+인수를 노출시켜, "prompt-weighted"/축척된 텍스트 임베딩을 파이프라인에 바로 전달할 수 있게 합니다.
+
+[Compel 라이브러리](https://github.com/damian0815/compel)는 프롬프트의 일부를 강조하거나 강조하지 않을 수 있는 쉬운 방법을 제공합니다.
+임베딩을 직접 준비하는 것 대신 이 방법을 사용하는 것을 강력히 추천합니다.
+
+간단한 예제를 살펴보겠습니다.
+다음과 같이 `"공을 갖고 노는 붉은색 고양이"` 이미지를 생성하고 싶습니다:
+
+```py
+from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
+
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+prompt = "a red cat playing with a ball"
+
+generator = torch.Generator(device="cpu").manual_seed(33)
+
+image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+생성된 이미지:
+
+![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png)
+
+사진에서 알 수 있듯이, "공"은 이미지에 없습니다. 이 부분을 강조해 볼까요!
+
+먼저 `compel` 라이브러리를 설치해야합니다:
+
+```
+pip install compel
+```
+
+그런 다음에는 `Compel` 오브젝트를 생성합니다:
+
+```py
+from compel import Compel
+
+compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
+```
+
+이제 `"++"` 를 사용해서 "공" 을 강조해 봅시다:
+
+```py
+prompt = "a red cat playing with a ball++"
+```
+
+그리고 이 프롬프트를 파이프라인에 바로 전달하지 않고, `compel_proc` 를 사용하여 처리해야합니다:
+
+```py
+prompt_embeds = compel_proc(prompt)
+```
+
+파이프라인에 `prompt_embeds` 를 바로 전달할 수 있습니다:
+
+```py
+generator = torch.Generator(device="cpu").manual_seed(33)
+
+images = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+이제 "공"이 있는 그림을 출력할 수 있습니다!
+
+![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_1.png)
+
+마찬가지로 `--` 접미사를 단어에 사용하여 문장의 일부를 강조하지 않을 수 있습니다. 한번 시도해 보세요!
+
+즐겨찾는 파이프라인에 `prompt_embeds` 입력이 없는 경우 issue를 새로 만들어주세요.
+Diffusers 팀은 최대한 대응하려고 노력합니다.
+
+Compel 1.1.6 는 textual inversions을 사용하여 단순화하는 유티릴티 클래스를 추가합니다.
+`DiffusersTextualInversionManager`를 인스턴스화 한 후 이를 Compel init에 전달합니다:
+
+```
+textual_inversion_manager = DiffusersTextualInversionManager(pipe)
+compel = Compel(
+ tokenizer=pipe.tokenizer,
+ text_encoder=pipe.text_encoder,
+ textual_inversion_manager=textual_inversion_manager)
+```
+
+더 많은 정보를 얻고 싶다면 [compel](https://github.com/damian0815/compel) 라이브러리 문서를 참고하세요.
diff --git a/diffusers/docs/source/ko/using-diffusers/write_own_pipeline.md b/diffusers/docs/source/ko/using-diffusers/write_own_pipeline.md
new file mode 100644
index 0000000000000000000000000000000000000000..787c8113bf0d44a3943a55d9bd829278a3c5f71c
--- /dev/null
+++ b/diffusers/docs/source/ko/using-diffusers/write_own_pipeline.md
@@ -0,0 +1,290 @@
+
+
+# 파이프라인, 모델 및 스케줄러 이해하기
+
+[[open-in-colab]]
+
+🧨 Diffusers는 사용자 친화적이며 유연한 도구 상자로, 사용사례에 맞게 diffusion 시스템을 구축 할 수 있도록 설계되었습니다. 이 도구 상자의 핵심은 모델과 스케줄러입니다. [`DiffusionPipeline`]은 편의를 위해 이러한 구성 요소를 번들로 제공하지만, 파이프라인을 분리하고 모델과 스케줄러를 개별적으로 사용해 새로운 diffusion 시스템을 만들 수도 있습니다.
+
+이 튜토리얼에서는 기본 파이프라인부터 시작해 Stable Diffusion 파이프라인까지 진행하며 모델과 스케줄러를 사용해 추론을 위한 diffusion 시스템을 조립하는 방법을 배웁니다.
+
+## 기본 파이프라인 해체하기
+
+파이프라인은 추론을 위해 모델을 실행하는 빠르고 쉬운 방법으로, 이미지를 생성하는 데 코드가 4줄 이상 필요하지 않습니다:
+
+```py
+>>> from diffusers import DDPMPipeline
+
+>>> ddpm = DDPMPipeline.from_pretrained("google/ddpm-cat-256").to("cuda")
+>>> image = ddpm(num_inference_steps=25).images[0]
+>>> image
+```
+
+
+
+
+
+정말 쉽습니다. 그런데 파이프라인은 어떻게 이렇게 할 수 있었을까요? 파이프라인을 세분화하여 내부에서 어떤 일이 일어나고 있는지 살펴보겠습니다.
+
+위 예시에서 파이프라인에는 [`UNet2DModel`] 모델과 [`DDPMScheduler`]가 포함되어 있습니다. 파이프라인은 원하는 출력 크기의 랜덤 노이즈를 받아 모델을 여러번 통과시켜 이미지의 노이즈를 제거합니다. 각 timestep에서 모델은 *noise residual*을 예측하고 스케줄러는 이를 사용하여 노이즈가 적은 이미지를 예측합니다. 파이프라인은 지정된 추론 스텝수에 도달할 때까지 이 과정을 반복합니다.
+
+모델과 스케줄러를 별도로 사용하여 파이프라인을 다시 생성하기 위해 자체적인 노이즈 제거 프로세스를 작성해 보겠습니다.
+
+1. 모델과 스케줄러를 불러옵니다:
+
+ ```py
+ >>> from diffusers import DDPMScheduler, UNet2DModel
+
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
+ >>> model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
+ ```
+
+2. 노이즈 제거 프로세스를 실행할 timestep 수를 설정합니다:
+
+ ```py
+ >>> scheduler.set_timesteps(50)
+ ```
+
+3. 스케줄러의 timestep을 설정하면 균등한 간격의 구성 요소를 가진 텐서가 생성됩니다.(이 예시에서는 50개) 각 요소는 모델이 이미지의 노이즈를 제거하는 시간 간격에 해당합니다. 나중에 노이즈 제거 루프를 만들 때 이 텐서를 반복하여 이미지의 노이즈를 제거합니다:
+
+ ```py
+ >>> scheduler.timesteps
+ tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,
+ 700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440,
+ 420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160,
+ 140, 120, 100, 80, 60, 40, 20, 0])
+ ```
+
+4. 원하는 출력과 같은 모양을 가진 랜덤 노이즈를 생성합니다:
+
+ ```py
+ >>> import torch
+
+ >>> sample_size = model.config.sample_size
+ >>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
+ ```
+
+5. 이제 timestep을 반복하는 루프를 작성합니다. 각 timestep에서 모델은 [`UNet2DModel.forward`]를 통해 noisy residual을 반환합니다. 스케줄러의 [`~DDPMScheduler.step`] 메서드는 noisy residual, timestep, 그리고 입력을 받아 이전 timestep에서 이미지를 예측합니다. 이 출력은 노이즈 제거 루프의 모델에 대한 다음 입력이 되며, `timesteps` 배열의 끝에 도달할 때까지 반복됩니다.
+
+ ```py
+ >>> input = noise
+
+ >>> for t in scheduler.timesteps:
+ ... with torch.no_grad():
+ ... noisy_residual = model(input, t).sample
+ ... previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
+ ... input = previous_noisy_sample
+ ```
+
+ 이것이 전체 노이즈 제거 프로세스이며, 동일한 패턴을 사용해 모든 diffusion 시스템을 작성할 수 있습니다.
+
+6. 마지막 단계는 노이즈가 제거된 출력을 이미지로 변환하는 것입니다:
+
+ ```py
+ >>> from PIL import Image
+ >>> import numpy as np
+
+ >>> image = (input / 2 + 0.5).clamp(0, 1)
+ >>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
+ >>> image = Image.fromarray((image * 255).round().astype("uint8"))
+ >>> image
+ ```
+
+다음 섹션에서는 여러분의 기술을 시험해보고 좀 더 복잡한 Stable Diffusion 파이프라인을 분석해 보겠습니다. 방법은 거의 동일합니다. 필요한 구성요소들을 초기화하고 timestep수를 설정하여 `timestep` 배열을 생성합니다. 노이즈 제거 루프에서 `timestep` 배열이 사용되며, 이 배열의 각 요소에 대해 모델은 노이즈가 적은 이미지를 예측합니다. 노이즈 제거 루프는 `timestep`을 반복하고 각 timestep에서 noise residual을 출력하고 스케줄러는 이를 사용하여 이전 timestep에서 노이즈가 덜한 이미지를 예측합니다. 이 프로세스는 `timestep` 배열의 끝에 도달할 때까지 반복됩니다.
+
+한번 사용해 봅시다!
+
+## Stable Diffusion 파이프라인 해체하기
+
+Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent diffusion 모델이라고 불리는 이유는 실제 픽셀 공간 대신 이미지의 저차원의 표현으로 작업하기 때문이고, 메모리 효율이 더 높습니다. 인코더는 이미지를 더 작은 표현으로 압축하고, 디코더는 압축된 표현을 다시 이미지로 변환합니다. text-to-image 모델의 경우 텍스트 임베딩을 생성하기 위해 tokenizer와 인코더가 필요합니다. 이전 예제에서 이미 UNet 모델과 스케줄러가 필요하다는 것은 알고 계셨을 것입니다.
+
+보시다시피, 이것은 UNet 모델만 포함된 DDPM 파이프라인보다 더 복잡합니다. Stable Diffusion 모델에는 세 개의 개별 사전학습된 모델이 있습니다.
+
+
+
+💡 VAE, UNet 및 텍스트 인코더 모델의 작동방식에 대한 자세한 내용은 [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) 블로그를 참조하세요.
+
+
+
+이제 Stable Diffusion 파이프라인에 필요한 구성요소들이 무엇인지 알았으니, [`~ModelMixin.from_pretrained`] 메서드를 사용해 모든 구성요소를 불러옵니다. 사전학습된 체크포인트 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)에서 찾을 수 있으며, 각 구성요소들은 별도의 하위 폴더에 저장되어 있습니다:
+
+```py
+>>> from PIL import Image
+>>> import torch
+>>> from transformers import CLIPTextModel, CLIPTokenizer
+>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
+
+>>> vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
+>>> tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
+>>> text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
+>>> unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
+```
+
+기본 [`PNDMScheduler`] 대신, [`UniPCMultistepScheduler`]로 교체하여 다른 스케줄러를 얼마나 쉽게 연결할 수 있는지 확인합니다:
+
+```py
+>>> from diffusers import UniPCMultistepScheduler
+
+>>> scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+```
+
+추론 속도를 높이려면 스케줄러와 달리 학습 가능한 가중치가 있으므로 모델을 GPU로 옮기세요:
+
+```py
+>>> torch_device = "cuda"
+>>> vae.to(torch_device)
+>>> text_encoder.to(torch_device)
+>>> unet.to(torch_device)
+```
+
+### 텍스트 임베딩 생성하기
+
+다음 단계는 임베딩을 생성하기 위해 텍스트를 토큰화하는 것입니다. 이 텍스트는 UNet 모델에서 condition으로 사용되고 입력 프롬프트와 유사한 방향으로 diffusion 프로세스를 조정하는 데 사용됩니다.
+
+
+
+💡 `guidance_scale` 매개변수는 이미지를 생성할 때 프롬프트에 얼마나 많은 가중치를 부여할지 결정합니다.
+
+
+
+다른 프롬프트를 생성하고 싶다면 원하는 프롬프트를 자유롭게 선택하세요!
+
+```py
+>>> prompt = ["a photograph of an astronaut riding a horse"]
+>>> height = 512 # Stable Diffusion의 기본 높이
+>>> width = 512 # Stable Diffusion의 기본 너비
+>>> num_inference_steps = 25 # 노이즈 제거 스텝 수
+>>> guidance_scale = 7.5 # classifier-free guidance를 위한 scale
+>>> generator = torch.manual_seed(0) # 초기 잠재 노이즈를 생성하는 seed generator
+>>> batch_size = len(prompt)
+```
+
+텍스트를 토큰화하고 프롬프트에서 임베딩을 생성합니다:
+
+```py
+>>> text_input = tokenizer(
+... prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
+... )
+
+>>> with torch.no_grad():
+... text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
+```
+
+또한 패딩 토큰의 임베딩인 *unconditional 텍스트 임베딩*을 생성해야 합니다. 이 임베딩은 조건부 `text_embeddings`과 동일한 shape(`batch_size` 그리고 `seq_length`)을 가져야 합니다:
+
+```py
+>>> max_length = text_input.input_ids.shape[-1]
+>>> uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
+>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
+```
+
+두번의 forward pass를 피하기 위해 conditional 임베딩과 unconditional 임베딩을 배치(batch)로 연결하겠습니다:
+
+```py
+>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+```
+
+### 랜덤 노이즈 생성
+
+그다음 diffusion 프로세스의 시작점으로 초기 랜덤 노이즈를 생성합니다. 이것이 이미지의 잠재적 표현이며 점차적으로 노이즈가 제거됩니다. 이 시점에서 `latent` 이미지는 최종 이미지 크기보다 작지만 나중에 모델이 이를 512x512 이미지 크기로 변환하므로 괜찮습니다.
+
+
+
+💡 `vae` 모델에는 3개의 다운 샘플링 레이어가 있기 때문에 높이와 너비가 8로 나뉩니다. 다음을 실행하여 확인할 수 있습니다:
+
+```py
+2 ** (len(vae.config.block_out_channels) - 1) == 8
+```
+
+
+
+```py
+>>> latents = torch.randn(
+... (batch_size, unet.in_channels, height // 8, width // 8),
+... generator=generator,
+... device=torch_device,
+... )
+```
+
+### 이미지 노이즈 제거
+
+먼저 [`UniPCMultistepScheduler`]와 같은 향상된 스케줄러에 필요한 노이즈 스케일 값인 초기 노이즈 분포 *sigma* 로 입력을 스케일링 하는 것부터 시작합니다:
+
+```py
+>>> latents = latents * scheduler.init_noise_sigma
+```
+
+마지막 단계는 `latent`의 순수한 노이즈를 점진적으로 프롬프트에 설명된 이미지로 변환하는 노이즈 제거 루프를 생성하는 것입니다. 노이즈 제거 루프는 세 가지 작업을 수행해야 한다는 점을 기억하세요:
+
+1. 노이즈 제거 중에 사용할 스케줄러의 timesteps를 설정합니다.
+2. timestep을 따라 반복합니다.
+3. 각 timestep에서 UNet 모델을 호출하여 noise residual을 예측하고 스케줄러에 전달하여 이전 노이즈 샘플을 계산합니다.
+
+```py
+>>> from tqdm.auto import tqdm
+
+>>> scheduler.set_timesteps(num_inference_steps)
+
+>>> for t in tqdm(scheduler.timesteps):
+... # classifier-free guidance를 수행하는 경우 두번의 forward pass를 수행하지 않도록 latent를 확장.
+... latent_model_input = torch.cat([latents] * 2)
+
+... latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
+
+... # noise residual 예측
+... with torch.no_grad():
+... noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+... # guidance 수행
+... noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+... noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+... # 이전 노이즈 샘플을 계산 x_t -> x_t-1
+... latents = scheduler.step(noise_pred, t, latents).prev_sample
+```
+
+### 이미지 디코딩
+
+마지막 단계는 `vae`를 이용하여 잠재 표현을 이미지로 디코딩하고 `sample`과 함께 디코딩된 출력을 얻는 것입니다:
+
+```py
+# latent를 스케일링하고 vae로 이미지 디코딩
+latents = 1 / 0.18215 * latents
+with torch.no_grad():
+ image = vae.decode(latents).sample
+```
+
+마지막으로 이미지를 `PIL.Image`로 변환하면 생성된 이미지를 확인할 수 있습니다!
+
+```py
+>>> image = (image / 2 + 0.5).clamp(0, 1)
+>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
+>>> images = (image * 255).round().astype("uint8")
+>>> pil_images = [Image.fromarray(image) for image in images]
+>>> pil_images[0]
+```
+
+
+
+
+
+## 다음 단계
+
+기본 파이프라인부터 복잡한 파이프라인까지, 자신만의 diffusion 시스템을 작성하는 데 필요한 것은 노이즈 제거 루프뿐이라는 것을 알 수 있었습니다. 이 루프는 스케줄러의 timesteps를 설정하고, 이를 반복하며, UNet 모델을 호출하여 noise residual을 예측하고 스케줄러에 전달하여 이전 노이즈 샘플을 계산하는 과정을 번갈아 가며 수행해야 합니다.
+
+이것이 바로 🧨 Diffusers가 설계된 목적입니다: 모델과 스케줄러를 사용해 자신만의 diffusion 시스템을 직관적이고 쉽게 작성할 수 있도록 하기 위해서입니다.
+
+다음 단계를 자유롭게 진행하세요:
+
+* 🧨 Diffusers에 [파이프라인 구축 및 기여](using-diffusers/#contribute_pipeline)하는 방법을 알아보세요. 여러분이 어떤 아이디어를 내놓을지 기대됩니다!
+* 라이브러리에서 [기본 파이프라인](./api/pipelines/overview)을 살펴보고, 모델과 스케줄러를 별도로 사용하여 파이프라인을 처음부터 해체하고 빌드할 수 있는지 확인해 보세요.
diff --git a/diffusers/docs/source/pt/_toctree.yml b/diffusers/docs/source/pt/_toctree.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c34297a4743f7d07380cab9ed8bbae64bd378e17
--- /dev/null
+++ b/diffusers/docs/source/pt/_toctree.yml
@@ -0,0 +1,8 @@
+- sections:
+ - local: index
+ title: 🧨 Diffusers
+ - local: quicktour
+ title: Tour rápido
+ - local: installation
+ title: Instalação
+ title: Primeiros passos
diff --git a/diffusers/docs/source/pt/index.md b/diffusers/docs/source/pt/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..b524fa6449e9771aacde0061e14484f884493f55
--- /dev/null
+++ b/diffusers/docs/source/pt/index.md
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+# Diffusers
+
+🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou queira treinar seu próprio modelo de difusão, 🤗 Diffusers é uma modular caixa de ferramentas que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
+
+A Biblioteca tem três componentes principais:
+
+- Pipelines de última geração para a geração em poucas linhas de código. Têm muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
+- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.
+- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.
+
+
diff --git a/diffusers/docs/source/pt/installation.md b/diffusers/docs/source/pt/installation.md
new file mode 100644
index 0000000000000000000000000000000000000000..52ea243ee034b1a8045d40e368cc1824f42adc1d
--- /dev/null
+++ b/diffusers/docs/source/pt/installation.md
@@ -0,0 +1,156 @@
+
+
+# Instalação
+
+🤗 Diffusers é testado no Python 3.8+, PyTorch 1.7.0+, e Flax. Siga as instruções de instalação abaixo para a biblioteca de deep learning que você está utilizando:
+
+- [PyTorch](https://pytorch.org/get-started/locally/) instruções de instalação
+- [Flax](https://flax.readthedocs.io/en/latest/) instruções de instalação
+
+## Instalação com pip
+
+Recomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).
+Se você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
+Um ambiente virtual deixa mais fácil gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
+
+Comece criando um ambiente virtual no diretório do projeto:
+
+```bash
+python -m venv .env
+```
+
+Ative o ambiente virtual:
+
+```bash
+source .env/bin/activate
+```
+
+Recomenda-se a instalação do 🤗 Transformers porque 🤗 Diffusers depende de seus modelos:
+
+
+
+```bash
+pip install diffusers["torch"] transformers
+```
+
+
+```bash
+pip install diffusers["flax"] transformers
+```
+
+
+
+## Instalação a partir do código fonte
+
+Antes da instalação do 🤗 Diffusers a partir do código fonte, certifique-se de ter o PyTorch e o 🤗 Accelerate instalados.
+
+Para instalar o 🤗 Accelerate:
+
+```bash
+pip install accelerate
+```
+
+então instale o 🤗 Diffusers do código fonte:
+
+```bash
+pip install git+https://github.com/huggingface/diffusers
+```
+
+Esse comando instala a última versão em desenvolvimento `main` em vez da última versão estável `stable`.
+A versão `main` é útil para se manter atualizado com os últimos desenvolvimentos.
+Por exemplo, se um bug foi corrigido desde o último lançamento estável, mas um novo lançamento ainda não foi lançado.
+No entanto, isso significa que a versão `main` pode não ser sempre estável.
+Nós nos esforçamos para manter a versão `main` operacional, e a maioria dos problemas geralmente são resolvidos em algumas horas ou um dia.
+Se você encontrar um problema, por favor abra uma [Issue](https://github.com/huggingface/diffusers/issues/new/choose), assim conseguimos arrumar o quanto antes!
+
+## Instalação editável
+
+Você precisará de uma instalação editável se você:
+
+- Usar a versão `main` do código fonte.
+- Contribuir para o 🤗 Diffusers e precisa testar mudanças no código.
+
+Clone o repositório e instale o 🤗 Diffusers com os seguintes comandos:
+
+```bash
+git clone https://github.com/huggingface/diffusers.git
+cd diffusers
+```
+
+
+
+```bash
+pip install -e ".[torch]"
+```
+
+
+```bash
+pip install -e ".[flax]"
+```
+
+
+
+Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
+Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
+Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.8/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
+
+
+
+Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
+
+
+
+Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
+
+```bash
+cd ~/diffusers/
+git pull
+```
+
+Seu ambiente Python vai encontrar a versão `main` do 🤗 Diffusers na próxima execução.
+
+## Cache
+
+Os pesos e os arquivos dos modelos são baixados do Hub para o cache que geralmente é o seu diretório home. Você pode mudar a localização do cache especificando as variáveis de ambiente `HF_HOME` ou `HUGGINFACE_HUB_CACHE` ou configurando o parâmetro `cache_dir` em métodos como [`~DiffusionPipeline.from_pretrained`].
+
+Aquivos em cache permitem que você rode 🤗 Diffusers offline. Para prevenir que o 🤗 Diffusers se conecte à internet, defina a variável de ambiente `HF_HUB_OFFLINE` para `True` e o 🤗 Diffusers irá apenas carregar arquivos previamente baixados em cache.
+
+```shell
+export HF_HUB_OFFLINE=True
+```
+
+Para mais detalhes de como gerenciar e limpar o cache, olhe o guia de [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).
+
+## Telemetria
+
+Nossa biblioteca coleta informações de telemetria durante as requisições [`~DiffusionPipeline.from_pretrained`].
+O dado coletado inclui a versão do 🤗 Diffusers e PyTorch/Flax, o modelo ou classe de pipeline requisitado,
+e o caminho para um checkpoint pré-treinado se ele estiver hospedado no Hugging Face Hub.
+Esse dado de uso nos ajuda a debugar problemas e priorizar novas funcionalidades.
+Telemetria é enviada apenas quando é carregado modelos e pipelines do Hub,
+e não é coletado se você estiver carregando arquivos locais.
+
+Nos entendemos que nem todo mundo quer compartilhar informações adicionais, e nós respeitamos sua privacidade.
+Você pode desabilitar a coleta de telemetria definindo a variável de ambiente `DISABLE_TELEMETRY` do seu terminal:
+
+No Linux/MacOS:
+
+```bash
+export DISABLE_TELEMETRY=YES
+```
+
+No Windows:
+
+```bash
+set DISABLE_TELEMETRY=YES
+```
diff --git a/diffusers/docs/source/pt/quicktour.md b/diffusers/docs/source/pt/quicktour.md
new file mode 100644
index 0000000000000000000000000000000000000000..1fae0e45e39f8371f8a339e297fee9a1c27c04fe
--- /dev/null
+++ b/diffusers/docs/source/pt/quicktour.md
@@ -0,0 +1,314 @@
+
+
+[[open-in-colab]]
+
+# Tour rápido
+
+Modelos de difusão são treinados para remover o ruído Gaussiano aleatório passo a passo para gerar uma amostra de interesse, como uma imagem ou áudio. Isso despertou um tremendo interesse em IA generativa, e você provavelmente já viu exemplos de imagens geradas por difusão na internet. 🧨 Diffusers é uma biblioteca que visa tornar os modelos de difusão amplamente acessíveis a todos.
+
+Seja você um desenvolvedor ou um usuário, esse tour rápido irá introduzir você ao 🧨 Diffusers e ajudar você a começar a gerar rapidamente! Há três componentes principais da biblioteca para conhecer:
+
+- O [`DiffusionPipeline`] é uma classe de alto nível de ponta a ponta desenhada para gerar rapidamente amostras de modelos de difusão pré-treinados para inferência.
+- [Modelos](./api/models) pré-treinados populares e módulos que podem ser usados como blocos de construção para criar sistemas de difusão.
+- Vários [Agendadores](./api/schedulers/overview) diferentes - algoritmos que controlam como o ruído é adicionado para treinamento, e como gerar imagens sem o ruído durante a inferência.
+
+Esse tour rápido mostrará como usar o [`DiffusionPipeline`] para inferência, e então mostrará como combinar um modelo e um agendador para replicar o que está acontecendo dentro do [`DiffusionPipeline`].
+
+
+
+Esse tour rápido é uma versão simplificada da introdução 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) para ajudar você a começar rápido. Se você quer aprender mais sobre o objetivo do 🧨 Diffusers, filosofia de design, e detalhes adicionais sobre a API principal, veja o notebook!
+
+
+
+Antes de começar, certifique-se de ter todas as bibliotecas necessárias instaladas:
+
+```py
+# uncomment to install the necessary libraries in Colab
+#!pip install --upgrade diffusers accelerate transformers
+```
+
+- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) acelera o carregamento do modelo para geração e treinamento.
+- [🤗 Transformers](https://huggingface.co/docs/transformers/index) é necessário para executar os modelos mais populares de difusão, como o [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
+
+## DiffusionPipeline
+
+O [`DiffusionPipeline`] é a forma mais fácil de usar um sistema de difusão pré-treinado para geração. É um sistema de ponta a ponta contendo o modelo e o agendador. Você pode usar o [`DiffusionPipeline`] pronto para muitas tarefas. Dê uma olhada na tabela abaixo para algumas tarefas suportadas, e para uma lista completa de tarefas suportadas, veja a tabela [Resumo do 🧨 Diffusers](./api/pipelines/overview#diffusers-summary).
+
+| **Tarefa** | **Descrição** | **Pipeline** |
+| -------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |
+| Unconditional Image Generation | gera uma imagem a partir do ruído Gaussiano | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
+| Text-Guided Image Generation | gera uma imagem a partir de um prompt de texto | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
+| Text-Guided Image-to-Image Translation | adapta uma imagem guiada por um prompt de texto | [img2img](./using-diffusers/img2img) |
+| Text-Guided Image-Inpainting | preenche a parte da máscara da imagem, dado a imagem, a máscara e o prompt de texto | [inpaint](./using-diffusers/inpaint) |
+| Text-Guided Depth-to-Image Translation | adapta as partes de uma imagem guiada por um prompt de texto enquanto preserva a estrutura por estimativa de profundidade | [depth2img](./using-diffusers/depth2img) |
+
+Comece criando uma instância do [`DiffusionPipeline`] e especifique qual checkpoint do pipeline você gostaria de baixar.
+Você pode usar o [`DiffusionPipeline`] para qualquer [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) armazenado no Hugging Face Hub.
+Nesse quicktour, você carregará o checkpoint [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) para geração de texto para imagem.
+
+
+
+Para os modelos de [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), por favor leia cuidadosamente a [licença](https://huggingface.co/spaces/CompVis/stable-diffusion-license) primeiro antes de rodar o modelo. 🧨 Diffusers implementa uma verificação de segurança: [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) para prevenir conteúdo ofensivo ou nocivo, mas as capacidades de geração de imagem aprimorada do modelo podem ainda produzir conteúdo potencialmente nocivo.
+
+
+
+Para carregar o modelo com o método [`~DiffusionPipeline.from_pretrained`]:
+
+```python
+>>> from diffusers import DiffusionPipeline
+
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+```
+
+O [`DiffusionPipeline`] baixa e armazena em cache todos os componentes de modelagem, tokenização, e agendamento. Você verá que o pipeline do Stable Diffusion é composto pelo [`UNet2DConditionModel`] e [`PNDMScheduler`] entre outras coisas:
+
+```py
+>>> pipeline
+StableDiffusionPipeline {
+ "_class_name": "StableDiffusionPipeline",
+ "_diffusers_version": "0.13.1",
+ ...,
+ "scheduler": [
+ "diffusers",
+ "PNDMScheduler"
+ ],
+ ...,
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL"
+ ]
+}
+```
+
+Nós fortemente recomendamos rodar o pipeline em uma placa de vídeo, pois o modelo consiste em aproximadamente 1.4 bilhões de parâmetros.
+Você pode mover o objeto gerador para uma placa de vídeo, assim como você faria no PyTorch:
+
+```python
+>>> pipeline.to("cuda")
+```
+
+Agora você pode passar o prompt de texto para o `pipeline` para gerar uma imagem, e então acessar a imagem sem ruído. Por padrão, a saída da imagem é embrulhada em um objeto [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
+
+```python
+>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
+>>> image
+```
+
+
+
+
+
+Salve a imagem chamando o `save`:
+
+```python
+>>> image.save("image_of_squirrel_painting.png")
+```
+
+### Pipeline local
+
+Você também pode utilizar o pipeline localmente. A única diferença é que você precisa baixar os pesos primeiro:
+
+```bash
+!git lfs install
+!git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
+```
+
+Assim carregue os pesos salvos no pipeline:
+
+```python
+>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", use_safetensors=True)
+```
+
+Agora você pode rodar o pipeline como você faria na seção acima.
+
+### Troca dos agendadores
+
+Agendadores diferentes tem diferentes velocidades de retirar o ruído e compensações de qualidade. A melhor forma de descobrir qual funciona melhor para você é testar eles! Uma das principais características do 🧨 Diffusers é permitir que você troque facilmente entre agendadores. Por exemplo, para substituir o [`PNDMScheduler`] padrão com o [`EulerDiscreteScheduler`], carregue ele com o método [`~diffusers.ConfigMixin.from_config`]:
+
+```py
+>>> from diffusers import EulerDiscreteScheduler
+
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+```
+
+Tente gerar uma imagem com o novo agendador e veja se você nota alguma diferença!
+
+Na próxima seção, você irá dar uma olhada mais de perto nos componentes - o modelo e o agendador - que compõe o [`DiffusionPipeline`] e aprender como usar esses componentes para gerar uma imagem de um gato.
+
+## Modelos
+
+A maioria dos modelos recebe uma amostra de ruído, e em cada _timestep_ ele prevê o _noise residual_ (outros modelos aprendem a prever a amostra anterior diretamente ou a velocidade ou [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)), a diferença entre uma imagem menos com ruído e a imagem de entrada. Você pode misturar e combinar modelos para criar outros sistemas de difusão.
+
+Modelos são inicializados com o método [`~ModelMixin.from_pretrained`] que também armazena em cache localmente os pesos do modelo para que seja mais rápido na próxima vez que você carregar o modelo. Para o tour rápido, você irá carregar o [`UNet2DModel`], um modelo básico de geração de imagem incondicional com um checkpoint treinado em imagens de gato:
+
+```py
+>>> from diffusers import UNet2DModel
+
+>>> repo_id = "google/ddpm-cat-256"
+>>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
+```
+
+Para acessar os parâmetros do modelo, chame `model.config`:
+
+```py
+>>> model.config
+```
+
+A configuração do modelo é um dicionário 🧊 congelado 🧊, o que significa que esses parâmetros não podem ser mudados depois que o modelo é criado. Isso é intencional e garante que os parâmetros usados para definir a arquitetura do modelo no início permaneçam os mesmos, enquanto outros parâmetros ainda podem ser ajustados durante a geração.
+
+Um dos parâmetros mais importantes são:
+
+- `sample_size`: a dimensão da altura e largura da amostra de entrada.
+- `in_channels`: o número de canais de entrada da amostra de entrada.
+- `down_block_types` e `up_block_types`: o tipo de blocos de downsampling e upsampling usados para criar a arquitetura UNet.
+- `block_out_channels`: o número de canais de saída dos blocos de downsampling; também utilizado como uma order reversa do número de canais de entrada dos blocos de upsampling.
+- `layers_per_block`: o número de blocks ResNet presentes em cada block UNet.
+
+Para usar o modelo para geração, crie a forma da imagem com ruído Gaussiano aleatório. Deve ter um eixo `batch` porque o modelo pode receber múltiplos ruídos aleatórios, um eixo `channel` correspondente ao número de canais de entrada, e um eixo `sample_size` para a altura e largura da imagem:
+
+```py
+>>> import torch
+
+>>> torch.manual_seed(0)
+
+>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
+>>> noisy_sample.shape
+torch.Size([1, 3, 256, 256])
+```
+
+Para geração, passe a imagem com ruído para o modelo e um `timestep`. O `timestep` indica o quão ruidosa a imagem de entrada é, com mais ruído no início e menos no final. Isso ajuda o modelo a determinar sua posição no processo de difusão, se está mais perto do início ou do final. Use o método `sample` para obter a saída do modelo:
+
+```py
+>>> with torch.no_grad():
+... noisy_residual = model(sample=noisy_sample, timestep=2).sample
+```
+
+Para geração de exemplos reais, você precisará de um agendador para guiar o processo de retirada do ruído. Na próxima seção, você irá aprender como acoplar um modelo com um agendador.
+
+## Agendadores
+
+Agendadores gerenciam a retirada do ruído de uma amostra ruidosa para uma amostra menos ruidosa dado a saída do modelo - nesse caso, é o `noisy_residual`.
+
+
+
+🧨 Diffusers é uma caixa de ferramentas para construir sistemas de difusão. Enquanto o [`DiffusionPipeline`] é uma forma conveniente de começar com um sistema de difusão pré-construído, você também pode escolher seus próprios modelos e agendadores separadamente para construir um sistema de difusão personalizado.
+
+
+
+Para o tour rápido, você irá instanciar o [`DDPMScheduler`] com o método [`~diffusers.ConfigMixin.from_config`]:
+
+```py
+>>> from diffusers import DDPMScheduler
+
+>>> scheduler = DDPMScheduler.from_config(repo_id)
+>>> scheduler
+DDPMScheduler {
+ "_class_name": "DDPMScheduler",
+ "_diffusers_version": "0.13.1",
+ "beta_end": 0.02,
+ "beta_schedule": "linear",
+ "beta_start": 0.0001,
+ "clip_sample": true,
+ "clip_sample_range": 1.0,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "trained_betas": null,
+ "variance_type": "fixed_small"
+}
+```
+
+
+
+💡 Perceba como o agendador é instanciado de uma configuração. Diferentemente de um modelo, um agendador não tem pesos treináveis e é livre de parâmetros!
+
+
+
+Um dos parâmetros mais importante são:
+
+- `num_train_timesteps`: o tamanho do processo de retirar ruído ou em outras palavras, o número de _timesteps_ necessários para o processo de ruídos Gausianos aleatórios dentro de uma amostra de dados.
+- `beta_schedule`: o tipo de agendados de ruído para o uso de geração e treinamento.
+- `beta_start` e `beta_end`: para começar e terminar os valores de ruído para o agendador de ruído.
+
+Para predizer uma imagem com um pouco menos de ruído, passe o seguinte para o método do agendador [`~diffusers.DDPMScheduler.step`]: saída do modelo, `timestep`, e a atual `amostra`.
+
+```py
+>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample
+>>> less_noisy_sample.shape
+```
+
+O `less_noisy_sample` pode ser passado para o próximo `timestep` onde ele ficará ainda com menos ruído! Vamos juntar tudo agora e visualizar o processo inteiro de retirada de ruído.
+
+Comece, criando a função que faça o pós-processamento e mostre a imagem sem ruído como uma `PIL.Image`:
+
+```py
+>>> import PIL.Image
+>>> import numpy as np
+
+
+>>> def display_sample(sample, i):
+... image_processed = sample.cpu().permute(0, 2, 3, 1)
+... image_processed = (image_processed + 1.0) * 127.5
+... image_processed = image_processed.numpy().astype(np.uint8)
+
+... image_pil = PIL.Image.fromarray(image_processed[0])
+... display(f"Image at step {i}")
+... display(image_pil)
+```
+
+Para acelerar o processo de retirada de ruído, mova a entrada e o modelo para uma GPU:
+
+```py
+>>> model.to("cuda")
+>>> noisy_sample = noisy_sample.to("cuda")
+```
+
+Agora, crie um loop de retirada de ruído que prediz o residual da amostra menos ruidosa, e computa a amostra menos ruidosa com o agendador:
+
+```py
+>>> import tqdm
+
+>>> sample = noisy_sample
+
+>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
+... # 1. predict noise residual
+... with torch.no_grad():
+... residual = model(sample, t).sample
+
+... # 2. compute less noisy image and set x_t -> x_t-1
+... sample = scheduler.step(residual, t, sample).prev_sample
+
+... # 3. optionally look at image
+... if (i + 1) % 50 == 0:
+... display_sample(sample, i + 1)
+```
+
+Sente-se e assista o gato ser gerado do nada além de ruído! 😻
+
+
+
+
+
+## Próximos passos
+
+Esperamos que você tenha gerado algumas imagens legais com o 🧨 Diffusers neste tour rápido! Para suas próximas etapas, você pode
+
+- Treine ou faça a configuração fina de um modelo para gerar suas próprias imagens no tutorial de [treinamento](./tutorials/basic_training).
+- Veja exemplos oficiais e da comunidade de [scripts de treinamento ou configuração fina](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) para os mais variados casos de uso.
+- Aprenda sobre como carregar, acessar, mudar e comparar agendadores no guia [Usando diferentes agendadores](./using-diffusers/schedulers).
+- Explore engenharia de prompt, otimizações de velocidade e memória, e dicas e truques para gerar imagens de maior qualidade com o guia [Stable Diffusion](./stable_diffusion).
+- Se aprofunde em acelerar 🧨 Diffusers com guias sobre [PyTorch otimizado em uma GPU](./optimization/fp16), e guias de inferência para rodar [Stable Diffusion em Apple Silicon (M1/M2)](./optimization/mps) e [ONNX Runtime](./optimization/onnx).
diff --git a/diffusers/docs/source/zh/_toctree.yml b/diffusers/docs/source/zh/_toctree.yml
new file mode 100644
index 0000000000000000000000000000000000000000..41d5e95a42305f9562926fc3e4e9a28337f2a176
--- /dev/null
+++ b/diffusers/docs/source/zh/_toctree.yml
@@ -0,0 +1,10 @@
+- sections:
+ - local: index
+ title: 🧨 Diffusers
+ - local: quicktour
+ title: 快速入门
+ - local: stable_diffusion
+ title: 有效和高效的扩散
+ - local: installation
+ title: 安装
+ title: 开始
diff --git a/diffusers/docs/source/zh/index.md b/diffusers/docs/source/zh/index.md
new file mode 100644
index 0000000000000000000000000000000000000000..e1a2a3971d87ce823e4668662d65c2b55602b87f
--- /dev/null
+++ b/diffusers/docs/source/zh/index.md
@@ -0,0 +1,101 @@
+
+
+
+
+非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+
+
+
+## 最后
+
+在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
+
+- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
+- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
+- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
diff --git a/diffusers/examples/README.md b/diffusers/examples/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f0d8a6bb57f0f94a633cc05b578b2f316ae7155b
--- /dev/null
+++ b/diffusers/examples/README.md
@@ -0,0 +1,72 @@
+
+
+# 🧨 Diffusers Examples
+
+Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
+for a variety of use cases involving training or fine-tuning.
+
+**Note**: If you are looking for **official** examples on how to use `diffusers` for inference,
+please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
+
+Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
+More specifically, this means:
+
+- **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
+- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
+- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
+- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling
+point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
+
+We provide **official** examples that cover the most popular tasks of diffusion models.
+*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
+If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
+
+Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
+
+| Task | 🤗 Accelerate | 🤗 Datasets | Colab
+|---|---|:---:|:---:|
+| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
+| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
+| [**ControlNet**](./controlnet) | ✅ | ✅ | -
+| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | -
+| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/reinforcement_learning/run_diffusers_locomotion.py) | - | - | coming soon.
+
+## Community
+
+In addition, we provide **community** examples, which are examples added and maintained by our community.
+Community examples can consist of both *training* examples or *inference* pipelines.
+For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.
+Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
+**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
+
+## Research Projects
+
+We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.
+
+## Important note
+
+To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+Then cd in the example folder of your choice and run
+```bash
+pip install -r requirements.txt
+```
diff --git a/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..f032634a11f0e03b0500facbaa8716663e5780d0
--- /dev/null
+++ b/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -0,0 +1,1968 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import hashlib
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+# imports of the TokenEmbeddingsHandler class
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr, unet_lora_state_dict
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ instance_prompt=str,
+ validation_prompt=str,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = "widget:\n" if images else ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"""
+ - text: '{validation_prompt if validation_prompt else ' ' }'
+ output:
+ url: >-
+ "image_{i}.png"
+ """
+
+ yaml = f"""
+---
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- lora
+- template:sd-lora
+widget:
+{img_str}
+---
+base_model: {base_model}
+instance_prompt: {instance_prompt}
+license: openrail++
+---
+ """
+
+ model_card = f"""
+# SDXL LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA adaption weights for {base_model}.
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+LoRA for the text encoder was enabled: {train_text_encoder}.
+Special VAE used for training: {vae_path}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+Weights for this model are available in Safetensors format.
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--token_abstraction",
+ default="TOK",
+ help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
+ "captions - e.g. TOK",
+ )
+
+ parser.add_argument(
+ "--num_new_tokens_per_abstraction",
+ default=2,
+ help="number of new tokens inserted to the tokenizers per token_abstraction value when "
+ "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
+ "tokens - ",
+ )
+
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti",
+ action="store_true",
+ help=("Whether to use textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti_frac",
+ type=float,
+ default=0.5,
+ help=("The percentage of epochs to perform textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_frac",
+ type=float,
+ default=0.5,
+ help=("The percentage of epochs to perform text encoder tuning"),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="adamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ if args.train_text_encoder and args.train_text_encoder_ti:
+ raise ValueError(
+ "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. "
+ "For full LoRA text encoder training check --train_text_encoder, for textual "
+ "inversion training check `--train_text_encoder_ti`"
+ )
+
+ if args.train_text_encoder_ti:
+ if isinstance(args.token_abstraction, str):
+ args.token_abstraction = [args.token_abstraction]
+ elif isinstance(args.token_abstraction, List):
+ args.token_abstraction = args.token_abstraction
+ else:
+ raise ValueError(
+ f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
+ f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
+ )
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
+class TokenEmbeddingsHandler:
+ def __init__(self, text_encoders, tokenizers):
+ self.text_encoders = text_encoders
+ self.tokenizers = tokenizers
+
+ self.train_ids: Optional[torch.Tensor] = None
+ self.inserting_toks: Optional[List[str]] = None
+ self.embeddings_settings = {}
+
+ def initialize_new_tokens(self, inserting_toks: List[str]):
+ idx = 0
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
+ assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
+
+ self.inserting_toks = inserting_toks
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
+
+ # random initialization of new tokens
+ std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()
+
+ print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
+
+ text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (
+ torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)
+ .to(device=self.device)
+ .to(dtype=self.dtype)
+ * std_token_embedding
+ )
+ self.embeddings_settings[
+ f"original_embeddings_{idx}"
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
+
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
+ inu[self.train_ids] = False
+
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
+
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
+
+ idx += 1
+
+ def save_embeddings(self, file_path: str):
+ assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
+ tensors = {}
+ for idx, text_encoder in enumerate(self.text_encoders):
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
+ self.tokenizers[0]
+ ), "Tokenizers should be the same."
+ new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
+ tensors[f"text_encoders_{idx}"] = new_token_embeddings
+
+ save_file(tensors, file_path)
+
+ @property
+ def dtype(self):
+ return self.text_encoders[0].dtype
+
+ @property
+ def device(self):
+ return self.text_encoders[0].device
+
+ # def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
+ # # Assuming new tokens are of the format
+ # self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])]
+ # special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
+ # tokenizer.add_special_tokens(special_tokens_dict)
+ # text_encoder.resize_token_embeddings(len(tokenizer))
+ #
+ # self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
+ # assert self.train_ids is not None, "New tokens could not be converted to IDs."
+ # text_encoder.text_model.embeddings.token_embedding.weight.data[
+ # self.train_ids
+ # ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
+
+ @torch.no_grad()
+ def retract_embeddings(self):
+ for idx, text_encoder in enumerate(self.text_encoders):
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
+ text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
+ .to(device=text_encoder.device)
+ .to(dtype=text_encoder.dtype)
+ )
+
+ # for the parts that were updated, we need to normalize them
+ # to have the same std as before
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
+
+ index_updates = ~index_no_updates
+ new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]
+ off_ratio = std_token_embedding / new_embeddings.std()
+
+ new_embeddings = new_embeddings * (off_ratio**0.1)
+ text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
+
+ # def load_embeddings(self, file_path: str):
+ # with safe_open(file_path, framework="pt", device=self.device.type) as f:
+ # for idx in range(len(self.text_encoders)):
+ # text_encoder = self.text_encoders[idx]
+ # tokenizer = self.tokenizers[idx]
+ #
+ # loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
+ # self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ token_abstraction_dict=None, # token mapping for textual inversion
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+ self.token_abstraction_dict = token_abstraction_dict
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.instance_images[index % self.num_instance_images]
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ if args.train_text_encoder_ti:
+ # replace instances of --token_abstraction in caption with the new tokens: "" etc.
+ for token_abs, token_replacement in self.token_abstraction_dict.items():
+ caption = caption.replace(token_abs, "".join(token_replacement))
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # costum prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ add_special_tokens=add_special_tokens,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ if args.train_text_encoder_ti:
+ token_abstraction_dict = {}
+ token_idx = 0
+ for i, token in enumerate(args.token_abstraction):
+ token_abstraction_dict[token] = [
+ f"" for j in range(args.num_new_tokens_per_abstraction)
+ ]
+ token_idx += args.num_new_tokens_per_abstraction - 1
+
+ # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc.
+ for token_abs, token_replacement in token_abstraction_dict.items():
+ args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
+ if args.with_prior_preservation:
+ args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
+
+ # initialize the new tokens for textual inversion
+ embedding_handler = TokenEmbeddingsHandler(
+ [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
+ )
+ inserting_toks = []
+ for new_tok in token_abstraction_dict.values():
+ inserting_toks.extend(new_tok)
+ embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_parameters = []
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
+ # Parse the attention module.
+ attn_module = unet
+ for n in attn_processor_name.split(".")[:-1]:
+ attn_module = getattr(attn_module, n)
+
+ # Set the `lora_layer` attribute of the attention-related matrices.
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=args.rank,
+ )
+ )
+
+ # Accumulate the LoRA params to optimize.
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_one, dtype=torch.float32, rank=args.rank
+ )
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_two, dtype=torch.float32, rank=args.rank
+ )
+
+ # if we use textual inversion, we freeze all parameters except for the token embeddings
+ # in text encoder
+ elif args.train_text_encoder_ti:
+ text_lora_parameters_one = []
+ for name, param in text_encoder_one.named_parameters():
+ if "token_embedding" in name:
+ param.requires_grad = True
+ text_lora_parameters_one.append(param)
+ else:
+ param.requires_grad = False
+ text_lora_parameters_two = []
+ for name, param in text_encoder_two.named_parameters():
+ if "token_embedding" in name:
+ param.requires_grad = True
+ text_lora_parameters_two.append(param)
+ else:
+ param.requires_grad = False
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = unet_lora_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
+
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
+ freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if not freeze_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warn(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warn(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warn(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warn(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Computes additional embeddings/ids required by the SDXL UNet.
+ # regular text embeddings (when `train_text_encoder` is not True)
+ # pooled text embeddings
+ # time ids
+
+ def compute_time_ids():
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ original_size = (args.resolution, args.resolution)
+ target_size = (args.resolution, args.resolution)
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Handle instance prompt.
+ instance_time_ids = compute_time_ids()
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ class_time_ids = compute_time_ids()
+ if freeze_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ add_time_ids = instance_time_ids
+ if args.with_prior_preservation:
+ add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
+
+ # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
+ add_special_tokens = True if args.train_text_encoder_ti else False
+
+ if not train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, add_special_tokens)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, add_special_tokens)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if not freeze_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ if args.train_text_encoder:
+ num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
+ elif args.train_text_encoder_ti: # args.train_text_encoder_ti
+ num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ # if performing any kind of optimization of text_encoder params
+ if args.train_text_encoder or args.train_text_encoder_ti:
+ if epoch == num_train_epochs_text_encoder:
+ print("PIVOT HALFWAY", epoch)
+ # stopping optimization of text_encoder params
+ params_to_optimize = params_to_optimize[:1]
+ # reinitializing the optimizer to optimize only on unet params
+ if args.optimizer.lower() == "prodigy":
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+ else: # AdamW or 8-bit-AdamW
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+ else:
+ # still optimizng the text encoder
+ text_encoder_one.train()
+ text_encoder_two.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ if args.train_text_encoder:
+ text_encoder_one.text_model.embeddings.requires_grad_(True)
+ text_encoder_two.text_model.embeddings.requires_grad_(True)
+
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+ print(prompts)
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+ elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
+
+ else:
+ elems_to_repeat_text_embeds = 1
+ elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
+
+ # Predict the noise residual
+ if freeze_text_encoder:
+ unet_added_conditions = {
+ "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ }
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ ).sample
+ else:
+ unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two],
+ )
+ unet_added_conditions.update(
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if (args.train_text_encoder or args.train_text_encoder_ti)
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # every step, we reset the embeddings to the original embeddings.
+ if args.train_text_encoder_ti:
+ for idx, text_encoder in enumerate(text_encoders):
+ embedding_handler.retract_embeddings()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, **scheduler_args
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = unet_lora_state_dict(unet)
+
+ if args.train_text_encoder:
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ if args.train_text_encoder_ti:
+ embedding_handler.save_embeddings(
+ f"{args.output_dir}/embeddings.safetensors",
+ )
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/community/README.md b/diffusers/examples/community/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..96d5304129793ffa42ff23671cdb2a73d2ee0ab9
--- /dev/null
+++ b/diffusers/examples/community/README.md
@@ -0,0 +1,2483 @@
+# Community Examples
+
+> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
+
+**Community** examples consist of both inference and training examples that have been added by the community.
+Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.
+If a community doesn't work as expected, please open an issue and ping the author on it.
+
+| Example | Description | Code Example | Colab | Author |
+|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
+| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
+| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
+| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
+| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
+| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
+| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
+| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
+| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
+| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
+| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
+| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
+| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
+| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
+| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+ Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
+ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
+| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
+| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
+| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
+| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) |
+| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.0986) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint ) | - | [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
+| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |
+| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
+| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
+| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
+| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
+Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
+FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
+sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
+prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
+| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
+| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
+| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
+
+
+To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
+```py
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="filename_in_the_community_folder")
+```
+
+## Example usages
+
+### LLM-grounded Diffusion
+
+LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)
+
+![Main Image](https://llm-grounded-diffusion.github.io/main_figure.jpg)
+![Visualizations: Enhanced Prompt Understanding](https://llm-grounded-diffusion.github.io/visualizations.jpg)
+
+This pipeline can be used with an LLM or on its own. We provide a parser that parses LLM outputs to the layouts. You can obtain the prompt to input to the LLM for layout generation [here](https://github.com/TonyLianLong/LLM-groundedDiffusion/blob/main/prompt.py). After feeding the prompt to an LLM (e.g., GPT-4 on ChatGPT website), you can feed the LLM response into our pipeline.
+
+The following code has been tested on 1x RTX 4090, but it should also support GPUs with lower GPU memory.
+
+#### Use this pipeline with an LLM
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "longlian/lmd_plus",
+ custom_pipeline="llm_grounded_diffusion",
+ variant="fp16", torch_dtype=torch.float16
+)
+pipe.enable_model_cpu_offload()
+
+# Generate directly from a text prompt and an LLM response
+prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+phrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response("""
+[('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]
+Background prompt: A beautiful forest with fall foliage
+Negative prompt:
+""")
+
+images = pipe(
+ prompt=prompt,
+ negative_prompt=neg_prompt,
+ phrases=phrases,
+ boxes=boxes,
+ gligen_scheduled_sampling_beta=0.4,
+ output_type="pil",
+ num_inference_steps=50,
+ lmd_guidance_kwargs={}
+).images
+
+images[0].save("./lmd_plus_generation.jpg")
+```
+
+#### Use this pipeline on its own for layout generation
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "longlian/lmd_plus",
+ custom_pipeline="llm_grounded_diffusion",
+ variant="fp16", torch_dtype=torch.float16
+)
+pipe.enable_model_cpu_offload()
+
+# Generate an image described by the prompt and
+# insert objects described by text at the region defined by bounding boxes
+prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
+phrases = ["a waterfall", "a modern high speed train"]
+
+images = pipe(
+ prompt=prompt,
+ phrases=phrases,
+ boxes=boxes,
+ gligen_scheduled_sampling_beta=0.4,
+ output_type="pil",
+ num_inference_steps=50,
+ lmd_guidance_kwargs={}
+).images
+
+images[0].save("./lmd_plus_generation.jpg")
+```
+
+### CLIP Guided Stable Diffusion
+
+CLIP guided stable diffusion can help to generate more realistic images
+by guiding stable diffusion at every denoising step with an additional CLIP model.
+
+The following code requires roughly 12GB of GPU RAM.
+
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
+import torch
+
+
+feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
+clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16)
+
+
+guided_pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+
+ torch_dtype=torch.float16,
+)
+guided_pipeline.enable_attention_slicing()
+guided_pipeline = guided_pipeline.to("cuda")
+
+prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+
+generator = torch.Generator(device="cuda").manual_seed(0)
+images = []
+for i in range(4):
+ image = guided_pipeline(
+ prompt,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+# save images locally
+for i, img in enumerate(images):
+ img.save(f"./clip_guided_sd/image_{i}.png")
+```
+
+The `images` list contains a list of PIL images that can be saved locally or displayed directly in a google colab.
+Generated images tend to be of higher qualtiy than natively using stable diffusion. E.g. the above script generates the following images:
+
+![clip_guidance](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/clip_guidance/merged_clip_guidance.jpg).
+
+### One Step Unet
+
+The dummy "one-step-unet" can be run as follows:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
+pipe()
+```
+
+**Note**: This community pipeline is not useful as a feature, but rather just serves as an example of how community pipelines can be added (see https://github.com/huggingface/diffusers/issues/841).
+
+### Stable Diffusion Interpolation
+
+The following code can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ revision='fp16',
+ torch_dtype=torch.float16,
+ safety_checker=None, # Very important for videos...lots of false positives while interpolating
+ custom_pipeline="interpolate_stable_diffusion",
+).to('cuda')
+pipe.enable_attention_slicing()
+
+frame_filepaths = pipe.walk(
+ prompts=['a dog', 'a cat', 'a horse'],
+ seeds=[42, 1337, 1234],
+ num_interpolation_steps=16,
+ output_dir='./dreams',
+ batch_size=4,
+ height=512,
+ width=512,
+ guidance_scale=8.5,
+ num_inference_steps=50,
+)
+```
+
+The output of the `walk(...)` function returns a list of images saved under the folder as defined in `output_dir`. You can use these images to create videos of stable diffusion.
+
+> **Please have a look at https://github.com/nateraw/stable-diffusion-videos for more in-detail information on how to create videos using stable diffusion as well as more feature-complete functionality.**
+
+### Stable Diffusion Mega
+
+The Stable Diffusion Mega Pipeline lets you use the main use cases of the stable diffusion pipeline in a single class.
+
+```python
+#!/usr/bin/env python3
+from diffusers import DiffusionPipeline
+import PIL
+import requests
+from io import BytesIO
+import torch
+
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16")
+pipe.to("cuda")
+pipe.enable_attention_slicing()
+
+
+### Text-to-Image
+
+images = pipe.text2img("An astronaut riding a horse").images
+
+### Image-to-Image
+
+init_image = download_image("https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg")
+
+prompt = "A fantasy landscape, trending on artstation"
+
+images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+
+### Inpainting
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+prompt = "a cat sitting on a bench"
+images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
+```
+
+As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.
+
+### Long Prompt Weighting Stable Diffusion
+Features of this custom pipeline:
+- Input a prompt without the 77 token length limit.
+- Includes tx2img, img2img. and inpainting pipelines.
+- Emphasize/weigh part of your prompt with parentheses as so: `a baby deer with (big eyes)`
+- De-emphasize part of your prompt as so: `a [baby] deer with big eyes`
+- Precisely weigh part of your prompt as so: `a baby deer with (big eyes:1.3)`
+
+Prompt weighting equivalents:
+- `a baby deer with` == `(a baby deer with:1.0)`
+- `(big eyes)` == `(big eyes:1.1)`
+- `((big eyes))` == `(big eyes:1.21)`
+- `[big eyes]` == `(big eyes:0.91)`
+
+You can run this custom pipeline as so:
+
+#### pytorch
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ 'hakurei/waifu-diffusion',
+ custom_pipeline="lpw_stable_diffusion",
+
+ torch_dtype=torch.float16
+)
+pipe=pipe.to("cuda")
+
+prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
+neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512,height=512,max_embeddings_multiples=3).images[0]
+
+```
+
+#### onnxruntime
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ 'CompVis/stable-diffusion-v1-4',
+ custom_pipeline="lpw_stable_diffusion_onnx",
+ revision="onnx",
+ provider="CUDAExecutionProvider"
+)
+
+prompt = "a photo of an astronaut riding a horse on mars, best quality"
+neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt,negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+
+```
+
+if you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal.
+
+### Speech to Image
+
+The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion.
+
+```Python
+import torch
+
+import matplotlib.pyplot as plt
+from datasets import load_dataset
+from diffusers import DiffusionPipeline
+from transformers import (
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+audio_sample = ds[3]
+
+text = audio_sample["text"].lower()
+speech_data = audio_sample["audio"]["array"]
+
+model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
+processor = WhisperProcessor.from_pretrained("openai/whisper-small")
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="speech_to_image_diffusion",
+ speech_model=model,
+ speech_processor=processor,
+
+ torch_dtype=torch.float16,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
+
+output = diffuser_pipeline(speech_data)
+plt.imshow(output.images[0])
+```
+This example produces the following image:
+
+![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png)
+
+### Wildcard Stable Diffusion
+Following the great examples from https://github.com/jtkelm2/stable-diffusion-webui-1/blob/master/scripts/wildcards.py and https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts#wildcards, here's a minimal implementation that allows for users to add "wildcards", denoted by `__wildcard__` to prompts that are used as placeholders for randomly sampled values given by either a dictionary or a `.txt` file. For example:
+
+Say we have a prompt:
+
+```
+prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
+```
+
+We can then define possible values to be sampled for `animal`, `object`, and `clothing`. These can either be from a `.txt` with the same name as the category.
+
+The possible values can also be defined / combined by using a dictionary like: `{"animal":["dog", "cat", mouse"]}`.
+
+The actual pipeline works just like `StableDiffusionPipeline`, except the `__call__` method takes in:
+
+`wildcard_files`: list of file paths for wild card replacement
+`wildcard_option_dict`: dict with key as `wildcard` and values as a list of possible replacements
+`num_prompt_samples`: number of prompts to sample, uniformly sampling wildcards
+
+A full example:
+
+create `animal.txt`, with contents like:
+
+```
+dog
+cat
+mouse
+```
+
+create `object.txt`, with contents like:
+
+```
+chair
+sofa
+bench
+```
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="wildcard_stable_diffusion",
+
+ torch_dtype=torch.float16,
+)
+prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
+out = pipe(
+ prompt,
+ wildcard_option_dict={
+ "clothing":["hat", "shirt", "scarf", "beret"]
+ },
+ wildcard_files=["object.txt", "animal.txt"],
+ num_prompt_samples=1
+)
+```
+
+### Composable Stable diffusion
+
+[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
+
+```python
+import torch as th
+import numpy as np
+import torchvision.utils as tvu
+
+from diffusers import DiffusionPipeline
+
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--prompt", type=str, default="mystical trees | A magical pond | dark",
+ help="use '|' as the delimiter to compose separate sentences.")
+parser.add_argument("--steps", type=int, default=50)
+parser.add_argument("--scale", type=float, default=7.5)
+parser.add_argument("--weights", type=str, default="7.5 | 7.5 | -7.5")
+parser.add_argument("--seed", type=int, default=2)
+parser.add_argument("--model_path", type=str, default="CompVis/stable-diffusion-v1-4")
+parser.add_argument("--num_images", type=int, default=1)
+args = parser.parse_args()
+
+has_cuda = th.cuda.is_available()
+device = th.device('cpu' if not has_cuda else 'cuda')
+
+prompt = args.prompt
+scale = args.scale
+steps = args.steps
+
+pipe = DiffusionPipeline.from_pretrained(
+ args.model_path,
+ custom_pipeline="composable_stable_diffusion",
+).to(device)
+
+pipe.safety_checker = None
+
+images = []
+generator = th.Generator("cuda").manual_seed(args.seed)
+for i in range(args.num_images):
+ image = pipe(prompt, guidance_scale=scale, num_inference_steps=steps,
+ weights=args.weights, generator=generator).images[0]
+ images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
+grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
+tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
+
+```
+
+### Imagic Stable Diffusion
+Allows you to edit an image using stable diffusion.
+
+```python
+import requests
+from PIL import Image
+from io import BytesIO
+import torch
+import os
+from diffusers import DiffusionPipeline, DDIMScheduler
+has_cuda = torch.cuda.is_available()
+device = torch.device('cpu' if not has_cuda else 'cuda')
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ safety_checker=None,
+ use_auth_token=True,
+ custom_pipeline="imagic_stable_diffusion",
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+).to(device)
+generator = torch.Generator("cuda").manual_seed(0)
+seed = 0
+prompt = "A photo of Barack Obama smiling with a big grin"
+url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = init_image.resize((512, 512))
+res = pipe.train(
+ prompt,
+ image=init_image,
+ generator=generator)
+res = pipe(alpha=1, guidance_scale=7.5, num_inference_steps=50)
+os.makedirs("imagic", exist_ok=True)
+image = res.images[0]
+image.save('./imagic/imagic_image_alpha_1.png')
+res = pipe(alpha=1.5, guidance_scale=7.5, num_inference_steps=50)
+image = res.images[0]
+image.save('./imagic/imagic_image_alpha_1_5.png')
+res = pipe(alpha=2, guidance_scale=7.5, num_inference_steps=50)
+image = res.images[0]
+image.save('./imagic/imagic_image_alpha_2.png')
+```
+
+### Seed Resizing
+Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
+
+```python
+import torch as th
+import numpy as np
+from diffusers import DiffusionPipeline
+
+has_cuda = th.cuda.is_available()
+device = th.device('cpu' if not has_cuda else 'cuda')
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ use_auth_token=True,
+ custom_pipeline="seed_resize_stable_diffusion"
+).to(device)
+
+def dummy(images, **kwargs):
+ return images, False
+
+pipe.safety_checker = dummy
+
+
+images = []
+th.manual_seed(0)
+generator = th.Generator("cuda").manual_seed(0)
+
+seed = 0
+prompt = "A painting of a futuristic cop"
+
+width = 512
+height = 512
+
+res = pipe(
+ prompt,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ height=height,
+ width=width,
+ generator=generator)
+image = res.images[0]
+image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
+
+
+th.manual_seed(0)
+generator = th.Generator("cuda").manual_seed(0)
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ use_auth_token=True,
+ custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
+).to(device)
+
+width = 512
+height = 592
+
+res = pipe(
+ prompt,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ height=height,
+ width=width,
+ generator=generator)
+image = res.images[0]
+image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=height))
+
+pipe_compare = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ use_auth_token=True,
+ custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
+).to(device)
+
+res = pipe_compare(
+ prompt,
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ height=height,
+ width=width,
+ generator=generator
+)
+
+image = res.images[0]
+image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
+```
+
+### Multilingual Stable Diffusion Pipeline
+
+The following code can generate an images from texts in different languages using the pre-trained [mBART-50 many-to-one multilingual machine translation model](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) and Stable Diffusion.
+
+```python
+from PIL import Image
+
+import torch
+
+from diffusers import DiffusionPipeline
+from transformers import (
+ pipeline,
+ MBart50TokenizerFast,
+ MBartForConditionalGeneration,
+)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+device_dict = {"cuda": 0, "cpu": -1}
+
+# helper function taken from: https://huggingface.co/blog/stable_diffusion
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows*cols
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(cols*w, rows*h))
+ grid_w, grid_h = grid.size
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i%cols*w, i//cols*h))
+ return grid
+
+# Add language detection pipeline
+language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
+language_detection_pipeline = pipeline("text-classification",
+ model=language_detection_model_ckpt,
+ device=device_dict[device])
+
+# Add model for language translation
+trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
+trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="multilingual_stable_diffusion",
+ detection_pipeline=language_detection_pipeline,
+ translation_model=trans_model,
+ translation_tokenizer=trans_tokenizer,
+
+ torch_dtype=torch.float16,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
+
+prompt = ["a photograph of an astronaut riding a horse",
+ "Una casa en la playa",
+ "Ein Hund, der Orange isst",
+ "Un restaurant parisien"]
+
+output = diffuser_pipeline(prompt)
+
+images = output.images
+
+grid = image_grid(images, rows=2, cols=2)
+```
+
+This example produces the following images:
+![image](https://user-images.githubusercontent.com/4313860/198328706-295824a4-9856-4ce5-8e66-278ceb42fd29.png)
+
+### Image to Image Inpainting Stable Diffusion
+
+Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument.
+
+`image`, `inner_image`, and `mask` should have the same dimensions. `inner_image` should have an alpha (transparency) channel.
+
+The aim is to overlay two images, then mask out the boundary between `image` and `inner_image` to allow stable diffusion to make the connection more seamless.
+For example, this could be used to place a logo on a shirt and make it blend seamlessly.
+
+```python
+import PIL
+import torch
+
+from diffusers import DiffusionPipeline
+
+image_path = "./path-to-image.png"
+inner_image_path = "./path-to-inner-image.png"
+mask_path = "./path-to-mask.png"
+
+init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
+inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
+mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ custom_pipeline="img2img_inpainting",
+
+ torch_dtype=torch.float16
+)
+pipe = pipe.to("cuda")
+
+prompt = "Your prompt here!"
+image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
+```
+
+![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
+
+### Text Based Inpainting Stable Diffusion
+
+Use a text prompt to generate the mask for the area to be inpainted.
+Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting.
+
+```python
+from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
+from diffusers import DiffusionPipeline
+
+from PIL import Image
+import requests
+
+processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
+model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ custom_pipeline="text_inpainting",
+ segmentation_model=model,
+ segmentation_processor=processor
+)
+pipe = pipe.to("cuda")
+
+
+url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true"
+image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
+text = "a glass" # will mask out this text
+prompt = "a cup" # the masked out region will be replaced with this
+
+image = pipe(image=image, text=text, prompt=prompt).images[0]
+```
+
+### Bit Diffusion
+Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete data - eg, discreate image data, DNA sequence data. An unconditional discreate image can be generated like this:
+
+```python
+from diffusers import DiffusionPipeline
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion")
+image = pipe().images[0]
+
+```
+
+### Stable Diffusion with K Diffusion
+
+Make sure you have @crowsonkb's https://github.com/crowsonkb/k-diffusion installed:
+
+```
+pip install k-diffusion
+```
+
+You can use the community pipeline as follows:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
+pipe = pipe.to("cuda")
+
+prompt = "an astronaut riding a horse on mars"
+pipe.set_scheduler("sample_heun")
+generator = torch.Generator(device="cuda").manual_seed(seed)
+image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
+
+image.save("./astronaut_heun_k_diffusion.png")
+```
+
+To make sure that K Diffusion and `diffusers` yield the same results:
+
+**Diffusers**:
+```python
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler
+
+seed = 33
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+
+generator = torch.Generator(device="cuda").manual_seed(seed)
+image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
+```
+
+![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler.png)
+
+**K Diffusion**:
+```python
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler
+
+seed = 33
+
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
+pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+
+pipe.set_scheduler("sample_euler")
+generator = torch.Generator(device="cuda").manual_seed(seed)
+image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
+```
+
+![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)
+
+### Checkpoint Merger Pipeline
+Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges upto 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.
+
+The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect atleast 13GB RAM Usage on Kaggle GPU kernels and
+on colab you might run out of the 12GB memory even while merging two checkpoints.
+
+Usage:-
+```python
+from diffusers import DiffusionPipeline
+
+#Return a CheckpointMergerPipeline class that allows you to merge checkpoints.
+#The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to
+#merge for convenience
+pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger")
+
+#There are multiple possible scenarios:
+#The pipeline with the merged checkpoints is returned in all the scenarios
+
+#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparison.( attrs with _ as prefix )
+merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffusion-v1-2"], interp = "sigmoid", alpha = 0.4)
+
+#Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility
+merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion"], force = True, interp = "sigmoid", alpha = 0.4)
+
+#Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.
+merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_difference", alpha = 0.4)
+
+prompt = "An astronaut riding a horse on Mars"
+
+image = merged_pipe(prompt).images[0]
+
+```
+Some examples along with the merge details:
+
+1. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" ; Sigmoid interpolation; alpha = 0.8
+
+![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stability_v1_4_waifu_sig_0.8.png)
+
+2. "hakurei/waifu-diffusion" + "prompthero/openjourney" ; Inverse Sigmoid interpolation; alpha = 0.8
+
+![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/waifu_openjourney_inv_sig_0.8.png)
+
+
+3. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" + "prompthero/openjourney"; Add Difference interpolation; alpha = 0.5
+
+![Stable plus Waifu plus openjourney add_diff 0.5](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stable_waifu_openjourney_add_diff_0.5.png)
+
+
+### Stable Diffusion Comparisons
+
+This Community Pipeline enables the comparison between the 4 checkpoints that exist for Stable Diffusion. They can be found through the following links:
+1. [Stable Diffusion v1.1](https://huggingface.co/CompVis/stable-diffusion-v1-1)
+2. [Stable Diffusion v1.2](https://huggingface.co/CompVis/stable-diffusion-v1-2)
+3. [Stable Diffusion v1.3](https://huggingface.co/CompVis/stable-diffusion-v1-3)
+4. [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)
+
+```python
+from diffusers import DiffusionPipeline
+import matplotlib.pyplot as plt
+
+pipe = DiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', custom_pipeline='suvadityamuk/StableDiffusionComparison')
+pipe.enable_attention_slicing()
+pipe = pipe.to('cuda')
+prompt = "an astronaut riding a horse on mars"
+output = pipe(prompt)
+
+plt.subplots(2,2,1)
+plt.imshow(output.images[0])
+plt.title('Stable Diffusion v1.1')
+plt.axis('off')
+plt.subplots(2,2,2)
+plt.imshow(output.images[1])
+plt.title('Stable Diffusion v1.2')
+plt.axis('off')
+plt.subplots(2,2,3)
+plt.imshow(output.images[2])
+plt.title('Stable Diffusion v1.3')
+plt.axis('off')
+plt.subplots(2,2,4)
+plt.imshow(output.images[3])
+plt.title('Stable Diffusion v1.4')
+plt.axis('off')
+
+plt.show()
+```
+
+As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
+
+### Magic Mix
+
+Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.
+
+There are 3 parameters for the method-
+- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.
+- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.
+
+Here is an example usage-
+
+```python
+from diffusers import DiffusionPipeline, DDIMScheduler
+from PIL import Image
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="magic_mix",
+ scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
+).to('cuda')
+
+img = Image.open('phone.jpg')
+mix_img = pipe(
+ img,
+ prompt = 'bed',
+ kmin = 0.3,
+ kmax = 0.5,
+ mix_factor = 0.5,
+ )
+mix_img.save('phone_bed_mix.jpg')
+```
+The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.
+
+E.g. the above script generates the following image:
+
+`phone.jpg`
+
+![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg)
+
+`phone_bed_mix.jpg`
+
+![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg)
+
+For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).
+
+
+### Stable UnCLIP
+
+UnCLIPPipeline("kakaobrain/karlo-v1-alpha") provide a prior model that can generate clip image embedding from text.
+StableDiffusionImageVariationPipeline("lambdalabs/sd-image-variations-diffusers") provide a decoder model than can generate images from clip image embedding.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "kakaobrain/karlo-v1-alpha",
+ torch_dtype=torch.float16,
+ custom_pipeline="stable_unclip",
+ decoder_pipe_kwargs=dict(
+ image_encoder=None,
+ ),
+)
+pipeline.to(device)
+
+prompt = "a shiba inu wearing a beret and black turtleneck"
+random_generator = torch.Generator(device=device).manual_seed(1000)
+output = pipeline(
+ prompt=prompt,
+ width=512,
+ height=512,
+ generator=random_generator,
+ prior_guidance_scale=4,
+ prior_num_inference_steps=25,
+ decoder_guidance_scale=8,
+ decoder_num_inference_steps=50,
+)
+
+image = output.images[0]
+image.save("./shiba-inu.jpg")
+
+# debug
+
+# `pipeline.decoder_pipe` is a regular StableDiffusionImageVariationPipeline instance.
+# It is used to convert clip image embedding to latents, then fed into VAE decoder.
+print(pipeline.decoder_pipe.__class__)
+#
+
+# this pipeline only use prior module in "kakaobrain/karlo-v1-alpha"
+# It is used to convert clip text embedding to clip image embedding.
+print(pipeline)
+# StableUnCLIPPipeline {
+# "_class_name": "StableUnCLIPPipeline",
+# "_diffusers_version": "0.12.0.dev0",
+# "prior": [
+# "diffusers",
+# "PriorTransformer"
+# ],
+# "prior_scheduler": [
+# "diffusers",
+# "UnCLIPScheduler"
+# ],
+# "text_encoder": [
+# "transformers",
+# "CLIPTextModelWithProjection"
+# ],
+# "tokenizer": [
+# "transformers",
+# "CLIPTokenizer"
+# ]
+# }
+
+# pipeline.prior_scheduler is the scheduler used for prior in UnCLIP.
+print(pipeline.prior_scheduler)
+# UnCLIPScheduler {
+# "_class_name": "UnCLIPScheduler",
+# "_diffusers_version": "0.12.0.dev0",
+# "clip_sample": true,
+# "clip_sample_range": 5.0,
+# "num_train_timesteps": 1000,
+# "prediction_type": "sample",
+# "variance_type": "fixed_small_log"
+# }
+```
+
+
+`shiba-inu.jpg`
+
+
+![shiba-inu](https://user-images.githubusercontent.com/16448529/209185639-6e5ec794-ce9d-4883-aa29-bd6852a2abad.jpg)
+
+### UnCLIP Text Interpolation Pipeline
+
+This Diffusion Pipeline takes two prompts and interpolates between the two input prompts using spherical interpolation ( slerp ). The input prompts are converted to text embeddings by the pipeline's text_encoder and the interpolation is done on the resulting text_embeddings over the number of steps specified. Defaults to 5 steps.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "kakaobrain/karlo-v1-alpha",
+ torch_dtype=torch.float16,
+ custom_pipeline="unclip_text_interpolation"
+)
+pipe.to(device)
+
+start_prompt = "A photograph of an adult lion"
+end_prompt = "A photograph of a lion cub"
+#For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
+generator = torch.Generator(device=device).manual_seed(42)
+
+output = pipe(start_prompt, end_prompt, steps = 6, generator = generator, enable_sequential_cpu_offload=False)
+
+for i,image in enumerate(output.images):
+ img.save('result%s.jpg' % i)
+```
+
+The resulting images in order:-
+
+![result_0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_0.png)
+![result_1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_1.png)
+![result_2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_2.png)
+![result_3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_3.png)
+![result_4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_4.png)
+![result_5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_5.png)
+
+### UnCLIP Image Interpolation Pipeline
+
+This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2 and interpolates between their embeddings using spherical interpolation ( slerp ). The input images/image_embeddings are converted to image embeddings by the pipeline's image_encoder and the interpolation is done on the resulting image_embeddings over the number of steps specified. Defaults to 5 steps.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+from PIL import Image
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
+
+pipe = DiffusionPipeline.from_pretrained(
+ "kakaobrain/karlo-v1-alpha-image-variations",
+ torch_dtype=dtype,
+ custom_pipeline="unclip_image_interpolation"
+)
+pipe.to(device)
+
+images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
+#For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
+generator = torch.Generator(device=device).manual_seed(42)
+
+output = pipe(image = images ,steps = 6, generator = generator)
+
+for i,image in enumerate(output.images):
+ image.save('starry_to_flowers_%s.jpg' % i)
+```
+The original images:-
+
+![starry](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_night.jpg)
+![flowers](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/flowers.jpg)
+
+The resulting images in order:-
+
+![result0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_0.png)
+![result1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_1.png)
+![result2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_2.png)
+![result3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_3.png)
+![result4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_4.png)
+![result5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_5.png)
+
+### DDIM Noise Comparative Analysis Pipeline
+#### **Research question: What visual concepts do the diffusion models learn from each noise level during training?**
+The [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227) paper proposed an approach to answer the above question, which is their second contribution.
+The approach consists of the following steps:
+
+1. The input is an image x0.
+2. Perturb it to xt using a diffusion process q(xt|x0).
+ - `strength` is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
+3. Reconstruct the image with the learned denoising process pθ(ˆx0|xt).
+4. Compare x0 and ˆx0 among various t to show how each step contributes to the sample.
+The authors used [openai/guided-diffusion](https://github.com/openai/guided-diffusion) model to denoise images in FFHQ dataset. This pipeline extends their second contribution by investigating DDIM on any input image.
+
+```python
+import torch
+from PIL import Image
+import numpy as np
+
+image_path = "path/to/your/image" # images from CelebA-HQ might be better
+image_pil = Image.open(image_path)
+image_name = image_path.split("/")[-1].split(".")[0]
+
+device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
+pipe = DiffusionPipeline.from_pretrained(
+ "google/ddpm-ema-celebahq-256",
+ custom_pipeline="ddim_noise_comparative_analysis",
+)
+pipe = pipe.to(device)
+
+for strength in np.linspace(0.1, 1, 25):
+ denoised_image, latent_timestep = pipe(
+ image_pil, strength=strength, return_dict=False
+ )
+ denoised_image = denoised_image[0]
+ denoised_image.save(
+ f"noise_comparative_analysis_{image_name}_{latent_timestep}.png"
+ )
+```
+
+Here is the result of this pipeline (which is DDIM) on CelebA-HQ dataset.
+
+![noise-comparative-analysis](https://user-images.githubusercontent.com/67547213/224677066-4474b2ed-56ab-4c27-87c6-de3c0255eb9c.jpeg)
+
+### CLIP Guided Img2Img Stable Diffusion
+
+CLIP guided Img2Img stable diffusion can help to generate more realistic images with an initial image
+by guiding stable diffusion at every denoising step with an additional CLIP model.
+
+The following code requires roughly 12GB of GPU RAM.
+
+```python
+from io import BytesIO
+import requests
+import torch
+from diffusers import DiffusionPipeline
+from PIL import Image
+from transformers import CLIPFeatureExtractor, CLIPModel
+feature_extractor = CLIPFeatureExtractor.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+)
+clip_model = CLIPModel.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
+)
+guided_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ # custom_pipeline="clip_guided_stable_diffusion",
+ custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+)
+guided_pipeline.enable_attention_slicing()
+guided_pipeline = guided_pipeline.to("cuda")
+prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+image = guided_pipeline(
+ prompt=prompt,
+ num_inference_steps=30,
+ image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+).images[0]
+display(image)
+```
+
+Init Image
+
+![img2img_init_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img_init.jpg)
+
+Output Image
+
+![img2img_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img.jpg)
+
+### TensorRT Text2Image Stable Diffusion Pipeline
+
+The TensorRT Pipeline can be used to accelerate the Text2Image Stable Diffusion Inference run.
+
+NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
+
+```python
+import torch
+from diffusers import DDIMScheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
+
+# Use the DDIMScheduler scheduler here instead
+scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
+ subfolder="scheduler")
+
+pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
+ custom_pipeline="stable_diffusion_tensorrt_txt2img",
+ revision='fp16',
+ torch_dtype=torch.float16,
+ scheduler=scheduler,)
+
+# re-use cached folder to save ONNX models and TensorRT Engines
+pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',)
+
+pipe = pipe.to("cuda")
+
+prompt = "a beautiful photograph of Mt. Fuji during cherry blossom"
+image = pipe(prompt).images[0]
+image.save('tensorrt_mt_fuji.png')
+```
+
+### EDICT Image Editing Pipeline
+
+This pipeline implements the text-guided image editing approach from the paper [EDICT: Exact Diffusion Inversion via Coupled Transformations](https://arxiv.org/abs/2211.12446). You have to pass:
+- (`PIL`) `image` you want to edit.
+- `base_prompt`: the text prompt describing the current image (before editing).
+- `target_prompt`: the text prompt describing with the edits.
+
+```python
+from diffusers import DiffusionPipeline, DDIMScheduler
+from transformers import CLIPTextModel
+import torch, PIL, requests
+from io import BytesIO
+from IPython.display import display
+
+def center_crop_and_resize(im):
+
+ width, height = im.size
+ d = min(width, height)
+ left = (width - d) / 2
+ upper = (height - d) / 2
+ right = (width + d) / 2
+ lower = (height + d) / 2
+
+ return im.crop((left, upper, right, lower)).resize((512, 512))
+
+torch_dtype = torch.float16
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# scheduler and text_encoder param values as in the paper
+scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ set_alpha_to_one=False,
+ clip_sample=False,
+)
+
+text_encoder = CLIPTextModel.from_pretrained(
+ pretrained_model_name_or_path="openai/clip-vit-large-patch14",
+ torch_dtype=torch_dtype,
+)
+
+# initialize pipeline
+pipeline = DiffusionPipeline.from_pretrained(
+ pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
+ custom_pipeline="edict_pipeline",
+ revision="fp16",
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ leapfrog_steps=True,
+ torch_dtype=torch_dtype,
+).to(device)
+
+# download image
+image_url = "https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg"
+response = requests.get(image_url)
+image = PIL.Image.open(BytesIO(response.content))
+
+# preprocess it
+cropped_image = center_crop_and_resize(image)
+
+# define the prompts
+base_prompt = "A dog"
+target_prompt = "A golden retriever"
+
+# run the pipeline
+result_image = pipeline(
+ base_prompt=base_prompt,
+ target_prompt=target_prompt,
+ image=cropped_image,
+)
+
+display(result_image)
+```
+
+Init Image
+
+![img2img_init_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg)
+
+Output Image
+
+![img2img_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1_cropped_generated.png)
+
+### Stable Diffusion RePaint
+
+This pipeline uses the [RePaint](https://arxiv.org/abs/2201.09865) logic on the latent space of stable diffusion. It can
+be used similarly to other image inpainting pipelines but does not rely on a specific inpainting model. This means you can use
+models that are not specifically created for inpainting.
+
+Make sure to use the ```RePaintScheduler``` as shown in the example below.
+
+Disclaimer: The mask gets transferred into latent space, this may lead to unexpected changes on the edge of the masked part.
+The inference time is a lot slower.
+
+```py
+import PIL
+import requests
+import torch
+from io import BytesIO
+from diffusers import StableDiffusionPipeline, RePaintScheduler
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+mask_image = PIL.ImageOps.invert(mask_image)
+pipe = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, custom_pipeline="stable_diffusion_repaint",
+)
+pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to("cuda")
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+```
+
+### TensorRT Image2Image Stable Diffusion Pipeline
+
+The TensorRT Pipeline can be used to accelerate the Image2Image Stable Diffusion Inference run.
+
+NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
+
+```python
+import requests
+from io import BytesIO
+from PIL import Image
+import torch
+from diffusers import DDIMScheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionImg2ImgPipeline
+
+# Use the DDIMScheduler scheduler here instead
+scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
+ subfolder="scheduler")
+
+
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
+ custom_pipeline="stable_diffusion_tensorrt_img2img",
+ revision='fp16',
+ torch_dtype=torch.float16,
+ scheduler=scheduler,)
+
+# re-use cached folder to save ONNX models and TensorRT Engines
+pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',)
+
+pipe = pipe.to("cuda")
+
+url = "https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png"
+response = requests.get(url)
+input_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+prompt = "photorealistic new zealand hills"
+image = pipe(prompt, image=input_image, strength=0.75,).images[0]
+image.save('tensorrt_img2img_new_zealand_hills.png')
+```
+
+### Stable Diffusion Reference
+
+This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).
+
+Based on [this issue](https://github.com/huggingface/diffusers/issues/3566),
+- `EulerAncestralDiscreteScheduler` got poor results.
+
+```py
+import torch
+from diffusers import UniPCMultistepScheduler
+from diffusers.utils import load_image
+
+input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+pipe = StableDiffusionReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+```
+
+Reference Image
+
+![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)
+
+Output Image of `reference_attn=True` and `reference_adain=False`
+
+![output_image](https://github.com/huggingface/diffusers/assets/24734142/813b5c6a-6d89-46ba-b7a4-2624e240eea5)
+
+Output Image of `reference_attn=False` and `reference_adain=True`
+
+![output_image](https://github.com/huggingface/diffusers/assets/24734142/ffc90339-9ef0-4c4d-a544-135c3e5644da)
+
+Output Image of `reference_attn=True` and `reference_adain=True`
+
+![output_image](https://github.com/huggingface/diffusers/assets/24734142/3c5255d6-867d-4d35-b202-8dfd30cc6827)
+
+### Stable Diffusion ControlNet Reference
+
+This pipeline uses the Reference Control with ControlNet. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).
+
+Based on [this issue](https://github.com/huggingface/diffusers/issues/3566),
+- `EulerAncestralDiscreteScheduler` got poor results.
+- `guess_mode=True` works well for ControlNet v1.1
+
+```py
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+from diffusers import UniPCMultistepScheduler
+from diffusers.utils import load_image
+
+input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+# get canny image
+image = cv2.Canny(np.array(input_image), 100, 200)
+image = image[:, :, None]
+image = np.concatenate([image, image, image], axis=2)
+canny_image = Image.fromarray(image)
+
+controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ image=canny_image,
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+```
+
+Reference Image
+
+![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)
+
+Output Image
+
+![output_image](https://github.com/huggingface/diffusers/assets/24734142/7b9a5830-f173-4b92-b0cf-73d0e9c01d60)
+
+
+### Stable Diffusion on IPEX
+
+This diffusion pipeline aims to accelarate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
+
+To use this pipeline, you need to:
+1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
+
+**Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance.
+
+|PyTorch Version|IPEX Version|
+|--|--|
+|[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|
+|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
+
+You can simply use pip to install IPEX with the latest version.
+```python
+python -m pip install intel_extension_for_pytorch
+```
+**Note:** To install a specific version, run with the following command:
+```
+python -m pip install intel_extension_for_pytorch== -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+
+2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
+
+**Note:** The setting of generated image height/width for `prepare_for_ipex()` should be same as the setting of pipeline inference.
+```python
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
+# For Float32
+pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
+# For BFloat16
+pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
+```
+
+Then you can use the ipex pipeline in a similar way to the default stable diffusion pipeline.
+```python
+# For Float32
+image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'
+# For BFloat16
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'
+```
+
+The following code compares the performance of the original stable diffusion pipeline with the ipex-optimized pipeline.
+
+```python
+import torch
+import intel_extension_for_pytorch as ipex
+from diffusers import StableDiffusionPipeline
+import time
+
+prompt = "sailing ship in storm by Rembrandt"
+model_id = "runwayml/stable-diffusion-v1-5"
+# Helper function for time evaluation
+def elapsed_time(pipeline, nb_pass=3, num_inference_steps=20):
+ # warmup
+ for _ in range(2):
+ images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images
+ #time evaluation
+ start = time.time()
+ for _ in range(nb_pass):
+ pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512)
+ end = time.time()
+ return (end - start) / nb_pass
+
+############## bf16 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex")
+pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512)
+
+# 2. Original Pipeline initialization
+pipe2 = StableDiffusionPipeline.from_pretrained(model_id)
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ latency = elapsed_time(pipe)
+ print("Latency of StableDiffusionIPEXPipeline--bf16", latency)
+ latency = elapsed_time(pipe2)
+ print("Latency of StableDiffusionPipeline--bf16",latency)
+
+############## fp32 inference performance ###############
+
+# 1. IPEX Pipeline initialization
+pipe3 = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex")
+pipe3.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512)
+
+# 2. Original Pipeline initialization
+pipe4 = StableDiffusionPipeline.from_pretrained(model_id)
+
+# 3. Compare performance between Original Pipeline and IPEX Pipeline
+latency = elapsed_time(pipe3)
+print("Latency of StableDiffusionIPEXPipeline--fp32", latency)
+latency = elapsed_time(pipe4)
+print("Latency of StableDiffusionPipeline--fp32",latency)
+
+```
+
+### CLIP Guided Images Mixing With Stable Diffusion
+
+![clip_guided_images_mixing_examples](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/main.png)
+
+CLIP guided stable diffusion images mixing pipeline allows to combine two images using standard diffusion models.
+This approach is using (optional) CoCa model to avoid writing image description.
+[More code examples](https://github.com/TheDenk/images_mixing)
+
+
+### Stable Diffusion XL Long Weighted Prompt Pipeline
+
+This SDXL pipeline support unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
+
+You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0"
+ , torch_dtype = torch.float16
+ , use_safetensors = True
+ , variant = "fp16"
+ , custom_pipeline = "lpw_stable_diffusion_xl",
+)
+
+prompt = "photo of a cute (white) cat running on the grass"*20
+prompt2 = "chasing (birds:1.5)"*20
+prompt = f"{prompt},{prompt2}"
+neg_prompt = "blur, low quality, carton, animate"
+
+pipe.to("cuda")
+images = pipe(
+ prompt = prompt
+ , negative_prompt = neg_prompt
+).images[0]
+
+pipe.to("cpu")
+torch.cuda.empty_cache()
+images
+```
+
+In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
+![Stable Diffusion XL Long Weighted Prompt Pipeline sample](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_long_weighted_prompt.png)
+
+## Example Images Mixing (with CoCa)
+```python
+import requests
+from io import BytesIO
+
+import PIL
+import torch
+import open_clip
+from open_clip import SimpleTokenizer
+from diffusers import DiffusionPipeline
+from transformers import CLIPFeatureExtractor, CLIPModel
+
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+# Loading additional models
+feature_extractor = CLIPFeatureExtractor.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+)
+clip_model = CLIPModel.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
+)
+coca_model = open_clip.create_model('coca_ViT-L-14', pretrained='laion2B-s13B-b90k').to('cuda')
+coca_model.dtype = torch.float16
+coca_transform = open_clip.image_transform(
+ coca_model.visual.image_size,
+ is_train = False,
+ mean = getattr(coca_model.visual, 'image_mean', None),
+ std = getattr(coca_model.visual, 'image_std', None),
+)
+coca_tokenizer = SimpleTokenizer()
+
+# Pipeline creating
+mixing_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_images_mixing_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ coca_model=coca_model,
+ coca_tokenizer=coca_tokenizer,
+ coca_transform=coca_transform,
+ torch_dtype=torch.float16,
+)
+mixing_pipeline.enable_attention_slicing()
+mixing_pipeline = mixing_pipeline.to("cuda")
+
+# Pipeline running
+generator = torch.Generator(device="cuda").manual_seed(17)
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+content_image = download_image("https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir.jpg")
+style_image = download_image("https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/gigachad.jpg")
+
+pipe_images = mixing_pipeline(
+ num_inference_steps=50,
+ content_image=content_image,
+ style_image=style_image,
+ noise_strength=0.65,
+ slerp_latent_style_strength=0.9,
+ slerp_prompt_style_strength=0.1,
+ slerp_clip_image_style_strength=0.1,
+ guidance_scale=9.0,
+ batch_size=1,
+ clip_guidance_scale=100,
+ generator=generator,
+).images
+```
+
+![image_mixing_result](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir_gigachad.png)
+
+### Stable Diffusion Mixture Tiling
+
+This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
+
+```python
+from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+
+# Creater scheduler and model (similar to StableDiffusionPipeline)
+scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, custom_pipeline="mixture_tiling")
+pipeline.to("cuda")
+
+# Mixture of Diffusers generation
+image = pipeline(
+ prompt=[[
+ "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
+ ]],
+ tile_height=640,
+ tile_width=640,
+ tile_row_overlap=0,
+ tile_col_overlap=256,
+ guidance_scale=8,
+ seed=7178915308,
+ num_inference_steps=50,
+)["images"][0]
+```
+![mixture_tiling_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/mixture_tiling.png)
+
+### TensorRT Inpainting Stable Diffusion Pipeline
+
+The TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run.
+
+NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
+
+```python
+import requests
+from io import BytesIO
+from PIL import Image
+import torch
+from diffusers import PNDMScheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
+
+# Use the PNDMScheduler scheduler here instead
+scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler")
+
+
+pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting",
+ custom_pipeline="stable_diffusion_tensorrt_inpaint",
+ revision='fp16',
+ torch_dtype=torch.float16,
+ scheduler=scheduler,
+ )
+
+# re-use cached folder to save ONNX models and TensorRT Engines
+pipe.set_cached_folder("stabilityai/stable-diffusion-2-inpainting", revision='fp16',)
+
+pipe = pipe.to("cuda")
+
+url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+response = requests.get(url)
+input_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+response = requests.get(mask_url)
+mask_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+prompt = "a mecha robot sitting on a bench"
+image = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).images[0]
+image.save('tensorrt_inpaint_mecha_robot.png')
+```
+
+### Stable Diffusion Mixture Canvas
+
+This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details.
+
+```python
+from PIL import Image
+from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+from diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image
+
+
+# Load and preprocess guide image
+iic_image = preprocess_image(Image.open("input_image.png").convert("RGB"))
+
+# Creater scheduler and model (similar to StableDiffusionPipeline)
+scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler).to("cuda:0", custom_pipeline="mixture_canvas")
+pipeline.to("cuda")
+
+# Mixture of Diffusers generation
+output = pipeline(
+ canvas_height=800,
+ canvas_width=352,
+ regions=[
+ Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,
+ prompt=f"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed eyes, hyper-detailed face, complex, perfect, model, textured, chiaroscuro, professional make-up, realistic, figure in frame, "),
+ Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0),
+ ],
+ num_inference_steps=100,
+ seed=5525475061,
+)["images"][0]
+```
+![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png)
+![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png)
+
+
+### IADB pipeline
+
+This pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) paper.
+It is a simple and minimalist diffusion model.
+
+The following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model.
+
+```python
+
+pipeline_iadb = DiffusionPipeline.from_pretrained("thomasc4/iadb-celebahq-256", custom_pipeline='iadb')
+
+pipeline_iadb = pipeline_iadb.to('cuda')
+
+output = pipeline_iadb(batch_size=4,num_inference_steps=128)
+for i in range(len(output[0])):
+ plt.imshow(output[0][i])
+ plt.show()
+
+```
+
+Sampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it):
+
+```python
+
+def sample_iadb(model, x0, nb_step):
+ x_alpha = x0
+ for t in range(nb_step):
+ alpha = (t/nb_step)
+ alpha_next =((t+1)/nb_step)
+
+ d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample']
+ x_alpha = x_alpha + (alpha_next-alpha)*d
+
+ return x_alpha
+
+```
+
+The training loop is also straightforward:
+
+```python
+
+# Training loop
+while True:
+ x0 = sample_noise()
+ x1 = sample_dataset()
+
+ alpha = torch.rand(batch_size)
+
+ # Blend
+ x_alpha = (1-alpha) * x0 + alpha * x1
+
+ # Loss
+ loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+```
+
+### Zero1to3 pipeline
+
+This pipeline is the implementation of the [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) paper.
+The original pytorch-lightning [repo](https://github.com/cvlab-columbia/zero123) and a diffusers [repo](https://github.com/kxhit/zero123-hf).
+
+The following code shows how to use the Zero1to3 pipeline to generate novel view synthesis images using a pretrained stable diffusion model.
+
+```python
+import os
+import torch
+from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
+from diffusers.utils import load_image
+
+model_id = "kxic/zero123-165000" # zero123-105000, zero123-165000, zero123-xl
+
+pipe = Zero1to3StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+
+pipe.enable_xformers_memory_efficient_attention()
+pipe.enable_vae_tiling()
+pipe.enable_attention_slicing()
+pipe = pipe.to("cuda")
+
+num_images_per_prompt = 4
+
+# test inference pipeline
+# x y z, Polar angle (vertical rotation in degrees) Azimuth angle (horizontal rotation in degrees) Zoom (relative distance from center)
+query_pose1 = [-75.0, 100.0, 0.0]
+query_pose2 = [-20.0, 125.0, 0.0]
+query_pose3 = [-55.0, 90.0, 0.0]
+
+# load image
+# H, W = (256, 256) # H, W = (512, 512) # zero123 training is 256,256
+
+# for batch input
+input_image1 = load_image("./demo/4_blackarm.png") #load_image("https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/4_blackarm.png")
+input_image2 = load_image("./demo/8_motor.png") #load_image("https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/8_motor.png")
+input_image3 = load_image("./demo/7_london.png") #load_image("https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/7_london.png")
+input_images = [input_image1, input_image2, input_image3]
+query_poses = [query_pose1, query_pose2, query_pose3]
+
+# # for single input
+# H, W = (256, 256)
+# input_images = [input_image2.resize((H, W), PIL.Image.NEAREST)]
+# query_poses = [query_pose2]
+
+
+# better do preprocessing
+from gradio_new import preprocess_image, create_carvekit_interface
+import numpy as np
+import PIL.Image as Image
+
+pre_images = []
+models = dict()
+print('Instantiating Carvekit HiInterface...')
+models['carvekit'] = create_carvekit_interface()
+if not isinstance(input_images, list):
+ input_images = [input_images]
+for raw_im in input_images:
+ input_im = preprocess_image(models, raw_im, True)
+ H, W = input_im.shape[:2]
+ pre_images.append(Image.fromarray((input_im * 255.0).astype(np.uint8)))
+input_images = pre_images
+
+# infer pipeline, in original zero123 num_inference_steps=76
+images = pipe(input_imgs=input_images, prompt_imgs=input_images, poses=query_poses, height=H, width=W,
+ guidance_scale=3.0, num_images_per_prompt=num_images_per_prompt, num_inference_steps=50).images
+
+
+# save imgs
+log_dir = "logs"
+os.makedirs(log_dir, exist_ok=True)
+bs = len(input_images)
+i = 0
+for obj in range(bs):
+ for idx in range(num_images_per_prompt):
+ images[i].save(os.path.join(log_dir,f"obj{obj}_{idx}.jpg"))
+ i += 1
+
+```
+
+### Stable Diffusion XL Reference
+
+This pipeline uses the Reference . Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
+
+
+```py
+import torch
+from PIL import Image
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+from diffusers.schedulers import UniPCMultistepScheduler
+input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+# pipe = DiffusionPipeline.from_pretrained(
+# "stabilityai/stable-diffusion-xl-base-1.0",
+# custom_pipeline="stable_diffusion_xl_reference",
+# torch_dtype=torch.float16,
+# use_safetensors=True,
+# variant="fp16").to('cuda:0')
+
+pipe = StableDiffusionXLReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16").to('cuda:0')
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+```
+
+Reference Image
+
+![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)
+
+Output Image
+
+`prompt: 1 girl`
+
+`reference_attn=True, reference_adain=True, num_inference_steps=20`
+![Output_image](https://github.com/zideliu/diffusers/assets/34944964/743848da-a215-48f9-ae39-b5e2ae49fb13)
+
+Reference Image
+![reference_image](https://github.com/huggingface/diffusers/assets/34944964/449bdab6-e744-4fb2-9620-d4068d9a741b)
+
+
+Output Image
+
+`prompt: A dog`
+
+`reference_attn=True, reference_adain=False, num_inference_steps=20`
+![Output_image](https://github.com/huggingface/diffusers/assets/34944964/fff2f16f-6e91-434b-abcc-5259d866c31e)
+
+Reference Image
+![reference_image](https://github.com/huggingface/diffusers/assets/34944964/077ed4fe-2991-4b79-99a1-009f056227d1)
+
+Output Image
+
+`prompt: An astronaut riding a lion`
+
+`reference_attn=True, reference_adain=True, num_inference_steps=20`
+![output_image](https://github.com/huggingface/diffusers/assets/34944964/9b2f1aca-886f-49c3-89ec-d2031c8e3670)
+
+### Stable diffusion fabric pipeline
+
+FABRIC approach applicable to a wide range of popular diffusion models, which exploits
+the self-attention layer present in the most widely used architectures to condition
+the diffusion process on a set of feedback images.
+
+
+```python
+import requests
+import torch
+from PIL import Image
+from io import BytesIO
+
+from diffusers import DiffusionPipeline
+
+# load the pipeline
+# make sure you're logged in with `huggingface-cli login`
+model_id_or_path = "runwayml/stable-diffusion-v1-5"
+#can also be used with dreamlike-art/dreamlike-photoreal-2.0
+pipe = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric").to("cuda")
+
+# let's specify a prompt
+prompt = "An astronaut riding an elephant"
+negative_prompt = "lowres, cropped"
+
+# call the pipeline
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=20,
+ generator=torch.manual_seed(12)
+).images[0]
+
+image.save("horse_to_elephant.jpg")
+
+# let's try another example with feedback
+url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+prompt = "photo, A blue colored car, fish eye"
+liked = [init_image]
+## same goes with disliked
+
+# call the pipeline
+torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ liked = liked,
+ num_inference_steps=20,
+).images[0]
+
+image.save("black_to_blue.png")
+```
+
+*With enough feedbacks you can create very similar high quality images.*
+
+The original codebase can be found at [sd-fabric/fabric](https://github.com/sd-fabric/fabric), and available checkpoints are [dreamlike-art/dreamlike-photoreal-2.0](https://huggingface.co/dreamlike-art/dreamlike-photoreal-2.0), [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5), and [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) (may give unexpected results).
+
+Let's have a look at the images (*512X512*)
+
+| Without Feedback | With Feedback (1st image) |
+|---------------------|---------------------|
+| ![Image 1](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/fabric_wo_feedback.jpg) | ![Feedback Image 1](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/fabric_w_feedback.png) |
+
+
+### Masked Im2Im Stable Diffusion Pipeline
+
+This pipeline reimplements sketch inpaint feature from A1111 for non-inpaint models. The following code reads two images, original and one with mask painted over it. It computes mask as a difference of two images and does the inpainting in the area defined by the mask.
+
+```python
+img = PIL.Image.open("./mech.png")
+# read image with mask painted over
+img_paint = PIL.Image.open("./mech_painted.png")
+neq = numpy.any(numpy.array(img) != numpy.array(img_paint), axis=-1)
+mask = neq / neq.max()
+
+pipeline = MaskedStableDiffusionImg2ImgPipeline.from_pretrained("frankjoshua/icbinpICantBelieveIts_v8")
+
+# works best with EulerAncestralDiscreteScheduler
+pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
+generator = torch.Generator(device="cpu").manual_seed(4)
+
+prompt = "a man wearing a mask"
+result = pipeline(prompt=prompt, image=img_paint, mask=mask, strength=0.75,
+ generator=generator)
+result.images[0].save("result.png")
+```
+
+original image mech.png
+
+
+
+image with mask mech_painted.png
+
+
+
+result:
+
+
+
+
+### Prompt2Prompt Pipeline
+
+Prompt2Prompt allows the following edits:
+- ReplaceEdit (change words in prompt)
+- ReplaceEdit with local blend (change words in prompt, keep image part unrelated to changes constant)
+- RefineEdit (add words to prompt)
+- RefineEdit with local blend (add words to prompt, keep image part unrelated to changes constant)
+- ReweightEdit (modulate importance of words)
+
+Here's a full example for `ReplaceEdit``:
+
+```python
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from diffusers.pipelines import Prompt2PromptPipeline
+
+pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
+
+prompts = ["A turtle playing with a ball",
+ "A monkey playing with a ball"]
+
+cross_attention_kwargs = {
+ "edit_type": "replace",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4
+}
+
+outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
+```
+
+And abbreviated examples for the other edits:
+
+`ReplaceEdit with local blend`
+```python
+prompts = ["A turtle playing with a ball",
+ "A monkey playing with a ball"]
+
+cross_attention_kwargs = {
+ "edit_type": "replace",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+ "local_blend_words": ["turtle", "monkey"]
+}
+```
+
+`RefineEdit`
+```python
+prompts = ["A turtle",
+ "A turtle in a forest"]
+
+cross_attention_kwargs = {
+ "edit_type": "refine",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+}
+```
+
+`RefineEdit with local blend`
+```python
+prompts = ["A turtle",
+ "A turtle in a forest"]
+
+cross_attention_kwargs = {
+ "edit_type": "refine",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+ "local_blend_words": ["in", "a" , "forest"]
+}
+```
+
+`ReweightEdit`
+```python
+prompts = ["A smiling turtle"] * 2
+
+edit_kcross_attention_kwargswargs = {
+ "edit_type": "reweight",
+ "cross_replace_steps": 0.4,
+ "self_replace_steps": 0.4,
+ "equalizer_words": ["smiling"],
+ "equalizer_strengths": [5]
+}
+```
+
+Side note: See [this GitHub gist](https://gist.github.com/UmerHA/b65bb5fb9626c9c73f3ade2869e36164) if you want to visualize the attention maps.
+
+### Latent Consistency Pipeline
+
+Latent Consistency Models was proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by *Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, Hang Zhao* from Tsinghua University.
+
+The abstract of the paper reads as follows:
+
+*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/)*
+
+The model can be used with `diffusers` as follows:
+
+ - *1. Load the model from the community pipeline.*
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_txt2img", custom_revision="main")
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+```
+
+- 2. Run inference with as little as 4 steps:
+
+```py
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
+num_inference_steps = 4
+
+images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
+```
+
+For any questions or feedback, feel free to reach out to [Simian Luo](https://github.com/luosiallen).
+
+You can also try this pipeline directly in the [🚀 official spaces](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).
+
+
+
+### Latent Consistency Img2img Pipeline
+
+This pipeline extends the Latent Consistency Pipeline to allow it to take an input image.
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_img2img")
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+```
+
+- 2. Run inference with as little as 4 steps:
+
+```py
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+
+input_image=Image.open("myimg.png")
+
+strength = 0.5 #strength =0 (no change) strength=1 (completely overwrite image)
+
+# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
+num_inference_steps = 4
+
+images = pipe(prompt=prompt, image=input_image, strength=strength, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
+```
+
+
+
+### Latent Consistency Interpolation Pipeline
+
+This pipeline extends the Latent Consistency Pipeline to allow for interpolation of the latent space between multiple prompts. It is similar to the [Stable Diffusion Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py) and [unCLIP Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/unclip_text_interpolation.py) community pipelines.
+
+```py
+import torch
+import numpy as np
+
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_interpolate")
+
+# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+
+prompts = [
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, Margot Robbie, 8k",
+ "Self-portrait oil painting, an extremely strong man, body builder, Huge Jackman, 8k",
+ "An astronaut floating in space, renaissance art, realistic, high quality, 8k",
+ "Oil painting of a cat, cute, dream-like",
+ "Hugging face emoji, cute, realistic"
+]
+num_inference_steps = 4
+num_interpolation_steps = 60
+seed = 1337
+
+torch.manual_seed(seed)
+np.random.seed(seed)
+
+images = pipe(
+ prompt=prompts,
+ height=512,
+ width=512,
+ num_inference_steps=num_inference_steps,
+ num_interpolation_steps=num_interpolation_steps,
+ guidance_scale=8.0,
+ embedding_interpolation_type="lerp",
+ latent_interpolation_type="slerp",
+ process_batch_size=4, # Make it higher or lower based on your GPU memory
+ generator=torch.Generator(seed),
+)
+
+assert len(images) == (len(prompts) - 1) * num_interpolation_steps
+```
+
+### ControlNet + T2I Adapter Pipeline
+This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
+It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.
+
+```py
+import cv2
+import numpy as np
+import torch
+from controlnet_aux.midas import MidasDetector
+from PIL import Image
+
+from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.utils import load_image
+from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
+ StableDiffusionXLControlNetAdapterPipeline,
+)
+
+controlnet_depth = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-depth-sdxl-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True
+)
+adapter_depth = T2IAdapter.from_pretrained(
+ "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+
+pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ controlnet=controlnet_depth,
+ adapter=adapter_depth,
+ vae=vae,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+pipe.enable_xformers_memory_efficient_attention()
+# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
+midas_depth = MidasDetector.from_pretrained(
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+).to("cuda")
+
+prompt = "a tiger sitting on a park bench"
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+
+image = load_image(img_url).resize((1024, 1024))
+
+depth_image = midas_depth(
+ image, detect_resolution=512, image_resolution=1024
+)
+
+strength = 0.5
+
+images = pipe(
+ prompt,
+ control_image=depth_image,
+ adapter_image=depth_image,
+ num_inference_steps=30,
+ controlnet_conditioning_scale=strength,
+ adapter_conditioning_scale=strength,
+).images
+images[0].save("controlnet_and_adapter.png")
+
+```
+
+### ControlNet + T2I Adapter + Inpainting Pipeline
+```py
+import cv2
+import numpy as np
+import torch
+from controlnet_aux.midas import MidasDetector
+from PIL import Image
+
+from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.utils import load_image
+from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
+ StableDiffusionXLControlNetAdapterInpaintPipeline,
+)
+
+controlnet_depth = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-depth-sdxl-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True
+)
+adapter_depth = T2IAdapter.from_pretrained(
+ "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+)
+vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+
+pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+ controlnet=controlnet_depth,
+ adapter=adapter_depth,
+ vae=vae,
+ variant="fp16",
+ use_safetensors=True,
+ torch_dtype=torch.float16,
+)
+pipe = pipe.to("cuda")
+pipe.enable_xformers_memory_efficient_attention()
+# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
+midas_depth = MidasDetector.from_pretrained(
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+).to("cuda")
+
+prompt = "a tiger sitting on a park bench"
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+image = load_image(img_url).resize((1024, 1024))
+mask_image = load_image(mask_url).resize((1024, 1024))
+
+depth_image = midas_depth(
+ image, detect_resolution=512, image_resolution=1024
+)
+
+strength = 0.4
+
+images = pipe(
+ prompt,
+ image=image,
+ mask_image=mask_image,
+ control_image=depth_image,
+ adapter_image=depth_image,
+ num_inference_steps=30,
+ controlnet_conditioning_scale=strength,
+ adapter_conditioning_scale=strength,
+ strength=0.7,
+).images
+images[0].save("controlnet_and_adapter_inpaint.png")
+
+```
\ No newline at end of file
diff --git a/diffusers/examples/community/bit_diffusion.py b/diffusers/examples/community/bit_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d5fca5619e3f420128288399aa000037d1feec
--- /dev/null
+++ b/diffusers/examples/community/bit_diffusion.py
@@ -0,0 +1,264 @@
+from typing import Optional, Tuple, Union
+
+import torch
+from einops import rearrange, reduce
+
+from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel
+from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
+from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
+
+
+BITS = 8
+
+
+# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py
+def decimal_to_bits(x, bits=BITS):
+ """expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""
+ device = x.device
+
+ x = (x * 255).int().clamp(0, 255)
+
+ mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)
+ mask = rearrange(mask, "d -> d 1 1")
+ x = rearrange(x, "b c h w -> b c 1 h w")
+
+ bits = ((x & mask) != 0).float()
+ bits = rearrange(bits, "b c d h w -> b (c d) h w")
+ bits = bits * 2 - 1
+ return bits
+
+
+def bits_to_decimal(x, bits=BITS):
+ """expects bits from -1 to 1, outputs image tensor from 0 to 1"""
+ device = x.device
+
+ x = (x > 0).int()
+ mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)
+
+ mask = rearrange(mask, "d -> d 1 1")
+ x = rearrange(x, "b (c d) h w -> b c d h w", d=8)
+ dec = reduce(x * mask, "b c d h w -> b c h w", "sum")
+ return (dec / 255).clamp(0.0, 1.0)
+
+
+# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale
+def ddim_bit_scheduler_step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = True,
+ generator=None,
+ return_dict: bool = True,
+) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): TODO
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ scale = self.bit_scale
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+
+def ddpm_bit_scheduler_step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ prediction_type="epsilon",
+ generator=None,
+ return_dict: bool = True,
+) -> Union[DDPMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ prediction_type (`str`, default `epsilon`):
+ indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
+ Returns:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif prediction_type == "sample":
+ pred_original_sample = model_output
+ else:
+ raise ValueError(f"Unsupported prediction_type {prediction_type}.")
+
+ # 3. Clip "predicted x_0"
+ scale = self.bit_scale
+ if self.config.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ noise = torch.randn(
+ model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
+ ).to(model_output.device)
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
+
+
+class BitDiffusion(DiffusionPipeline):
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, DDPMScheduler],
+ bit_scale: Optional[float] = 1.0,
+ ):
+ super().__init__()
+ self.bit_scale = bit_scale
+ self.scheduler.step = (
+ ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step
+ )
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ height: Optional[int] = 256,
+ width: Optional[int] = 256,
+ num_inference_steps: Optional[int] = 50,
+ generator: Optional[torch.Generator] = None,
+ batch_size: Optional[int] = 1,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ latents = torch.randn(
+ (batch_size, self.unet.config.in_channels, height, width),
+ generator=generator,
+ )
+ latents = decimal_to_bits(latents) * self.bit_scale
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # predict the noise residual
+ noise_pred = self.unet(latents, t).sample
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ image = bits_to_decimal(latents)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/checkpoint_merger.py b/diffusers/examples/community/checkpoint_merger.py
new file mode 100644
index 0000000000000000000000000000000000000000..10381020bf631d6a7baaf76a38b80d11bc7f4b09
--- /dev/null
+++ b/diffusers/examples/community/checkpoint_merger.py
@@ -0,0 +1,280 @@
+import glob
+import os
+from typing import Dict, List, Union
+
+import safetensors.torch
+import torch
+from huggingface_hub import snapshot_download
+
+from diffusers import DiffusionPipeline, __version__
+from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
+from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
+
+
+class CheckpointMergerPipeline(DiffusionPipeline):
+ """
+ A class that supports merging diffusion models based on the discussion here:
+ https://github.com/huggingface/diffusers/issues/877
+
+ Example usage:-
+
+ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")
+
+ merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)
+
+ merged_pipe.to('cuda')
+
+ prompt = "An astronaut riding a unicycle on Mars"
+
+ results = merged_pipe(prompt)
+
+ ## For more details, see the docstring for the merge method.
+
+ """
+
+ def __init__(self):
+ self.register_to_config()
+ super().__init__()
+
+ def _compare_model_configs(self, dict0, dict1):
+ if dict0 == dict1:
+ return True
+ else:
+ config0, meta_keys0 = self._remove_meta_keys(dict0)
+ config1, meta_keys1 = self._remove_meta_keys(dict1)
+ if config0 == config1:
+ print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")
+ return True
+ return False
+
+ def _remove_meta_keys(self, config_dict: Dict):
+ meta_keys = []
+ temp_dict = config_dict.copy()
+ for key in config_dict.keys():
+ if key.startswith("_"):
+ temp_dict.pop(key)
+ meta_keys.append(key)
+ return (temp_dict, meta_keys)
+
+ @torch.no_grad()
+ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
+ """
+ Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
+ in the argument 'pretrained_model_name_or_path_list' as a list.
+
+ Parameters:
+ -----------
+ pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.
+
+ **kwargs:
+ Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
+
+ cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
+
+ alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
+ would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
+
+ interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
+ Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
+
+ force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
+
+ """
+ # Default kwargs from DiffusionPipeline
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ device_map = kwargs.pop("device_map", None)
+
+ alpha = kwargs.pop("alpha", 0.5)
+ interp = kwargs.pop("interp", None)
+
+ print("Received list", pretrained_model_name_or_path_list)
+ print(f"Combining with alpha={alpha}, interpolation mode={interp}")
+
+ checkpoint_count = len(pretrained_model_name_or_path_list)
+ # Ignore result from model_index_json comparision of the two checkpoints
+ force = kwargs.pop("force", False)
+
+ # If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
+ if checkpoint_count > 3 or checkpoint_count < 2:
+ raise ValueError(
+ "Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"
+ " passed."
+ )
+
+ print("Received the right number of checkpoints")
+ # chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]
+ # chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None
+
+ # Validate that the checkpoints can be merged
+ # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
+ config_dicts = []
+ for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
+ config_dict = DiffusionPipeline.load_config(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ config_dicts.append(config_dict)
+
+ comparison_result = True
+ for idx in range(1, len(config_dicts)):
+ comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
+ if not force and comparison_result is False:
+ raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
+ print(config_dicts[0], config_dicts[1])
+ print("Compatible model_index.json files found")
+ # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
+ cached_folders = []
+ for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [
+ WEIGHTS_NAME,
+ SCHEDULER_CONFIG_NAME,
+ CONFIG_NAME,
+ ONNX_WEIGHTS_NAME,
+ DiffusionPipeline.config_name,
+ ]
+ requested_pipeline_class = config_dict.get("_class_name")
+ user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
+
+ cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ user_agent=user_agent,
+ )
+ )
+ print("Cached Folder", cached_folder)
+ cached_folders.append(cached_folder)
+
+ # Step 3:-
+ # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
+ final_pipe = DiffusionPipeline.from_pretrained(
+ cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
+ )
+ final_pipe.to(self.device)
+
+ checkpoint_path_2 = None
+ if len(cached_folders) > 2:
+ checkpoint_path_2 = os.path.join(cached_folders[2])
+
+ if interp == "sigmoid":
+ theta_func = CheckpointMergerPipeline.sigmoid
+ elif interp == "inv_sigmoid":
+ theta_func = CheckpointMergerPipeline.inv_sigmoid
+ elif interp == "add_diff":
+ theta_func = CheckpointMergerPipeline.add_difference
+ else:
+ theta_func = CheckpointMergerPipeline.weighted_sum
+
+ # Find each module's state dict.
+ for attr in final_pipe.config.keys():
+ if not attr.startswith("_"):
+ checkpoint_path_1 = os.path.join(cached_folders[1], attr)
+ if os.path.exists(checkpoint_path_1):
+ files = [
+ *glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
+ *glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
+ ]
+ checkpoint_path_1 = files[0] if len(files) > 0 else None
+ if len(cached_folders) < 3:
+ checkpoint_path_2 = None
+ else:
+ checkpoint_path_2 = os.path.join(cached_folders[2], attr)
+ if os.path.exists(checkpoint_path_2):
+ files = [
+ *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
+ *glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
+ ]
+ checkpoint_path_2 = files[0] if len(files) > 0 else None
+ # For an attr if both checkpoint_path_1 and 2 are None, ignore.
+ # If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
+ if checkpoint_path_1 is None and checkpoint_path_2 is None:
+ print(f"Skipping {attr}: not present in 2nd or 3d model")
+ continue
+ try:
+ module = getattr(final_pipe, attr)
+ if isinstance(module, bool): # ignore requires_safety_checker boolean
+ continue
+ theta_0 = getattr(module, "state_dict")
+ theta_0 = theta_0()
+
+ update_theta_0 = getattr(module, "load_state_dict")
+ theta_1 = (
+ safetensors.torch.load_file(checkpoint_path_1)
+ if (checkpoint_path_1.endswith(".safetensors"))
+ else torch.load(checkpoint_path_1, map_location="cpu")
+ )
+ theta_2 = None
+ if checkpoint_path_2:
+ theta_2 = (
+ safetensors.torch.load_file(checkpoint_path_2)
+ if (checkpoint_path_2.endswith(".safetensors"))
+ else torch.load(checkpoint_path_2, map_location="cpu")
+ )
+
+ if not theta_0.keys() == theta_1.keys():
+ print(f"Skipping {attr}: key mismatch")
+ continue
+ if theta_2 and not theta_1.keys() == theta_2.keys():
+ print(f"Skipping {attr}:y mismatch")
+ except Exception as e:
+ print(f"Skipping {attr} do to an unexpected error: {str(e)}")
+ continue
+ print(f"MERGING {attr}")
+
+ for key in theta_0.keys():
+ if theta_2:
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
+ else:
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
+
+ del theta_1
+ del theta_2
+ update_theta_0(theta_0)
+
+ del theta_0
+ return final_pipe
+
+ @staticmethod
+ def weighted_sum(theta0, theta1, theta2, alpha):
+ return ((1 - alpha) * theta0) + (alpha * theta1)
+
+ # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ @staticmethod
+ def sigmoid(theta0, theta1, theta2, alpha):
+ alpha = alpha * alpha * (3 - (2 * alpha))
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ @staticmethod
+ def inv_sigmoid(theta0, theta1, theta2, alpha):
+ import math
+
+ alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
+ return theta0 + ((theta1 - theta0) * alpha)
+
+ @staticmethod
+ def add_difference(theta0, theta1, theta2, alpha):
+ return theta0 + (theta1 - theta2) * (1.0 - alpha)
diff --git a/diffusers/examples/community/clip_guided_images_mixing_stable_diffusion.py b/diffusers/examples/community/clip_guided_images_mixing_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..399f5b14506dec70c539459f0af9c95b6ee58a6f
--- /dev/null
+++ b/diffusers/examples/community/clip_guided_images_mixing_stable_diffusion.py
@@ -0,0 +1,455 @@
+# -*- coding: utf-8 -*-
+import inspect
+from typing import Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from torch.nn import functional as F
+from torchvision import transforms
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION
+from diffusers.utils.torch_utils import randn_tensor
+
+
+def preprocess(image, w, h):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ if not isinstance(v0, np.ndarray):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def set_requires_grad(model, value):
+ for param in model.parameters():
+ param.requires_grad = value
+
+
+class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ clip_model: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
+ feature_extractor: CLIPFeatureExtractor,
+ coca_model=None,
+ coca_tokenizer=None,
+ coca_transform=None,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ clip_model=clip_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ coca_model=coca_model,
+ coca_tokenizer=coca_tokenizer,
+ coca_transform=coca_transform,
+ )
+ self.feature_extractor_size = (
+ feature_extractor.size
+ if isinstance(feature_extractor.size, int)
+ else feature_extractor.size["shortest_edge"]
+ )
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+ set_requires_grad(self.text_encoder, False)
+ set_requires_grad(self.clip_model, False)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ self.enable_attention_slicing(None)
+
+ def freeze_vae(self):
+ set_requires_grad(self.vae, False)
+
+ def unfreeze_vae(self):
+ set_requires_grad(self.vae, True)
+
+ def freeze_unet(self):
+ set_requires_grad(self.unet, False)
+
+ def unfreeze_unet(self):
+ set_requires_grad(self.unet, True)
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` but is {type(image)}")
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
+ init_latents = 0.18215 * init_latents
+ init_latents = init_latents.repeat_interleave(batch_size, dim=0)
+
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def get_image_description(self, image):
+ transformed_image = self.coca_transform(image).unsqueeze(0)
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ generated = self.coca_model.generate(transformed_image.to(device=self.device, dtype=self.coca_model.dtype))
+ generated = self.coca_tokenizer.decode(generated[0].cpu().numpy())
+ return generated.split("")[0].replace("", "").rstrip(" .,")
+
+ def get_clip_image_embeddings(self, image, batch_size):
+ clip_image_input = self.feature_extractor.preprocess(image)
+ clip_image_features = torch.from_numpy(clip_image_input["pixel_values"][0]).unsqueeze(0).to(self.device).half()
+ image_embeddings_clip = self.clip_model.get_image_features(clip_image_features)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ image_embeddings_clip = image_embeddings_clip.repeat_interleave(batch_size, dim=0)
+ return image_embeddings_clip
+
+ @torch.enable_grad()
+ def cond_fn(
+ self,
+ latents,
+ timestep,
+ index,
+ text_embeddings,
+ noise_pred_original,
+ original_image_embeddings_clip,
+ clip_guidance_scale,
+ ):
+ latents = latents.detach().requires_grad_()
+
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
+
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ # compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ fac = torch.sqrt(beta_prod_t)
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[index]
+ sample = latents - sigma * noise_pred
+ else:
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
+
+ # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
+ sample = 1 / 0.18215 * sample
+ image = self.vae.decode(sample).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ image = transforms.Resize(self.feature_extractor_size)(image)
+ image = self.normalize(image).to(latents.dtype)
+
+ image_embeddings_clip = self.clip_model.get_image_features(image)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+
+ loss = spherical_dist_loss(image_embeddings_clip, original_image_embeddings_clip).mean() * clip_guidance_scale
+
+ grads = -torch.autograd.grad(loss, latents)[0]
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents.detach() + grads * (sigma**2)
+ noise_pred = noise_pred_original
+ else:
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
+ return noise_pred, latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ style_image: Union[torch.FloatTensor, PIL.Image.Image],
+ content_image: Union[torch.FloatTensor, PIL.Image.Image],
+ style_prompt: Optional[str] = None,
+ content_prompt: Optional[str] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ noise_strength: float = 0.6,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ batch_size: Optional[int] = 1,
+ eta: float = 0.0,
+ clip_guidance_scale: Optional[float] = 100,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ slerp_latent_style_strength: float = 0.8,
+ slerp_prompt_style_strength: float = 0.1,
+ slerp_clip_image_style_strength: float = 0.1,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(f"You have passed {batch_size} batch_size, but only {len(generator)} generators.")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if isinstance(generator, torch.Generator) and batch_size > 1:
+ generator = [generator] + [None] * (batch_size - 1)
+
+ coca_is_none = [
+ ("model", self.coca_model is None),
+ ("tokenizer", self.coca_tokenizer is None),
+ ("transform", self.coca_transform is None),
+ ]
+ coca_is_none = [x[0] for x in coca_is_none if x[1]]
+ coca_is_none_str = ", ".join(coca_is_none)
+ # generate prompts with coca model if prompt is None
+ if content_prompt is None:
+ if len(coca_is_none):
+ raise ValueError(
+ f"Content prompt is None and CoCa [{coca_is_none_str}] is None."
+ f"Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline."
+ )
+ content_prompt = self.get_image_description(content_image)
+ if style_prompt is None:
+ if len(coca_is_none):
+ raise ValueError(
+ f"Style prompt is None and CoCa [{coca_is_none_str}] is None."
+ f" Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline."
+ )
+ style_prompt = self.get_image_description(style_image)
+
+ # get prompt text embeddings for content and style
+ content_text_input = self.tokenizer(
+ content_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ content_text_embeddings = self.text_encoder(content_text_input.input_ids.to(self.device))[0]
+
+ style_text_input = self.tokenizer(
+ style_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ style_text_embeddings = self.text_encoder(style_text_input.input_ids.to(self.device))[0]
+
+ text_embeddings = slerp(slerp_prompt_style_strength, content_text_embeddings, style_text_embeddings)
+
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ self.scheduler.timesteps.to(self.device)
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, noise_strength, self.device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # Preprocess image
+ preprocessed_content_image = preprocess(content_image, width, height)
+ content_latents = self.prepare_latents(
+ preprocessed_content_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator
+ )
+
+ preprocessed_style_image = preprocess(style_image, width, height)
+ style_latents = self.prepare_latents(
+ preprocessed_style_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator
+ )
+
+ latents = slerp(slerp_latent_style_strength, content_latents, style_latents)
+
+ if clip_guidance_scale > 0:
+ content_clip_image_embedding = self.get_clip_image_embeddings(content_image, batch_size)
+ style_clip_image_embedding = self.get_clip_image_embeddings(style_image, batch_size)
+ clip_image_embeddings = slerp(
+ slerp_clip_image_style_strength, content_clip_image_embedding, style_clip_image_embedding
+ )
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = content_text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform clip guidance
+ if clip_guidance_scale > 0:
+ text_embeddings_for_guidance = (
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
+ )
+ noise_pred, latents = self.cond_fn(
+ latents,
+ t,
+ i,
+ text_embeddings_for_guidance,
+ noise_pred,
+ clip_image_embeddings,
+ clip_guidance_scale,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ progress_bar.update()
+ # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/clip_guided_stable_diffusion.py b/diffusers/examples/community/clip_guided_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f4ab2ab9f4ad2417d6dbf40e1fd2e479df88b73
--- /dev/null
+++ b/diffusers/examples/community/clip_guided_stable_diffusion.py
@@ -0,0 +1,347 @@
+import inspect
+from typing import List, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cut_power=1.0):
+ super().__init__()
+
+ self.cut_size = cut_size
+ self.cut_power = cut_power
+
+ def forward(self, pixel_values, num_cutouts):
+ sideY, sideX = pixel_values.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ cutouts = []
+ for _ in range(num_cutouts):
+ size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
+ return torch.cat(cutouts)
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def set_requires_grad(model, value):
+ for param in model.parameters():
+ param.requires_grad = value
+
+
+class CLIPGuidedStableDiffusion(DiffusionPipeline):
+ """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
+ - https://github.com/Jack000/glid-3-xl
+ - https://github.dev/crowsonkb/k-diffusion
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ clip_model: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ clip_model=clip_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+ self.cut_out_size = (
+ feature_extractor.size
+ if isinstance(feature_extractor.size, int)
+ else feature_extractor.size["shortest_edge"]
+ )
+ self.make_cutouts = MakeCutouts(self.cut_out_size)
+
+ set_requires_grad(self.text_encoder, False)
+ set_requires_grad(self.clip_model, False)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ self.enable_attention_slicing(None)
+
+ def freeze_vae(self):
+ set_requires_grad(self.vae, False)
+
+ def unfreeze_vae(self):
+ set_requires_grad(self.vae, True)
+
+ def freeze_unet(self):
+ set_requires_grad(self.unet, False)
+
+ def unfreeze_unet(self):
+ set_requires_grad(self.unet, True)
+
+ @torch.enable_grad()
+ def cond_fn(
+ self,
+ latents,
+ timestep,
+ index,
+ text_embeddings,
+ noise_pred_original,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts=True,
+ ):
+ latents = latents.detach().requires_grad_()
+
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
+
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ # compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ fac = torch.sqrt(beta_prod_t)
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[index]
+ sample = latents - sigma * noise_pred
+ else:
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
+
+ sample = 1 / self.vae.config.scaling_factor * sample
+ image = self.vae.decode(sample).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ if use_cutouts:
+ image = self.make_cutouts(image, num_cutouts)
+ else:
+ image = transforms.Resize(self.cut_out_size)(image)
+ image = self.normalize(image).to(latents.dtype)
+
+ image_embeddings_clip = self.clip_model.get_image_features(image)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+
+ if use_cutouts:
+ dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
+ dists = dists.view([num_cutouts, sample.shape[0], -1])
+ loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
+ else:
+ loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
+
+ grads = -torch.autograd.grad(loss, latents)[0]
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents.detach() + grads * (sigma**2)
+ noise_pred = noise_pred_original
+ else:
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
+ return noise_pred, latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ clip_guidance_scale: Optional[float] = 100,
+ clip_prompt: Optional[Union[str, List[str]]] = None,
+ num_cutouts: Optional[int] = 4,
+ use_cutouts: Optional[bool] = True,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if clip_guidance_scale > 0:
+ if clip_prompt is not None:
+ clip_text_input = self.tokenizer(
+ clip_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids.to(self.device)
+ else:
+ clip_text_input = text_input.input_ids.to(self.device)
+ text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
+ text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ # duplicate text embeddings clip for each generation per prompt
+ text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform clip guidance
+ if clip_guidance_scale > 0:
+ text_embeddings_for_guidance = (
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
+ )
+ noise_pred, latents = self.cond_fn(
+ latents,
+ t,
+ i,
+ text_embeddings_for_guidance,
+ noise_pred,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/clip_guided_stable_diffusion_img2img.py b/diffusers/examples/community/clip_guided_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dbc9bef9ffebdbd2324d44df9198450d4f270ae
--- /dev/null
+++ b/diffusers/examples/community/clip_guided_stable_diffusion_img2img.py
@@ -0,0 +1,493 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision import transforms
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION, deprecate
+from diffusers.utils.torch_utils import randn_tensor
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```
+ from io import BytesIO
+
+ import requests
+ import torch
+ from diffusers import DiffusionPipeline
+ from PIL import Image
+ from transformers import CLIPFeatureExtractor, CLIPModel
+
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
+ )
+ clip_model = CLIPModel.from_pretrained(
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16
+ )
+
+
+ guided_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ # custom_pipeline="clip_guided_stable_diffusion",
+ custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+ )
+ guided_pipeline.enable_attention_slicing()
+ guided_pipeline = guided_pipeline.to("cuda")
+
+ prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+
+ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+
+ response = requests.get(url)
+ init_image = Image.open(BytesIO(response.content)).convert("RGB")
+
+ image = guided_pipeline(
+ prompt=prompt,
+ num_inference_steps=30,
+ image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+ ).images[0]
+ display(image)
+ ```
+"""
+
+
+def preprocess(image, w, h):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cut_power=1.0):
+ super().__init__()
+
+ self.cut_size = cut_size
+ self.cut_power = cut_power
+
+ def forward(self, pixel_values, num_cutouts):
+ sideY, sideX = pixel_values.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ cutouts = []
+ for _ in range(num_cutouts):
+ size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
+ return torch.cat(cutouts)
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def set_requires_grad(model, value):
+ for param in model.parameters():
+ param.requires_grad = value
+
+
+class CLIPGuidedStableDiffusion(DiffusionPipeline):
+ """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
+ - https://github.com/Jack000/glid-3-xl
+ - https://github.dev/crowsonkb/k-diffusion
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ clip_model: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ clip_model=clip_model,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
+ self.cut_out_size = (
+ feature_extractor.size
+ if isinstance(feature_extractor.size, int)
+ else feature_extractor.size["shortest_edge"]
+ )
+ self.make_cutouts = MakeCutouts(self.cut_out_size)
+
+ set_requires_grad(self.text_encoder, False)
+ set_requires_grad(self.clip_model, False)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ self.enable_attention_slicing(None)
+
+ def freeze_vae(self):
+ set_requires_grad(self.vae, False)
+
+ def unfreeze_vae(self):
+ set_requires_grad(self.vae, True)
+
+ def freeze_unet(self):
+ set_requires_grad(self.unet, False)
+
+ def unfreeze_unet(self):
+ set_requires_grad(self.unet, True)
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.enable_grad()
+ def cond_fn(
+ self,
+ latents,
+ timestep,
+ index,
+ text_embeddings,
+ noise_pred_original,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts=True,
+ ):
+ latents = latents.detach().requires_grad_()
+
+ latent_model_input = self.scheduler.scale_model_input(latents, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
+
+ if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ # compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
+
+ fac = torch.sqrt(beta_prod_t)
+ sample = pred_original_sample * (fac) + latents * (1 - fac)
+ elif isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[index]
+ sample = latents - sigma * noise_pred
+ else:
+ raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
+
+ sample = 1 / self.vae.config.scaling_factor * sample
+ image = self.vae.decode(sample).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ if use_cutouts:
+ image = self.make_cutouts(image, num_cutouts)
+ else:
+ image = transforms.Resize(self.cut_out_size)(image)
+ image = self.normalize(image).to(latents.dtype)
+
+ image_embeddings_clip = self.clip_model.get_image_features(image)
+ image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+
+ if use_cutouts:
+ dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
+ dists = dists.view([num_cutouts, sample.shape[0], -1])
+ loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
+ else:
+ loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
+
+ grads = -torch.autograd.grad(loss, latents)[0]
+
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents.detach() + grads * (sigma**2)
+ noise_pred = noise_pred_original
+ else:
+ noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
+ return noise_pred, latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ clip_guidance_scale: Optional[float] = 100,
+ clip_prompt: Optional[Union[str, List[str]]] = None,
+ num_cutouts: Optional[int] = 4,
+ use_cutouts: Optional[bool] = True,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ # duplicate text embeddings for each generation per prompt
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ self.scheduler.timesteps.to(self.device)
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # Preprocess image
+ image = preprocess(image, width, height)
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator
+ )
+
+ if clip_guidance_scale > 0:
+ if clip_prompt is not None:
+ clip_text_input = self.tokenizer(
+ clip_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids.to(self.device)
+ else:
+ clip_text_input = text_input.input_ids.to(self.device)
+ text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
+ text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
+ # duplicate text embeddings clip for each generation per prompt
+ text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+ # duplicate unconditional embeddings for each generation per prompt
+ uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ with self.progress_bar(total=num_inference_steps):
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform classifier free guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # perform clip guidance
+ if clip_guidance_scale > 0:
+ text_embeddings_for_guidance = (
+ text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
+ )
+ noise_pred, latents = self.cond_fn(
+ latents,
+ t,
+ i,
+ text_embeddings_for_guidance,
+ noise_pred,
+ text_embeddings_clip,
+ clip_guidance_scale,
+ num_cutouts,
+ use_cutouts,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/composable_stable_diffusion.py b/diffusers/examples/community/composable_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..444d3375c3d162de2e4c2c4089b19ffe176fb081
--- /dev/null
+++ b/diffusers/examples/community/composable_stable_diffusion.py
@@ -0,0 +1,582 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import deprecate, is_accelerate_available, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ComposableStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
+ # fix by only offloading self.safety_checker for now
+ cpu_offload(self.safety_checker.vision_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ weights: Optional[str] = "",
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if "|" in prompt:
+ prompt = [x.strip() for x in prompt.split("|")]
+ print(f"composing {prompt}...")
+
+ if not weights:
+ # specify weights for prompts (excluding the unconditional score)
+ print("using equal positive weights (conjunction) for all prompts...")
+ weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
+ else:
+ # set prompt weight for each
+ num_prompts = len(prompt) if isinstance(prompt, list) else 1
+ weights = [float(w.strip()) for w in weights.split("|")]
+ # guidance scale as the default
+ if len(weights) < num_prompts:
+ weights.append(guidance_scale)
+ else:
+ weights = weights[:num_prompts]
+ assert len(weights) == len(prompt), "weights specified are not equal to the number of prompts"
+ weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)
+ else:
+ weights = guidance_scale
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # composable diffusion
+ if isinstance(prompt, list) and batch_size == 1:
+ # remove extra unconditional embedding
+ # N = one unconditional embed + conditional embeds
+ text_embeddings = text_embeddings[len(prompt) - 1 :]
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = []
+ for j in range(text_embeddings.shape[0]):
+ noise_pred.append(
+ self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample
+ )
+ noise_pred = torch.cat(noise_pred, dim=0)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]
+ noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
+ dim=0, keepdims=True
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/ddim_noise_comparative_analysis.py b/diffusers/examples/community/ddim_noise_comparative_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..482c0a5826d27c04ff2767d4d67ca4475642a0da
--- /dev/null
+++ b/diffusers/examples/community/ddim_noise_comparative_analysis.py
@@ -0,0 +1,190 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import PIL.Image
+import torch
+from torchvision import transforms
+
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils.torch_utils import randn_tensor
+
+
+trans = transforms.Compose(
+ [
+ transforms.Resize((256, 256)),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+)
+
+
+def preprocess(image):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ image = [trans(img.convert("RGB")) for img in image]
+ image = torch.stack(image)
+ return image
+
+
+class DDIMNoiseComparativeAnalysisPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ # make sure scheduler can always be converted to DDIM
+ scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ def check_inputs(self, strength):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ init_latents = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ print("add noise to latents at timestep", timestep)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ use_clipped_model_output: Optional[bool] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ use_clipped_model_output (`bool`, *optional*, defaults to `None`):
+ if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
+ downstream to the scheduler. So use `None` for schedulers which don't support this argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(strength)
+
+ # 2. Preprocess image
+ image = preprocess(image)
+
+ # 3. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # 4. Prepare latent variables
+ latents = self.prepare_latents(image, latent_timestep, batch_size, self.unet.dtype, self.device, generator)
+ image = latents
+
+ # 5. Denoising loop
+ for t in self.progress_bar(timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # eta corresponds to η in paper and should be between [0, 1]
+ # do x_t -> x_t-1
+ image = self.scheduler.step(
+ model_output,
+ t,
+ image,
+ eta=eta,
+ use_clipped_model_output=use_clipped_model_output,
+ generator=generator,
+ ).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, latent_timestep.item())
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/edict_pipeline.py b/diffusers/examples/community/edict_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac977f79abecd281c07e780c76023216afb1a5f6
--- /dev/null
+++ b/diffusers/examples/community/edict_pipeline.py
@@ -0,0 +1,264 @@
+from typing import Optional
+
+import torch
+from PIL import Image
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.utils import (
+ deprecate,
+)
+
+
+class EDICTPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ mixing_coeff: float = 0.93,
+ leapfrog_steps: bool = True,
+ ):
+ self.mixing_coeff = mixing_coeff
+ self.leapfrog_steps = leapfrog_steps
+
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(
+ self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False
+ ):
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ prompt_embeds = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = "" if negative_prompt is None else negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device)).last_hidden_state
+
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):
+ x = self.mixing_coeff * x + (1 - self.mixing_coeff) * y
+ y = self.mixing_coeff * y + (1 - self.mixing_coeff) * x
+
+ return [x, y]
+
+ def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):
+ y = (y - (1 - self.mixing_coeff) * x) / self.mixing_coeff
+ x = (x - (1 - self.mixing_coeff) * y) / self.mixing_coeff
+
+ return [x, y]
+
+ def _get_alpha_and_beta(self, t: torch.Tensor):
+ # as self.alphas_cumprod is always in cpu
+ t = int(t)
+
+ alpha_prod = self.scheduler.alphas_cumprod[t] if t >= 0 else self.scheduler.final_alpha_cumprod
+
+ return alpha_prod, 1 - alpha_prod
+
+ def noise_step(
+ self,
+ base: torch.Tensor,
+ model_input: torch.Tensor,
+ model_output: torch.Tensor,
+ timestep: torch.Tensor,
+ ):
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps
+
+ alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)
+ alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)
+
+ a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5
+ b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5
+
+ next_model_input = (base - b_t * model_output) / a_t
+
+ return model_input, next_model_input.to(base.dtype)
+
+ def denoise_step(
+ self,
+ base: torch.Tensor,
+ model_input: torch.Tensor,
+ model_output: torch.Tensor,
+ timestep: torch.Tensor,
+ ):
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps
+
+ alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)
+ alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)
+
+ a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5
+ b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5
+ next_model_input = a_t * base + b_t * model_output
+
+ return model_input, next_model_input.to(base.dtype)
+
+ @torch.no_grad()
+ def decode_latents(self, latents: torch.Tensor):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ return image
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ image: Image.Image,
+ text_embeds: torch.Tensor,
+ timesteps: torch.Tensor,
+ guidance_scale: float,
+ generator: Optional[torch.Generator] = None,
+ ):
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ image = image.to(device=self.device, dtype=text_embeds.dtype)
+ latent = self.vae.encode(image).latent_dist.sample(generator)
+
+ latent = self.vae.config.scaling_factor * latent
+
+ coupled_latents = [latent.clone(), latent.clone()]
+
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
+ coupled_latents = self.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])
+
+ # j - model_input index, k - base index
+ for j in range(2):
+ k = j ^ 1
+
+ if self.leapfrog_steps:
+ if i % 2 == 0:
+ k, j = j, k
+
+ model_input = coupled_latents[j]
+ base = coupled_latents[k]
+
+ latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ base, model_input = self.noise_step(
+ base=base,
+ model_input=model_input,
+ model_output=noise_pred,
+ timestep=t,
+ )
+
+ coupled_latents[k] = model_input
+
+ return coupled_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ base_prompt: str,
+ target_prompt: str,
+ image: Image.Image,
+ guidance_scale: float = 3.0,
+ num_inference_steps: int = 50,
+ strength: float = 0.8,
+ negative_prompt: Optional[str] = None,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ ):
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ image = self.image_processor.preprocess(image)
+
+ base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance)
+ target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance)
+
+ self.scheduler.set_timesteps(num_inference_steps, self.device)
+
+ t_limit = num_inference_steps - int(num_inference_steps * strength)
+ fwd_timesteps = self.scheduler.timesteps[t_limit:]
+ bwd_timesteps = fwd_timesteps.flip(0)
+
+ coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale, generator)
+
+ for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)):
+ # j - model_input index, k - base index
+ for k in range(2):
+ j = k ^ 1
+
+ if self.leapfrog_steps:
+ if i % 2 == 1:
+ k, j = j, k
+
+ model_input = coupled_latents[j]
+ base = coupled_latents[k]
+
+ latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
+
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ base, model_input = self.denoise_step(
+ base=base,
+ model_input=model_input,
+ model_output=noise_pred,
+ timestep=t,
+ )
+
+ coupled_latents[k] = model_input
+
+ coupled_latents = self.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])
+
+ # either one is fine
+ final_latent = coupled_latents[0]
+
+ if output_type not in ["latent", "pt", "np", "pil"]:
+ deprecation_message = (
+ f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
+ "`pil`, `np`, `pt`, `latent`"
+ )
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
+ output_type = "np"
+
+ if output_type == "latent":
+ image = final_latent
+ else:
+ image = self.decode_latents(final_latent)
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ return image
diff --git a/diffusers/examples/community/iadb.py b/diffusers/examples/community/iadb.py
new file mode 100644
index 0000000000000000000000000000000000000000..6089e49fc621e8bfaa78440b372d7a28c4aef3a3
--- /dev/null
+++ b/diffusers/examples/community/iadb.py
@@ -0,0 +1,149 @@
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+class IADBScheduler(SchedulerMixin, ConfigMixin):
+ """
+ IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist.
+
+ For more details, see the original paper: https://arxiv.org/abs/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html
+ """
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ x_alpha: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ """
+ Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model. It is the direction from x0 to x1.
+ timestep (`float`): current timestep in the diffusion chain.
+ x_alpha (`torch.FloatTensor`): x_alpha sample for the current timestep
+
+ Returns:
+ `torch.FloatTensor`: the sample at the previous timestep
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ alpha = timestep / self.num_inference_steps
+ alpha_next = (timestep + 1) / self.num_inference_steps
+
+ d = model_output
+
+ x_alpha = x_alpha + (alpha_next - alpha) * d
+
+ return x_alpha
+
+ def set_timesteps(self, num_inference_steps: int):
+ self.num_inference_steps = num_inference_steps
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ alpha: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ return original_samples * alpha + noise * (1 - alpha)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+
+class IADBPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # Sample gaussian noise to begin loop
+ if isinstance(self.unet.config.sample_size, int):
+ image_shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ self.unet.config.sample_size,
+ self.unet.config.sample_size,
+ )
+ else:
+ image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+ x_alpha = image.clone()
+ for t in self.progress_bar(range(num_inference_steps)):
+ alpha = t / num_inference_steps
+
+ # 1. predict noise model_output
+ model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample
+
+ # 2. step
+ x_alpha = self.scheduler.step(model_output, t, x_alpha)
+
+ image = (x_alpha * 0.5 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/imagic_stable_diffusion.py b/diffusers/examples/community/imagic_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..acd09c7e0bf406cea672242370dac93b70a99da3
--- /dev/null
+++ b/diffusers/examples/community/imagic_stable_diffusion.py
@@ -0,0 +1,496 @@
+"""
+ modeled after the textual_inversion.py / train_dreambooth.py and the work
+ of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
+"""
+import inspect
+import warnings
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import logging
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class ImagicStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for imagic image editing.
+ See paper here: https://arxiv.org/pdf/2210.09276.pdf
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def train(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ generator: Optional[torch.Generator] = None,
+ embedding_learning_rate: float = 0.001,
+ diffusion_model_learning_rate: float = 2e-6,
+ text_embedding_optimization_steps: int = 500,
+ model_fine_tuning_optimization_steps: int = 1000,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ accelerator = Accelerator(
+ gradient_accumulation_steps=1,
+ mixed_precision="fp16",
+ )
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # Freeze vae and unet
+ self.vae.requires_grad_(False)
+ self.unet.requires_grad_(False)
+ self.text_encoder.requires_grad_(False)
+ self.unet.eval()
+ self.vae.eval()
+ self.text_encoder.eval()
+
+ if accelerator.is_main_process:
+ accelerator.init_trackers(
+ "imagic",
+ config={
+ "embedding_learning_rate": embedding_learning_rate,
+ "text_embedding_optimization_steps": text_embedding_optimization_steps,
+ },
+ )
+
+ # get text embeddings for prompt
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = torch.nn.Parameter(
+ self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
+ )
+ text_embeddings = text_embeddings.detach()
+ text_embeddings.requires_grad_()
+ text_embeddings_orig = text_embeddings.clone()
+
+ # Initialize the optimizer
+ optimizer = torch.optim.Adam(
+ [text_embeddings], # only optimize the embeddings
+ lr=embedding_learning_rate,
+ )
+
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess(image)
+
+ latents_dtype = text_embeddings.dtype
+ image = image.to(device=self.device, dtype=latents_dtype)
+ init_latent_image_dist = self.vae.encode(image).latent_dist
+ image_latents = init_latent_image_dist.sample(generator=generator)
+ image_latents = 0.18215 * image_latents
+
+ progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ global_step = 0
+
+ logger.info("First optimizing the text embedding to better reconstruct the init image")
+ for _ in range(text_embedding_optimization_steps):
+ with accelerator.accumulate(text_embeddings):
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(image_latents.shape).to(image_latents.device)
+ timesteps = torch.randint(1000, (1,), device=image_latents.device)
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
+
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
+
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ accelerator.wait_for_everyone()
+
+ text_embeddings.requires_grad_(False)
+
+ # Now we fine tune the unet to better reconstruct the image
+ self.unet.requires_grad_(True)
+ self.unet.train()
+ optimizer = torch.optim.Adam(
+ self.unet.parameters(), # only optimize unet
+ lr=diffusion_model_learning_rate,
+ )
+ progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
+
+ logger.info("Next fine tuning the entire model to better reconstruct the init image")
+ for _ in range(model_fine_tuning_optimization_steps):
+ with accelerator.accumulate(self.unet.parameters()):
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(image_latents.shape).to(image_latents.device)
+ timesteps = torch.randint(1000, (1,), device=image_latents.device)
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
+
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
+
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ accelerator.wait_for_everyone()
+ self.text_embeddings_orig = text_embeddings_orig
+ self.text_embeddings = text_embeddings
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ alpha: float = 1.2,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ guidance_scale: float = 7.5,
+ eta: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+ if self.text_embeddings is None:
+ raise ValueError("Please run the pipe.train() before trying to generate an image.")
+ if self.text_embeddings_orig is None:
+ raise ValueError("Please run the pipe.train() before trying to generate an image.")
+
+ text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens = [""]
+ max_length = self.tokenizer.model_max_length
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/img2img_inpainting.py b/diffusers/examples/community/img2img_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ee8355d49a6b85b2278f6dc5c8e8a7c40adbd92
--- /dev/null
+++ b/diffusers/examples/community/img2img_inpainting.py
@@ -0,0 +1,464 @@
+import inspect
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+def check_size(image, height, width):
+ if isinstance(image, PIL.Image.Image):
+ w, h = image.size
+ elif isinstance(image, torch.Tensor):
+ *_, h, w = image.shape
+
+ if h != height or w != width:
+ raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
+
+
+def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
+ inner_image = inner_image.convert("RGBA")
+ image = image.convert("RGB")
+
+ image.paste(inner_image, paste_offset, inner_image)
+ image = image.convert("RGB")
+
+ return image
+
+
+class ImageToImageInpaintingPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ inner_image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ inner_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent
+ regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with
+ the last channel representing the alpha channel, which will be used to blend `inner_image` with
+ `image`. If not provided, it will be forcibly cast to RGBA.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # check if input sizes are correct
+ check_size(image, height, width)
+ check_size(inner_image, height, width)
+ check_size(mask_image, height, width)
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ num_channels_latents = self.vae.config.latent_channels
+ latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # overlay the inner image
+ image = overlay_inner_image(image, inner_image)
+
+ # prepare mask and masked_image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+ mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
+ masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
+
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = 0.18215 * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
+ masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/interpolate_stable_diffusion.py b/diffusers/examples/community/interpolate_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e4d025a037284e7a3cfea54074b2ec473dea84
--- /dev/null
+++ b/diffusers/examples/community/interpolate_stable_diffusion.py
@@ -0,0 +1,525 @@
+import inspect
+import time
+from pathlib import Path
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ """helper function to spherically interpolate two arrays v1 v2"""
+
+ if not isinstance(v0, np.ndarray):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+class StableDiffusionWalkPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ text_embeddings: Optional[torch.FloatTensor] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*, defaults to `None`):
+ The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
+ Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
+ `prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
+ the supplied `prompt`.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if text_embeddings is None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ print(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+ else:
+ batch_size = text_embeddings.shape[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = self.tokenizer.model_max_length
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def embed_text(self, text):
+ """takes in text and turns it into text embeddings"""
+ text_input = self.tokenizer(
+ text,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ with torch.no_grad():
+ embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ return embed
+
+ def get_noise(self, seed, dtype=torch.float32, height=512, width=512):
+ """Takes in random seed and returns corresponding noise vector"""
+ return torch.randn(
+ (1, self.unet.config.in_channels, height // 8, width // 8),
+ generator=torch.Generator(device=self.device).manual_seed(seed),
+ device=self.device,
+ dtype=dtype,
+ )
+
+ def walk(
+ self,
+ prompts: List[str],
+ seeds: List[int],
+ num_interpolation_steps: Optional[int] = 6,
+ output_dir: Optional[str] = "./dreams",
+ name: Optional[str] = None,
+ batch_size: Optional[int] = 1,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ guidance_scale: Optional[float] = 7.5,
+ num_inference_steps: Optional[int] = 50,
+ eta: Optional[float] = 0.0,
+ ) -> List[str]:
+ """
+ Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.
+
+ Args:
+ prompts (`List[str]`):
+ List of prompts to generate images for.
+ seeds (`List[int]`):
+ List of seeds corresponding to provided prompts. Must be the same length as prompts.
+ num_interpolation_steps (`int`, *optional*, defaults to 6):
+ Number of interpolation steps to take between prompts.
+ output_dir (`str`, *optional*, defaults to `./dreams`):
+ Directory to save the generated images to.
+ name (`str`, *optional*, defaults to `None`):
+ Subdirectory of `output_dir` to save the generated images to. If `None`, the name will
+ be the current time.
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate at once.
+ height (`int`, *optional*, defaults to 512):
+ Height of the generated images.
+ width (`int`, *optional*, defaults to 512):
+ Width of the generated images.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+
+ Returns:
+ `List[str]`: List of paths to the generated images.
+ """
+ if not len(prompts) == len(seeds):
+ raise ValueError(
+ f"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds"
+ )
+
+ name = name or time.strftime("%Y%m%d-%H%M%S")
+ save_path = Path(output_dir) / name
+ save_path.mkdir(exist_ok=True, parents=True)
+
+ frame_idx = 0
+ frame_filepaths = []
+ for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):
+ # Embed Text
+ embed_a = self.embed_text(prompt_a)
+ embed_b = self.embed_text(prompt_b)
+
+ # Get Noise
+ noise_dtype = embed_a.dtype
+ noise_a = self.get_noise(seed_a, noise_dtype, height, width)
+ noise_b = self.get_noise(seed_b, noise_dtype, height, width)
+
+ noise_batch, embeds_batch = None, None
+ T = np.linspace(0.0, 1.0, num_interpolation_steps)
+ for i, t in enumerate(T):
+ noise = slerp(float(t), noise_a, noise_b)
+ embed = torch.lerp(embed_a, embed_b, t)
+
+ noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)
+ embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)
+
+ batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
+ if batch_is_ready:
+ outputs = self(
+ latents=noise_batch,
+ text_embeddings=embeds_batch,
+ height=height,
+ width=width,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ )
+ noise_batch, embeds_batch = None, None
+
+ for image in outputs["images"]:
+ frame_filepath = str(save_path / f"frame_{frame_idx:06d}.png")
+ image.save(frame_filepath)
+ frame_filepaths.append(frame_filepath)
+ frame_idx += 1
+ return frame_filepaths
diff --git a/diffusers/examples/community/latent_consistency_img2img.py b/diffusers/examples/community/latent_consistency_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2dffdfe3167d1752e0eb8d85ac67466280cad34
--- /dev/null
+++ b/diffusers/examples/community/latent_consistency_img2img.py
@@ -0,0 +1,827 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
+from diffusers.configuration_utils import register_to_config
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import BaseOutput
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):
+ _optional_components = ["scheduler"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: "LCMSchedulerWithTimestamp",
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ scheduler = (
+ scheduler
+ if scheduler is not None
+ else LCMSchedulerWithTimestamp(
+ beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon"
+ )
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds: None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+
+ if prompt is not None and isinstance(prompt, str):
+ pass
+ elif prompt is not None and isinstance(prompt, list):
+ len(prompt)
+ else:
+ prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ latents=None,
+ generator=None,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ # batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ if latents is None:
+ latents = torch.randn(shape, dtype=dtype).to(device)
+ else:
+ latents = latents.to(device)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+ Args:
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
+ embedding_dim: int: dimension of the embeddings to generate
+ dtype: data type of the generated embeddings
+ Returns:
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.8,
+ height: Optional[int] = 768,
+ width: Optional[int] = 768,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ latents: Optional[torch.FloatTensor] = None,
+ num_inference_steps: int = 4,
+ lcm_origin_steps: int = 50,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ )
+
+ # 3.5 encode image
+ image = self.image_processor.preprocess(image)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)
+ # timesteps = self.scheduler.timesteps
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
+ timesteps = self.scheduler.timesteps
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ print("timesteps: ", timesteps)
+
+ # 5. Prepare latent variable
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ latents,
+ )
+ bs = batch_size * num_images_per_prompt
+
+ # 6. Get Guidance Scale Embedding
+ w = torch.tensor(guidance_scale).repeat(bs)
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)
+
+ # 7. LCM MultiStep Sampling Loop:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
+ latents = latents.to(prompt_embeds.dtype)
+
+ # model prediction (v-prediction, eps, x)
+ model_pred = self.unet(
+ latents,
+ ts,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)
+
+ # # call the callback, if provided
+ # if i == len(timesteps) - 1:
+ progress_bar.update()
+
+ denoised = denoised.to(prompt_embeds.dtype)
+ if not output_type == "latent":
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = denoised
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class LCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ denoised: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):
+ """
+ This class modifies LCMScheduler to add a timestamp argument to set_timesteps
+
+
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(
+ self, stength, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # LCM Timesteps Setting: # Linear Spacing
+ c = self.config.num_train_timesteps // lcm_origin_steps
+ lcm_origin_timesteps = (
+ np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1
+ ) # LCM Training Steps Schedule
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
+
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
+
+ def get_scalings_for_boundary_condition_discrete(self, t):
+ self.sigma_data = 0.5 # Default: 0.5
+
+ # By dividing 0.1: This is almost a delta function at t=0.
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timeindex: int,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[LCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.FloatTensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # 1. get previous step value
+ prev_timeindex = timeindex + 1
+ if prev_timeindex < len(self.timesteps):
+ prev_timestep = self.timesteps[prev_timeindex]
+ else:
+ prev_timestep = timestep
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 3. Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
+
+ # 4. Different Parameterization:
+ parameterization = self.config.prediction_type
+
+ if parameterization == "epsilon": # noise-prediction
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
+
+ elif parameterization == "sample": # x-prediction
+ pred_x0 = model_output
+
+ elif parameterization == "v_prediction": # v-prediction
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
+
+ # 4. Denoise model output using boundary conditions
+ denoised = c_out * pred_x0 + c_skip * sample
+
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
+ # Noise is not used for one-step sampling.
+ if len(self.timesteps) > 1:
+ noise = torch.randn(model_output.shape).to(model_output.device)
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
+ else:
+ prev_sample = denoised
+
+ if not return_dict:
+ return (prev_sample, denoised)
+
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/examples/community/latent_consistency_interpolate.py b/diffusers/examples/community/latent_consistency_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1058bf6598c86cc83d090c43333957d26b1cd9a3
--- /dev/null
+++ b/diffusers/examples/community/latent_consistency_interpolate.py
@@ -0,0 +1,1051 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import LCMScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> import numpy as np
+
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_interpolate")
+ >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
+ >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32)
+
+ >>> prompts = ["A cat", "A dog", "A horse"]
+ >>> num_inference_steps = 4
+ >>> num_interpolation_steps = 24
+ >>> seed = 1337
+
+ >>> torch.manual_seed(seed)
+ >>> np.random.seed(seed)
+
+ >>> images = pipe(
+ prompt=prompts,
+ height=512,
+ width=512,
+ num_inference_steps=num_inference_steps,
+ num_interpolation_steps=num_interpolation_steps,
+ guidance_scale=8.0,
+ embedding_interpolation_type="lerp",
+ latent_interpolation_type="slerp",
+ process_batch_size=4, # Make it higher or lower based on your GPU memory
+ generator=torch.Generator(seed),
+ )
+
+ >>> # Save the images as a video
+ >>> import imageio
+ >>> from PIL import Image
+
+ >>> def pil_to_video(images: List[Image.Image], filename: str, fps: int = 60) -> None:
+ frames = [np.array(image) for image in images]
+ with imageio.get_writer(filename, fps=fps) as video_writer:
+ for frame in frames:
+ video_writer.append_data(frame)
+
+ >>> pil_to_video(images, "lcm_interpolate.mp4", fps=24)
+ ```
+"""
+
+
+def lerp(
+ v0: Union[torch.Tensor, np.ndarray],
+ v1: Union[torch.Tensor, np.ndarray],
+ t: Union[float, torch.Tensor, np.ndarray],
+) -> Union[torch.Tensor, np.ndarray]:
+ """
+ Linearly interpolate between two vectors/tensors.
+
+ Args:
+ v0 (`torch.Tensor` or `np.ndarray`): First vector/tensor.
+ v1 (`torch.Tensor` or `np.ndarray`): Second vector/tensor.
+ t: (`float`, `torch.Tensor`, or `np.ndarray`):
+ Interpolation factor. If float, must be between 0 and 1. If np.ndarray or
+ torch.Tensor, must be one dimensional with values between 0 and 1.
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]
+ Interpolated vector/tensor between v0 and v1.
+ """
+ inputs_are_torch = False
+ t_is_float = False
+
+ if isinstance(v0, torch.Tensor):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ if isinstance(t, torch.Tensor):
+ inputs_are_torch = True
+ input_device = t.device
+ t = t.cpu().numpy()
+ elif isinstance(t, float):
+ t_is_float = True
+ t = np.array([t])
+
+ t = t[..., None]
+ v0 = v0[None, ...]
+ v1 = v1[None, ...]
+ v2 = (1 - t) * v0 + t * v1
+
+ if t_is_float and v0.ndim > 1:
+ assert v2.shape[0] == 1
+ v2 = np.squeeze(v2, axis=0)
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+def slerp(
+ v0: Union[torch.Tensor, np.ndarray],
+ v1: Union[torch.Tensor, np.ndarray],
+ t: Union[float, torch.Tensor, np.ndarray],
+ DOT_THRESHOLD=0.9995,
+) -> Union[torch.Tensor, np.ndarray]:
+ """
+ Spherical linear interpolation between two vectors/tensors.
+
+ Args:
+ v0 (`torch.Tensor` or `np.ndarray`): First vector/tensor.
+ v1 (`torch.Tensor` or `np.ndarray`): Second vector/tensor.
+ t: (`float`, `torch.Tensor`, or `np.ndarray`):
+ Interpolation factor. If float, must be between 0 and 1. If np.ndarray or
+ torch.Tensor, must be one dimensional with values between 0 and 1.
+ DOT_THRESHOLD (`float`, *optional*, default=0.9995):
+ Threshold for when to use linear interpolation instead of spherical interpolation.
+
+ Returns:
+ `torch.Tensor` or `np.ndarray`:
+ Interpolated vector/tensor between v0 and v1.
+ """
+ inputs_are_torch = False
+ t_is_float = False
+
+ if isinstance(v0, torch.Tensor):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ if isinstance(t, torch.Tensor):
+ inputs_are_torch = True
+ input_device = t.device
+ t = t.cpu().numpy()
+ elif isinstance(t, float):
+ t_is_float = True
+ t = np.array([t], dtype=v0.dtype)
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ # v1 and v2 are close to parallel
+ # Use linear interpolation instead
+ v2 = lerp(v0, v1, t)
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ s0 = s0[..., None]
+ s1 = s1[..., None]
+ v0 = v0[None, ...]
+ v1 = v1[None, ...]
+ v2 = s0 * v0 + s1 * v1
+
+ if t_is_float and v0.ndim > 1:
+ assert v2.shape[0] == 1
+ v2 = np.squeeze(v2, axis=0)
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+class LatentConsistencyModelWalkPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using a latent consistency model.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
+ supports [`LCMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ requires_safety_checker (`bool`, *optional*, defaults to `True`):
+ Whether the pipeline requires a safety checker component.
+ """
+
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: LCMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ callback_steps: int,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @torch.no_grad()
+ def interpolate_embedding(
+ self,
+ start_embedding: torch.FloatTensor,
+ end_embedding: torch.FloatTensor,
+ num_interpolation_steps: Union[int, List[int]],
+ interpolation_type: str,
+ ) -> torch.FloatTensor:
+ if interpolation_type == "lerp":
+ interpolation_fn = lerp
+ elif interpolation_type == "slerp":
+ interpolation_fn = slerp
+ else:
+ raise ValueError(
+ f"embedding_interpolation_type must be one of ['lerp', 'slerp'], got {interpolation_type}."
+ )
+
+ embedding = torch.cat([start_embedding, end_embedding])
+ steps = torch.linspace(0, 1, num_interpolation_steps, dtype=embedding.dtype).cpu().numpy()
+ steps = np.expand_dims(steps, axis=tuple(range(1, embedding.ndim)))
+ interpolations = []
+
+ # Interpolate between text embeddings
+ # TODO(aryan): Think of a better way of doing this
+ # See if it can be done parallelly instead
+ for i in range(embedding.shape[0] - 1):
+ interpolations.append(interpolation_fn(embedding[i], embedding[i + 1], steps).squeeze(dim=1))
+
+ interpolations = torch.cat(interpolations)
+ return interpolations
+
+ @torch.no_grad()
+ def interpolate_latent(
+ self,
+ start_latent: torch.FloatTensor,
+ end_latent: torch.FloatTensor,
+ num_interpolation_steps: Union[int, List[int]],
+ interpolation_type: str,
+ ) -> torch.FloatTensor:
+ if interpolation_type == "lerp":
+ interpolation_fn = lerp
+ elif interpolation_type == "slerp":
+ interpolation_fn = slerp
+
+ latent = torch.cat([start_latent, end_latent])
+ steps = torch.linspace(0, 1, num_interpolation_steps, dtype=latent.dtype).cpu().numpy()
+ steps = np.expand_dims(steps, axis=tuple(range(1, latent.ndim)))
+ interpolations = []
+
+ # Interpolate between latents
+ # TODO: Think of a better way of doing this
+ # See if it can be done parallelly instead
+ for i in range(latent.shape[0] - 1):
+ interpolations.append(interpolation_fn(latent[i], latent[i + 1], steps).squeeze(dim=1))
+
+ return torch.cat(interpolations)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 4,
+ num_interpolation_steps: int = 8,
+ original_inference_steps: int = None,
+ guidance_scale: float = 8.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ embedding_interpolation_type: str = "lerp",
+ latent_interpolation_type: str = "slerp",
+ process_batch_size: int = 4,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ original_inference_steps (`int`, *optional*):
+ The original number of inference steps use to generate a linearly-spaced timestep schedule, from which
+ we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
+ following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
+ scheduler's `original_inference_steps` attribute.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ Note that the original latent consistency models paper uses a different CFG formulation where the
+ guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale >
+ 0`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
+ embedding_interpolation_type (`str`, *optional*, defaults to `"lerp"`):
+ The type of interpolation to use for interpolating between text embeddings. Choose between `"lerp"` and `"slerp"`.
+ latent_interpolation_type (`str`, *optional*, defaults to `"slerp"`):
+ The type of interpolation to use for interpolating between latents. Choose between `"lerp"` and `"slerp"`.
+ process_batch_size (`int`, *optional*, defaults to 4):
+ The batch size to use for processing the images. This is useful when generating a large number of images
+ and you want to avoid running out of memory.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ if batch_size < 2:
+ raise ValueError(f"`prompt` must have length of atleast 2 but found {batch_size}")
+ if num_images_per_prompt != 1:
+ raise ValueError("`num_images_per_prompt` must be `1` as no other value is supported yet")
+ if prompt_embeds is not None:
+ raise ValueError("`prompt_embeds` must be None since it is not supported yet")
+ if latents is not None:
+ raise ValueError("`latents` must be None since it is not supported yet")
+
+ device = self._execution_device
+ # do_classifier_free_guidance = guidance_scale > 1.0
+
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ self.scheduler.set_timesteps(num_inference_steps, device, original_inference_steps=original_inference_steps)
+ timesteps = self.scheduler.timesteps
+ num_channels_latents = self.unet.config.in_channels
+ # bs = batch_size * num_images_per_prompt
+
+ # 3. Encode initial input prompt
+ prompt_embeds_1, _ = self.encode_prompt(
+ prompt[:1],
+ device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=False,
+ negative_prompt=None,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Prepare initial latent variables
+ latents_1 = self.prepare_latents(
+ 1,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds_1.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ images = []
+
+ # 5. Iterate over prompts and perform latent walk. Note that we do this two prompts at a time
+ # otherwise the memory usage ends up being too high.
+ with self.progress_bar(total=batch_size - 1) as prompt_progress_bar:
+ for i in range(1, batch_size):
+ # 6. Encode current prompt
+ prompt_embeds_2, _ = self.encode_prompt(
+ prompt[i : i + 1],
+ device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=False,
+ negative_prompt=None,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=None,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 7. Prepare current latent variables
+ latents_2 = self.prepare_latents(
+ 1,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds_2.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Interpolate between previous and current prompt embeddings and latents
+ inference_embeddings = self.interpolate_embedding(
+ start_embedding=prompt_embeds_1,
+ end_embedding=prompt_embeds_2,
+ num_interpolation_steps=num_interpolation_steps,
+ interpolation_type=embedding_interpolation_type,
+ )
+ inference_latents = self.interpolate_latent(
+ start_latent=latents_1,
+ end_latent=latents_2,
+ num_interpolation_steps=num_interpolation_steps,
+ interpolation_type=latent_interpolation_type,
+ )
+ next_prompt_embeds = inference_embeddings[-1:].detach().clone()
+ next_latents = inference_latents[-1:].detach().clone()
+ bs = num_interpolation_steps
+
+ # 9. Perform inference in batches. Note the use of `process_batch_size` to control the batch size
+ # of the inference. This is useful for reducing memory usage and can be configured based on the
+ # available GPU memory.
+ with self.progress_bar(
+ total=(bs + process_batch_size - 1) // process_batch_size
+ ) as batch_progress_bar:
+ for batch_index in range(0, bs, process_batch_size):
+ batch_inference_latents = inference_latents[batch_index : batch_index + process_batch_size]
+ batch_inference_embedddings = inference_embeddings[
+ batch_index : batch_index + process_batch_size
+ ]
+
+ self.scheduler.set_timesteps(
+ num_inference_steps, device, original_inference_steps=original_inference_steps
+ )
+ timesteps = self.scheduler.timesteps
+
+ current_bs = batch_inference_embedddings.shape[0]
+ w = torch.tensor(self.guidance_scale - 1).repeat(current_bs)
+ w_embedding = self.get_guidance_scale_embedding(
+ w, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents_1.dtype)
+
+ # 10. Perform inference for current batch
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for index, t in enumerate(timesteps):
+ batch_inference_latents = batch_inference_latents.to(batch_inference_embedddings.dtype)
+
+ # model prediction (v-prediction, eps, x)
+ model_pred = self.unet(
+ batch_inference_latents,
+ t,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=batch_inference_embedddings,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ batch_inference_latents, denoised = self.scheduler.step(
+ model_pred, t, batch_inference_latents, **extra_step_kwargs, return_dict=False
+ )
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, index, t, callback_kwargs)
+
+ batch_inference_latents = callback_outputs.pop("latents", batch_inference_latents)
+ batch_inference_embedddings = callback_outputs.pop(
+ "prompt_embeds", batch_inference_embedddings
+ )
+ w_embedding = callback_outputs.pop("w_embedding", w_embedding)
+ denoised = callback_outputs.pop("denoised", denoised)
+
+ # call the callback, if provided
+ if index == len(timesteps) - 1 or (
+ (index + 1) > num_warmup_steps and (index + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and index % callback_steps == 0:
+ step_idx = index // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, batch_inference_latents)
+
+ denoised = denoised.to(batch_inference_embedddings.dtype)
+
+ # Note: This is not supported because you would get black images in your latent walk if
+ # NSFW concept is detected
+ # if not output_type == "latent":
+ # image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ # image, has_nsfw_concept = self.run_safety_checker(image, device, inference_embeddings.dtype)
+ # else:
+ # image = denoised
+ # has_nsfw_concept = None
+
+ # if has_nsfw_concept is None:
+ # do_denormalize = [True] * image.shape[0]
+ # else:
+ # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ do_denormalize = [True] * image.shape[0]
+ has_nsfw_concept = None
+
+ image = self.image_processor.postprocess(
+ image, output_type=output_type, do_denormalize=do_denormalize
+ )
+ images.append(image)
+
+ batch_progress_bar.update()
+
+ prompt_embeds_1 = next_prompt_embeds
+ latents_1 = next_latents
+
+ prompt_progress_bar.update()
+
+ # 11. Determine what should be returned
+ if output_type == "pil":
+ images = [image for image_list in images for image in image_list]
+ elif output_type == "np":
+ images = np.concatenate(images)
+ elif output_type == "pt":
+ images = torch.cat(images)
+ else:
+ raise ValueError("`output_type` must be one of 'pil', 'np' or 'pt'.")
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (images, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/latent_consistency_txt2img.py b/diffusers/examples/community/latent_consistency_txt2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..85bcc2cf94cbec548af5f7f62a78909b28b752ed
--- /dev/null
+++ b/diffusers/examples/community/latent_consistency_txt2img.py
@@ -0,0 +1,728 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging
+from diffusers.configuration_utils import register_to_config
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import BaseOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class LatentConsistencyModelPipeline(DiffusionPipeline):
+ _optional_components = ["scheduler"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: "LCMScheduler",
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ scheduler = (
+ scheduler
+ if scheduler is not None
+ else LCMScheduler(
+ beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", prediction_type="epsilon"
+ )
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds: None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+
+ if prompt is not None and isinstance(prompt, str):
+ pass
+ elif prompt is not None and isinstance(prompt, list):
+ len(prompt)
+ else:
+ prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # Don't need to get uncond prompt embedding because of LCM Guided Distillation
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if latents is None:
+ latents = torch.randn(shape, dtype=dtype).to(device)
+ else:
+ latents = latents.to(device)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+ Args:
+ timesteps: torch.Tensor: generate embedding vectors at these timesteps
+ embedding_dim: int: dimension of the embeddings to generate
+ dtype: data type of the generated embeddings
+ Returns:
+ embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = 768,
+ width: Optional[int] = 768,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ latents: Optional[torch.FloatTensor] = None,
+ num_inference_steps: int = 4,
+ lcm_origin_steps: int = 50,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variable
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ latents,
+ )
+ bs = batch_size * num_images_per_prompt
+
+ # 6. Get Guidance Scale Embedding
+ w = torch.tensor(guidance_scale).repeat(bs)
+ w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)
+
+ # 7. LCM MultiStep Sampling Loop:
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ ts = torch.full((bs,), t, device=device, dtype=torch.long)
+ latents = latents.to(prompt_embeds.dtype)
+
+ # model prediction (v-prediction, eps, x)
+ model_pred = self.unet(
+ latents,
+ ts,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)
+
+ # # call the callback, if provided
+ # if i == len(timesteps) - 1:
+ progress_bar.update()
+
+ denoised = denoised.to(prompt_embeds.dtype)
+ if not output_type == "latent":
+ image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = denoised
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class LCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ denoised: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class LCMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+
+ # LCM Timesteps Setting: # Linear Spacing
+ c = self.config.num_train_timesteps // lcm_origin_steps
+ lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1 # LCM Training Steps Schedule
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] # LCM Inference Steps Schedule
+
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device)
+
+ def get_scalings_for_boundary_condition_discrete(self, t):
+ self.sigma_data = 0.5 # Default: 0.5
+
+ # By dividing 0.1: This is almost a delta function at t=0.
+ c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
+ c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timeindex: int,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[LCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.FloatTensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # 1. get previous step value
+ prev_timeindex = timeindex + 1
+ if prev_timeindex < len(self.timesteps):
+ prev_timestep = self.timesteps[prev_timeindex]
+ else:
+ prev_timestep = timestep
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 3. Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
+
+ # 4. Different Parameterization:
+ parameterization = self.config.prediction_type
+
+ if parameterization == "epsilon": # noise-prediction
+ pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
+
+ elif parameterization == "sample": # x-prediction
+ pred_x0 = model_output
+
+ elif parameterization == "v_prediction": # v-prediction
+ pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
+
+ # 4. Denoise model output using boundary conditions
+ denoised = c_out * pred_x0 + c_skip * sample
+
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
+ # Noise is not used for one-step sampling.
+ if len(self.timesteps) > 1:
+ noise = torch.randn(model_output.shape).to(model_output.device)
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
+ else:
+ prev_sample = denoised
+
+ if not return_dict:
+ return (prev_sample, denoised)
+
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/examples/community/llm_grounded_diffusion.py b/diffusers/examples/community/llm_grounded_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47c99bb2990ebf34b68d769a2975da643dc96d2
--- /dev/null
+++ b/diffusers/examples/community/llm_grounded_diffusion.py
@@ -0,0 +1,1015 @@
+# Copyright 2023 Long Lian, the GLIGEN Authors, and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This is a single file implementation of LMD+. See README.md for examples.
+
+import ast
+import gc
+import math
+import warnings
+from collections.abc import Iterable
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention import Attention, GatedSelfAttentionDense
+from diffusers.models.attention_processor import AttnProcessor2_0
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import logging, replace_example_docstring
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "longlian/lmd_plus",
+ ... custom_pipeline="llm_grounded_diffusion",
+ ... variant="fp16", torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # Generate an image described by the prompt and
+ >>> # insert objects described by text at the region defined by bounding boxes
+ >>> prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+ >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
+ >>> phrases = ["a waterfall", "a modern high speed train"]
+
+ >>> images = pipe(
+ ... prompt=prompt,
+ ... phrases=phrases,
+ ... boxes=boxes,
+ ... gligen_scheduled_sampling_beta=0.4,
+ ... output_type="pil",
+ ... num_inference_steps=50,
+ ... lmd_guidance_kwargs={}
+ ... ).images
+
+ >>> images[0].save("./lmd_plus_generation.jpg")
+
+ >>> # Generate directly from a text prompt and an LLM response
+ >>> prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
+ >>> phrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response(\"""
+ [('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]
+ Background prompt: A beautiful forest with fall foliage
+ Negative prompt:
+ \""")
+
+ >> images = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=neg_prompt,
+ ... phrases=phrases,
+ ... boxes=boxes,
+ ... gligen_scheduled_sampling_beta=0.4,
+ ... output_type="pil",
+ ... num_inference_steps=50,
+ ... lmd_guidance_kwargs={}
+ ... ).images
+
+ >>> images[0].save("./lmd_plus_generation.jpg")
+
+images[0]
+
+ ```
+"""
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
+# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
+DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
+
+
+def convert_attn_keys(key):
+ """Convert the attention key from tuple format to the torch state format"""
+
+ if key[0] == "mid":
+ assert key[1] == 0, f"mid block only has one block but the index is {key[1]}"
+ return f"{key[0]}_block.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor"
+
+ return f"{key[0]}_blocks.{key[1]}.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor"
+
+
+DEFAULT_GUIDANCE_ATTN_KEYS = [convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS]
+
+
+def scale_proportion(obj_box, H, W):
+ # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5".
+ x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H)
+ box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H)
+ x_max, y_max = x_min + box_w, y_min + box_h
+
+ x_min, y_min = max(x_min, 0), max(y_min, 0)
+ x_max, y_max = min(x_max, W), min(y_max, H)
+
+ return x_min, y_min, x_max, y_max
+
+
+# Adapted from the parent class `AttnProcessor2_0`
+class AttnProcessorWithHook(AttnProcessor2_0):
+ def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
+ super().__init__()
+ self.attn_processor_key = attn_processor_key
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.hook = hook
+ self.fast_attn = fast_attn
+ self.enabled = enabled
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ scale: float = 1.0,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states, scale=scale)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states, scale=scale)
+ value = attn.to_v(encoder_hidden_states, scale=scale)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ if (self.hook is not None and self.enabled) or not self.fast_attn:
+ query_batch_dim = attn.head_to_batch_dim(query)
+ key_batch_dim = attn.head_to_batch_dim(key)
+ value_batch_dim = attn.head_to_batch_dim(value)
+ attention_probs = attn.get_attention_scores(query_batch_dim, key_batch_dim, attention_mask)
+
+ if self.hook is not None and self.enabled:
+ # Call the hook with query, key, value, and attention maps
+ self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
+
+ if self.fast_attn:
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attention_mask is not None:
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
+ r"""
+ Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
+
+ This model inherits from [`StableDiffusionPipeline`] and aims at implementing the pipeline with minimal modifications. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ This is a simplified implementation that does not perform latent or attention transfer from single object generation to overall generation. The final image is generated directly with attention and adapters control.
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ requires_safety_checker (bool):
+ Whether a safety checker is needed for this pipeline.
+ """
+
+ objects_text = "Objects: "
+ bg_prompt_text = "Background prompt: "
+ bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
+ neg_prompt_text = "Negative prompt: "
+ neg_prompt_text_no_trailing_space = neg_prompt_text.rstrip()
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ )
+
+ self.register_attn_hooks(unet)
+ self._saved_attn = None
+
+ def attn_hook(self, name, query, key, value, attention_probs):
+ if name in DEFAULT_GUIDANCE_ATTN_KEYS:
+ self._saved_attn[name] = attention_probs
+
+ @classmethod
+ def convert_box(cls, box, height, width):
+ # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max
+ x_min, y_min = box[0] / width, box[1] / height
+ w_box, h_box = box[2] / width, box[3] / height
+
+ x_max, y_max = x_min + w_box, y_min + h_box
+
+ return x_min, y_min, x_max, y_max
+
+ @classmethod
+ def _parse_response_with_negative(cls, text):
+ if not text:
+ raise ValueError("LLM response is empty")
+
+ if cls.objects_text in text:
+ text = text.split(cls.objects_text)[1]
+
+ text_split = text.split(cls.bg_prompt_text_no_trailing_space)
+ if len(text_split) == 2:
+ gen_boxes, text_rem = text_split
+ else:
+ raise ValueError(f"LLM response is incomplete: {text}")
+
+ text_split = text_rem.split(cls.neg_prompt_text_no_trailing_space)
+
+ if len(text_split) == 2:
+ bg_prompt, neg_prompt = text_split
+ else:
+ raise ValueError(f"LLM response is incomplete: {text}")
+
+ try:
+ gen_boxes = ast.literal_eval(gen_boxes)
+ except SyntaxError as e:
+ # Sometimes the response is in plain text
+ if "No objects" in gen_boxes or gen_boxes.strip() == "":
+ gen_boxes = []
+ else:
+ raise e
+ bg_prompt = bg_prompt.strip()
+ neg_prompt = neg_prompt.strip()
+
+ # LLM may return "None" to mean no negative prompt provided.
+ if neg_prompt == "None":
+ neg_prompt = ""
+
+ return gen_boxes, bg_prompt, neg_prompt
+
+ @classmethod
+ def parse_llm_response(cls, response, canvas_height=512, canvas_width=512):
+ # Infer from spec
+ gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative(text=response)
+
+ gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0])
+
+ phrases = [name for name, _ in gen_boxes]
+ boxes = [cls.convert_box(box, height=canvas_height, width=canvas_width) for _, box in gen_boxes]
+
+ return phrases, boxes, bg_prompt, neg_prompt
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ phrases,
+ boxes,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ phrase_indices=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt is None and phrase_indices is None:
+ raise ValueError("If the prompt is None, the phrase_indices cannot be None")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if len(phrases) != len(boxes):
+ ValueError(
+ "length of `phrases` and `boxes` has to be same, but"
+ f" got: `phrases` {len(phrases)} != `boxes` {len(boxes)}"
+ )
+
+ def register_attn_hooks(self, unet):
+ """Registering hooks to obtain the attention maps for guidance"""
+
+ attn_procs = {}
+
+ for name in unet.attn_processors.keys():
+ # Only obtain the queries and keys from cross-attention
+ if name.endswith("attn1.processor") or name.endswith("fuser.attn.processor"):
+ # Keep the same attn_processors for self-attention (no hooks for self-attention)
+ attn_procs[name] = unet.attn_processors[name]
+ continue
+
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ attn_procs[name] = AttnProcessorWithHook(
+ attn_processor_key=name,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ hook=self.attn_hook,
+ fast_attn=True,
+ # Not enabled by default
+ enabled=False,
+ )
+
+ unet.set_attn_processor(attn_procs)
+
+ def enable_fuser(self, enabled=True):
+ for module in self.unet.modules():
+ if isinstance(module, GatedSelfAttentionDense):
+ module.enabled = enabled
+
+ def enable_attn_hook(self, enabled=True):
+ for module in self.unet.attn_processors.values():
+ if isinstance(module, AttnProcessorWithHook):
+ module.enabled = enabled
+
+ def get_token_map(self, prompt, padding="do_not_pad", verbose=False):
+ """Get a list of mapping: prompt index to str (prompt in a list of token str)"""
+ fg_prompt_tokens = self.tokenizer([prompt], padding=padding, max_length=77, return_tensors="np")
+ input_ids = fg_prompt_tokens["input_ids"][0]
+
+ token_map = []
+ for ind, item in enumerate(input_ids.tolist()):
+ token = self.tokenizer._convert_id_to_token(item)
+
+ if verbose:
+ logger.info(f"{ind}, {token} ({item})")
+
+ token_map.append(token)
+
+ return token_map
+
+ def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
+ for obj in phrases:
+ # Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
+ if obj not in prompt:
+ prompt += "| " + obj
+
+ if token_map is None:
+ # We allow using a pre-computed token map.
+ token_map = self.get_token_map(prompt=prompt, padding="do_not_pad", verbose=verbose)
+ token_map_str = " ".join(token_map)
+
+ phrase_indices = []
+
+ for obj in phrases:
+ phrase_token_map = self.get_token_map(prompt=obj, padding="do_not_pad", verbose=verbose)
+ # Remove and in substr
+ phrase_token_map = phrase_token_map[1:-1]
+ phrase_token_map_len = len(phrase_token_map)
+ phrase_token_map_str = " ".join(phrase_token_map)
+
+ if verbose:
+ logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
+
+ # Count the number of token before substr
+ # The substring comes with a trailing space that needs to be removed by minus one in the index.
+ obj_first_index = len(token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split(" "))
+
+ obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len))
+ phrase_indices.append(obj_position)
+
+ if add_suffix_if_not_found:
+ return phrase_indices, prompt
+
+ return phrase_indices
+
+ def add_ca_loss_per_attn_map_to_loss(
+ self,
+ loss,
+ attn_map,
+ object_number,
+ bboxes,
+ phrase_indices,
+ fg_top_p=0.2,
+ bg_top_p=0.2,
+ fg_weight=1.0,
+ bg_weight=1.0,
+ ):
+ # b is the number of heads, not batch
+ b, i, j = attn_map.shape
+ H = W = int(math.sqrt(i))
+ for obj_idx in range(object_number):
+ obj_loss = 0
+ mask = torch.zeros(size=(H, W), device="cuda")
+ obj_boxes = bboxes[obj_idx]
+
+ # We support two level (one box per phrase) and three level (multiple boxes per phrase)
+ if not isinstance(obj_boxes[0], Iterable):
+ obj_boxes = [obj_boxes]
+
+ for obj_box in obj_boxes:
+ # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
+ x_min, y_min, x_max, y_max = scale_proportion(obj_box, H=H, W=W)
+ mask[y_min:y_max, x_min:x_max] = 1
+
+ for obj_position in phrase_indices[obj_idx]:
+ # Could potentially optimize to compute this for loop in batch.
+ # Could crop the ref cross attention before saving to save memory.
+
+ ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
+
+ # shape: (b, H * W)
+ ca_map_obj = attn_map[:, :, obj_position] # .reshape(b, H, W)
+ k_fg = (mask.sum() * fg_top_p).long().clamp_(min=1)
+ k_bg = ((1 - mask).sum() * bg_top_p).long().clamp_(min=1)
+
+ mask_1d = mask.view(1, -1)
+
+ # Max-based loss function
+
+ # Take the topk over spatial dimension, and then take the sum over heads dim
+ # The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own.
+ obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight
+ obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight
+
+ loss += obj_loss / len(phrase_indices[obj_idx])
+
+ return loss
+
+ def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
+ """
+ The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
+ `AttnProcessor` will put attention maps into the `save_attn_to_dict`.
+
+ `index` is the timestep.
+ `ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).
+ `ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.
+ """
+ loss = torch.tensor(0).float().cuda()
+ object_number = len(bboxes)
+ if object_number == 0:
+ return loss
+
+ for attn_key in guidance_attn_keys:
+ # We only have 1 cross attention for mid.
+
+ attn_map_integrated = saved_attn[attn_key]
+ if not attn_map_integrated.is_cuda:
+ attn_map_integrated = attn_map_integrated.cuda()
+ # Example dimension: [20, 64, 77]
+ attn_map = attn_map_integrated.squeeze(dim=0)
+
+ loss = self.add_ca_loss_per_attn_map_to_loss(
+ loss, attn_map, object_number, bboxes, phrase_indices, **kwargs
+ )
+
+ num_attn = len(guidance_attn_keys)
+
+ if num_attn > 0:
+ loss = loss / (object_number * num_attn)
+
+ return loss
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ gligen_scheduled_sampling_beta: float = 0.3,
+ phrases: List[str] = None,
+ boxes: List[List[float]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: Optional[int] = None,
+ lmd_guidance_kwargs: Optional[Dict[str, Any]] = {},
+ phrase_indices: Optional[List[int]] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ phrases (`List[str]`):
+ The phrases to guide what to include in each of the regions defined by the corresponding
+ `boxes`. There should only be one phrase per bounding box.
+ boxes (`List[List[float]]`):
+ The bounding boxes that identify rectangular regions of the image that are going to be filled with the
+ content described by the corresponding `phrases`. Each rectangular box is defined as a
+ `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1].
+ gligen_scheduled_sampling_beta (`float`, defaults to 0.3):
+ Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image
+ Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for
+ scheduled sampling during inference for improved quality and controllability.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ lmd_guidance_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to `latent_lmd_guidance` function. Useful keys include `loss_scale` (the guidance strength), `loss_threshold` (when loss is lower than this value, the guidance is not applied anymore), `max_iter` (the number of iterations of guidance for each step), and `guidance_timesteps` (the number of diffusion timesteps to apply guidance on). See `latent_lmd_guidance` for implementation details.
+ phrase_indices (`list` of `list`, *optional*): The indices of the tokens of each phrase in the overall prompt. If omitted, the pipeline will match the first token subsequence. The pipeline will append the missing phrases to the end of the prompt by default.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ phrases,
+ boxes,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ phrase_indices,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ if phrase_indices is None:
+ phrase_indices, prompt = self.get_phrase_indices(prompt, phrases, add_suffix_if_not_found=True)
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ if phrase_indices is None:
+ phrase_indices = []
+ prompt_parsed = []
+ for prompt_item in prompt:
+ phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
+ prompt_item, add_suffix_if_not_found=True
+ )
+ phrase_indices.append(phrase_indices_parsed_item)
+ prompt_parsed.append(prompt_parsed_item)
+ prompt = prompt_parsed
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ clip_skip=clip_skip,
+ )
+
+ cond_prompt_embeds = prompt_embeds
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5.1 Prepare GLIGEN variables
+ max_objs = 30
+ if len(boxes) > max_objs:
+ warnings.warn(
+ f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.",
+ FutureWarning,
+ )
+ phrases = phrases[:max_objs]
+ boxes = boxes[:max_objs]
+
+ n_objs = len(boxes)
+ if n_objs:
+ # prepare batched input to the PositionNet (boxes, phrases, mask)
+ # Get tokens for phrases from pre-trained CLIPTokenizer
+ tokenizer_inputs = self.tokenizer(phrases, padding=True, return_tensors="pt").to(device)
+ # For the token, we use the same pre-trained text encoder
+ # to obtain its text feature
+ _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output
+
+ # For each entity, described in phrases, is denoted with a bounding box,
+ # we represent the location information as (xmin,ymin,xmax,ymax)
+ cond_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)
+ if n_objs:
+ cond_boxes[:n_objs] = torch.tensor(boxes)
+ text_embeddings = torch.zeros(
+ max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
+ )
+ if n_objs:
+ text_embeddings[:n_objs] = _text_embeddings
+ # Generate a mask for each object that is entity described by phrases
+ masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype)
+ masks[:n_objs] = 1
+
+ repeat_batch = batch_size * num_images_per_prompt
+ cond_boxes = cond_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
+ text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
+ masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
+ if do_classifier_free_guidance:
+ repeat_batch = repeat_batch * 2
+ cond_boxes = torch.cat([cond_boxes] * 2)
+ text_embeddings = torch.cat([text_embeddings] * 2)
+ masks = torch.cat([masks] * 2)
+ masks[: repeat_batch // 2] = 0
+ if cross_attention_kwargs is None:
+ cross_attention_kwargs = {}
+ cross_attention_kwargs["gligen"] = {
+ "boxes": cond_boxes,
+ "positive_embeddings": text_embeddings,
+ "masks": masks,
+ }
+
+ num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
+ self.enable_fuser(True)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ loss_attn = torch.tensor(10000.0)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Scheduled sampling
+ if i == num_grounding_steps:
+ self.enable_fuser(False)
+
+ if latents.shape[1] != 4:
+ latents = torch.randn_like(latents[:, :4])
+
+ # 7.1 Perform LMD guidance
+ if boxes:
+ latents, loss_attn = self.latent_lmd_guidance(
+ cond_prompt_embeds,
+ index=i,
+ boxes=boxes,
+ phrase_indices=phrase_indices,
+ t=t,
+ latents=latents,
+ loss=loss_attn,
+ **lmd_guidance_kwargs,
+ )
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ @torch.set_grad_enabled(True)
+ def latent_lmd_guidance(
+ self,
+ cond_embeddings,
+ index,
+ boxes,
+ phrase_indices,
+ t,
+ latents,
+ loss,
+ *,
+ loss_scale=20,
+ loss_threshold=5.0,
+ max_iter=[3] * 5 + [2] * 5 + [1] * 5,
+ guidance_timesteps=15,
+ cross_attention_kwargs=None,
+ guidance_attn_keys=DEFAULT_GUIDANCE_ATTN_KEYS,
+ verbose=False,
+ clear_cache=False,
+ unet_additional_kwargs={},
+ guidance_callback=None,
+ **kwargs,
+ ):
+ scheduler, unet = self.scheduler, self.unet
+
+ iteration = 0
+
+ if index < guidance_timesteps:
+ if isinstance(max_iter, list):
+ max_iter = max_iter[index]
+
+ if verbose:
+ logger.info(
+ f"time index {index}, loss: {loss.item()/loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}"
+ )
+
+ try:
+ self.enable_attn_hook(enabled=True)
+
+ while (
+ loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < guidance_timesteps
+ ):
+ self._saved_attn = {}
+
+ latents.requires_grad_(True)
+ latent_model_input = latents
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
+
+ unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=cond_embeddings,
+ cross_attention_kwargs=cross_attention_kwargs,
+ **unet_additional_kwargs,
+ )
+
+ # update latents with guidance
+ loss = (
+ self.compute_ca_loss(
+ saved_attn=self._saved_attn,
+ bboxes=boxes,
+ phrase_indices=phrase_indices,
+ guidance_attn_keys=guidance_attn_keys,
+ verbose=verbose,
+ **kwargs,
+ )
+ * loss_scale
+ )
+
+ if torch.isnan(loss):
+ raise RuntimeError("**Loss is NaN**")
+
+ # This callback allows visualizations.
+ if guidance_callback is not None:
+ guidance_callback(self, latents, loss, iteration, index)
+
+ self._saved_attn = None
+
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
+
+ latents.requires_grad_(False)
+
+ # Scaling with classifier guidance
+ alpha_prod_t = scheduler.alphas_cumprod[t]
+ # Classifier guidance: https://arxiv.org/pdf/2105.05233.pdf
+ # DDIM: https://arxiv.org/pdf/2010.02502.pdf
+ scale = (1 - alpha_prod_t) ** (0.5)
+ latents = latents - scale * grad_cond
+
+ iteration += 1
+
+ if clear_cache:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ if verbose:
+ logger.info(
+ f"time index {index}, loss: {loss.item()/loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}"
+ )
+
+ finally:
+ self.enable_attn_hook(enabled=False)
+
+ return latents, loss
diff --git a/diffusers/examples/community/lpw_stable_diffusion.py b/diffusers/examples/community/lpw_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7249e033186f4bba2e42e324cd7628e16e5762e6
--- /dev/null
+++ b/diffusers/examples/community/lpw_stable_diffusion.py
@@ -0,0 +1,1471 @@
+import inspect
+import re
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+re_attention = re.compile(
+ r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+ re.X,
+)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \\( - literal character '('
+ \\[ - literal character '['
+ \\) - literal character ')'
+ \\] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\\(literal\\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
+ r"""
+ Tokenize a list of prompts and return its tokens with weights of each token.
+
+ No padding, starting or ending token is included.
+ """
+ tokens = []
+ weights = []
+ truncated = False
+ for text in prompt:
+ texts_and_weights = parse_prompt_attention(text)
+ text_token = []
+ text_weight = []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = pipe.tokenizer(word).input_ids[1:-1]
+ text_token += token
+ # copy the weight by length of token
+ text_weight += [weight] * len(token)
+ # stop if the text is too long (longer than truncation limit)
+ if len(text_token) > max_length:
+ truncated = True
+ break
+ # truncate
+ if len(text_token) > max_length:
+ truncated = True
+ text_token = text_token[:max_length]
+ text_weight = text_weight[:max_length]
+ tokens.append(text_token)
+ weights.append(text_weight)
+ if truncated:
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+ return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+ r"""
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+ """
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+ for i in range(len(tokens)):
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
+ if no_boseos_middle:
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+ else:
+ w = []
+ if len(weights[i]) == 0:
+ w = [1.0] * weights_length
+ else:
+ for j in range(max_embeddings_multiples):
+ w.append(1.0) # weight for starting token in this chunk
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+ w.append(1.0) # weight for ending token in this chunk
+ w += [1.0] * (weights_length - len(w))
+ weights[i] = w[:]
+
+ return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+ pipe: DiffusionPipeline,
+ text_input: torch.Tensor,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ text_embeddings = torch.concat(text_embeddings, axis=1)
+ else:
+ text_embeddings = pipe.text_encoder(text_input)[0]
+ return text_embeddings
+
+
+def get_weighted_text_embeddings(
+ pipe: DiffusionPipeline,
+ prompt: Union[str, List[str]],
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ no_boseos_middle: Optional[bool] = False,
+ skip_parsing: Optional[bool] = False,
+ skip_weighting: Optional[bool] = False,
+):
+ r"""
+ Prompts can be assigned with local weights using brackets. For example,
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+ Args:
+ pipe (`DiffusionPipeline`):
+ Pipe to provide access to the tokenizer and the text encoder.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ uncond_prompt (`str` or `List[str]`):
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+ ending token in each of the chunk in the middle.
+ skip_parsing (`bool`, *optional*, defaults to `False`):
+ Skip the parsing of brackets.
+ skip_weighting (`bool`, *optional*, defaults to `False`):
+ Skip the weighting. When the parsing is skipped, it is forced True.
+ """
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ if not skip_parsing:
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+ else:
+ prompt_tokens = [
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
+ ]
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
+ ]
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
+ max_length = max([len(token) for token in prompt_tokens])
+ if uncond_prompt is not None:
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+ max_embeddings_multiples = min(
+ max_embeddings_multiples,
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+ )
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+ # pad the length of tokens and weights
+ bos = pipe.tokenizer.bos_token_id
+ eos = pipe.tokenizer.eos_token_id
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
+ prompt_tokens,
+ prompt_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
+ if uncond_prompt is not None:
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
+ uncond_tokens,
+ uncond_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
+
+ # get the embeddings
+ text_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ prompt_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
+ if uncond_prompt is not None:
+ uncond_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ uncond_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
+
+ # assign weights to the prompts and normalize in the sense of mean
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
+ if (not skip_parsing) and (not skip_weighting):
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+ text_embeddings *= prompt_weights.unsqueeze(-1)
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+ if uncond_prompt is not None:
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+
+ if uncond_prompt is not None:
+ return text_embeddings, uncond_embeddings
+ return text_embeddings, None
+
+
+def preprocess_image(image, batch_size):
+ w, h = image.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, batch_size, scale_factor=8):
+ if not isinstance(mask, torch.FloatTensor):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = np.vstack([mask[None]] * batch_size)
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+ else:
+ valid_mask_channel_sizes = [1, 3]
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
+ if mask.shape[3] in valid_mask_channel_sizes:
+ mask = mask.permute(0, 3, 1, 2)
+ elif mask.shape[1] not in valid_mask_channel_sizes:
+ raise ValueError(
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
+ f" but received mask of shape {tuple(mask.shape)}"
+ )
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
+ mask = mask.mean(dim=1, keepdim=True)
+ h, w = mask.shape[-2:]
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
+ return mask
+
+
+class StableDiffusionLongPromptWeightingPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
+ weighting in prompt.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(
+ requires_safety_checker=requires_safety_checker,
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ max_embeddings_multiples=3,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if negative_prompt_embeds is None:
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ if batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ if prompt_embeds is None or negative_prompt_embeds is None:
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
+
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
+ pipe=self,
+ prompt=prompt,
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+ max_embeddings_multiples=max_embeddings_multiples,
+ )
+ if prompt_embeds is None:
+ prompt_embeds = prompt_embeds1
+ if negative_prompt_embeds is None:
+ negative_prompt_embeds = negative_prompt_embeds1
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
+ if is_text2img:
+ return self.scheduler.timesteps.to(device), num_inference_steps
+ else:
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ num_images_per_prompt,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if image is None:
+ batch_size = batch_size * num_images_per_prompt
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents, None, None
+ else:
+ image = image.to(device=self.device, dtype=dtype)
+ init_latent_dist = self.vae.encode(image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+ init_latents_orig = init_latents
+
+ # add noise to latents using the timesteps
+ noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+ return latents, init_latents_orig, noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ num_images_per_prompt: Optional[int] = 1,
+ add_predicted_noise: Optional[bool] = False,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ add_predicted_noise (`bool`, *optional*, defaults to True):
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
+ the reverse diffusion process
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ dtype = prompt_embeds.dtype
+
+ # 4. Preprocess image and mask
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess_image(image, batch_size)
+ if image is not None:
+ image = image.to(device=self.device, dtype=dtype)
+ if isinstance(mask_image, PIL.Image.Image):
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
+ if mask_image is not None:
+ mask = mask_image.to(device=self.device, dtype=dtype)
+ mask = torch.cat([mask] * num_images_per_prompt)
+ else:
+ mask = None
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents, init_latents_orig, noise = self.prepare_latents(
+ image,
+ latent_timestep,
+ num_images_per_prompt,
+ batch_size,
+ self.unet.config.in_channels,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if mask is not None:
+ # masking
+ if add_predicted_noise:
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
+ )
+ else:
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if i % callback_steps == 0:
+ if callback is not None:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+ if is_cancelled_callback is not None and is_cancelled_callback():
+ return None
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 11. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return image, has_nsfw_concept
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function for text-to-image generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ is_cancelled_callback=is_cancelled_callback,
+ callback_steps=callback_steps,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ def img2img(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function for image-to-image generation.
+ Args:
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ is_cancelled_callback=is_cancelled_callback,
+ callback_steps=callback_steps,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ def inpaint(
+ self,
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ add_predicted_noise: Optional[bool] = False,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function for inpaint.
+ Args:
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ add_predicted_noise (`bool`, *optional*, defaults to True):
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
+ the reverse diffusion process
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ add_predicted_noise=add_predicted_noise,
+ eta=eta,
+ generator=generator,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ is_cancelled_callback=is_cancelled_callback,
+ callback_steps=callback_steps,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
diff --git a/diffusers/examples/community/lpw_stable_diffusion_onnx.py b/diffusers/examples/community/lpw_stable_diffusion_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c2944dbc44a063b33c5a9be9ff7bbdbcb77544
--- /dev/null
+++ b/diffusers/examples/community/lpw_stable_diffusion_onnx.py
@@ -0,0 +1,1148 @@
+import inspect
+import re
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTokenizer
+
+import diffusers
+from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import logging
+
+
+try:
+ from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
+except ImportError:
+ ORT_TO_NP_TYPE = {
+ "tensor(bool)": np.bool_,
+ "tensor(int8)": np.int8,
+ "tensor(uint8)": np.uint8,
+ "tensor(int16)": np.int16,
+ "tensor(uint16)": np.uint16,
+ "tensor(int32)": np.int32,
+ "tensor(uint32)": np.uint32,
+ "tensor(int64)": np.int64,
+ "tensor(uint64)": np.uint64,
+ "tensor(float16)": np.float16,
+ "tensor(float)": np.float32,
+ "tensor(double)": np.float64,
+ }
+
+try:
+ from diffusers.utils import PIL_INTERPOLATION
+except ImportError:
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+ else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+re_attention = re.compile(
+ r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+ re.X,
+)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \\( - literal character '('
+ \\[ - literal character '['
+ \\) - literal character ')'
+ \\] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\\(literal\\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
+ r"""
+ Tokenize a list of prompts and return its tokens with weights of each token.
+
+ No padding, starting or ending token is included.
+ """
+ tokens = []
+ weights = []
+ truncated = False
+ for text in prompt:
+ texts_and_weights = parse_prompt_attention(text)
+ text_token = []
+ text_weight = []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
+ text_token += list(token)
+ # copy the weight by length of token
+ text_weight += [weight] * len(token)
+ # stop if the text is too long (longer than truncation limit)
+ if len(text_token) > max_length:
+ truncated = True
+ break
+ # truncate
+ if len(text_token) > max_length:
+ truncated = True
+ text_token = text_token[:max_length]
+ text_weight = text_weight[:max_length]
+ tokens.append(text_token)
+ weights.append(text_weight)
+ if truncated:
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+ return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+ r"""
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+ """
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+ for i in range(len(tokens)):
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
+ if no_boseos_middle:
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+ else:
+ w = []
+ if len(weights[i]) == 0:
+ w = [1.0] * weights_length
+ else:
+ for j in range(max_embeddings_multiples):
+ w.append(1.0) # weight for starting token in this chunk
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+ w.append(1.0) # weight for ending token in this chunk
+ w += [1.0] * (weights_length - len(w))
+ weights[i] = w[:]
+
+ return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+ pipe,
+ text_input: np.array,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+
+ text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ text_embeddings = np.concatenate(text_embeddings, axis=1)
+ else:
+ text_embeddings = pipe.text_encoder(input_ids=text_input)[0]
+ return text_embeddings
+
+
+def get_weighted_text_embeddings(
+ pipe,
+ prompt: Union[str, List[str]],
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
+ max_embeddings_multiples: Optional[int] = 4,
+ no_boseos_middle: Optional[bool] = False,
+ skip_parsing: Optional[bool] = False,
+ skip_weighting: Optional[bool] = False,
+ **kwargs,
+):
+ r"""
+ Prompts can be assigned with local weights using brackets. For example,
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+ Args:
+ pipe (`OnnxStableDiffusionPipeline`):
+ Pipe to provide access to the tokenizer and the text encoder.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ uncond_prompt (`str` or `List[str]`):
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+ ending token in each of the chunk in the middle.
+ skip_parsing (`bool`, *optional*, defaults to `False`):
+ Skip the parsing of brackets.
+ skip_weighting (`bool`, *optional*, defaults to `False`):
+ Skip the weighting. When the parsing is skipped, it is forced True.
+ """
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ if not skip_parsing:
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+ else:
+ prompt_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors="np").input_ids
+ ]
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(
+ uncond_prompt,
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ ).input_ids
+ ]
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
+ max_length = max([len(token) for token in prompt_tokens])
+ if uncond_prompt is not None:
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+ max_embeddings_multiples = min(
+ max_embeddings_multiples,
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+ )
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+ # pad the length of tokens and weights
+ bos = pipe.tokenizer.bos_token_id
+ eos = pipe.tokenizer.eos_token_id
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
+ prompt_tokens,
+ prompt_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ prompt_tokens = np.array(prompt_tokens, dtype=np.int32)
+ if uncond_prompt is not None:
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
+ uncond_tokens,
+ uncond_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ uncond_tokens = np.array(uncond_tokens, dtype=np.int32)
+
+ # get the embeddings
+ text_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ prompt_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
+ if uncond_prompt is not None:
+ uncond_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ uncond_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
+
+ # assign weights to the prompts and normalize in the sense of mean
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
+ if (not skip_parsing) and (not skip_weighting):
+ previous_mean = text_embeddings.mean(axis=(-2, -1))
+ text_embeddings *= prompt_weights[:, :, None]
+ text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]
+ if uncond_prompt is not None:
+ previous_mean = uncond_embeddings.mean(axis=(-2, -1))
+ uncond_embeddings *= uncond_weights[:, :, None]
+ uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if uncond_prompt is not None:
+ return text_embeddings, uncond_embeddings
+
+ return text_embeddings
+
+
+def preprocess_image(image):
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask, scale_factor=8):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ return mask
+
+
+class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
+ weighting in prompt.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ """
+
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: SchedulerMixin,
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker,
+ )
+ self.__init__additional__()
+
+ else:
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: SchedulerMixin,
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.__init__additional__()
+
+ def __init__additional__(self):
+ self.unet.config.in_channels = 4
+ self.vae_scale_factor = 8
+
+ def _encode_prompt(
+ self,
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ if batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
+ pipe=self,
+ prompt=prompt,
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+ max_embeddings_multiples=max_embeddings_multiples,
+ )
+
+ text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)
+ if do_classifier_free_guidance:
+ uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, is_text2img):
+ if is_text2img:
+ return self.scheduler.timesteps, num_inference_steps
+ else:
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+ return timesteps, num_inference_steps - t_start
+
+ def run_safety_checker(self, image):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="np"
+ ).pixel_values.astype(image.dtype)
+ # There will throw an error if use safety_checker directly and batchsize>1
+ images, has_nsfw_concept = [], []
+ for i in range(image.shape[0]):
+ image_i, has_nsfw_concept_i = self.safety_checker(
+ clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
+ )
+ images.append(image_i)
+ has_nsfw_concept.append(has_nsfw_concept_i[0])
+ image = np.concatenate(images)
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ # image = self.vae_decoder(latent_sample=latents)[0]
+ # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
+ image = np.concatenate(
+ [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
+ )
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):
+ if image is None:
+ shape = (
+ batch_size,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+
+ if latents is None:
+ latents = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()
+ return latents, None, None
+ else:
+ init_latents = self.vae_encoder(sample=image)[0]
+ init_latents = 0.18215 * init_latents
+ init_latents = np.concatenate([init_latents] * batch_size, axis=0)
+ init_latents_orig = init_latents
+ shape = init_latents.shape
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(shape, generator=generator, device="cpu").numpy().astype(dtype)
+ latents = self.scheduler.add_noise(
+ torch.from_numpy(init_latents), torch.from_numpy(noise), timestep
+ ).numpy()
+ return latents, init_latents_orig, noise
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ image: Union[np.ndarray, PIL.Image.Image] = None,
+ mask_image: Union[np.ndarray, PIL.Image.Image] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ strength: float = 0.8,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[np.ndarray] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ is_cancelled_callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. If the function returns
+ `True`, the inference will be cancelled.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ `None` if cancelled by `is_cancelled_callback`,
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ )
+ dtype = text_embeddings.dtype
+
+ # 4. Preprocess image and mask
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess_image(image)
+ if image is not None:
+ image = image.astype(dtype)
+ if isinstance(mask_image, PIL.Image.Image):
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
+ if mask_image is not None:
+ mask = mask_image.astype(dtype)
+ mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
+ else:
+ mask = None
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+ timestep_dtype = next(
+ (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
+ )
+ timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents, init_latents_orig, noise = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
+ latent_model_input = latent_model_input.numpy()
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=np.array([t], dtype=timestep_dtype),
+ encoder_hidden_states=text_embeddings,
+ )
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = self.scheduler.step(
+ torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
+ )
+ latents = scheduler_output.prev_sample.numpy()
+
+ if mask is not None:
+ # masking
+ init_latents_proper = self.scheduler.add_noise(
+ torch.from_numpy(init_latents_orig),
+ torch.from_numpy(noise),
+ t,
+ ).numpy()
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # call the callback, if provided
+ if i % callback_steps == 0:
+ if callback is not None:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+ if is_cancelled_callback is not None and is_cancelled_callback():
+ return None
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image)
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return image, has_nsfw_concept
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[np.ndarray] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function for text-to-image generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`np.ndarray`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ def img2img(
+ self,
+ image: Union[np.ndarray, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function for image-to-image generation.
+ Args:
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or ndarray representing an image batch, that will be used as the starting point for the
+ process.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ def inpaint(
+ self,
+ image: Union[np.ndarray, PIL.Image.Image],
+ mask_image: Union[np.ndarray, PIL.Image.Image],
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function for inpaint.
+ Args:
+ image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`np.ndarray` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ return self.__call__(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ strength=strength,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ max_embeddings_multiples=max_embeddings_multiples,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
diff --git a/diffusers/examples/community/lpw_stable_diffusion_xl.py b/diffusers/examples/community/lpw_stable_diffusion_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb955a688643f7f557482db4e761c674fe2f0506
--- /dev/null
+++ b/diffusers/examples/community/lpw_stable_diffusion_xl.py
@@ -0,0 +1,1308 @@
+## ----------------------------------------------------------
+# A SDXL pipeline can take unlimited weighted prompt
+#
+# Author: Andrew Zhu
+# Github: https://github.com/xhinker
+# Medium: https://medium.com/@xhinker
+## -----------------------------------------------------------
+
+import inspect
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_invisible_watermark_available,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \\( - literal character '('
+ \\[ - literal character '['
+ \\) - literal character ')'
+ \\] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\\(literal\\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+ import re
+
+ re_attention = re.compile(
+ r"""
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
+ \)|]|[^\\()\[\]:]+|:
+ """,
+ re.X,
+ )
+
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ parts = re.split(re_break, text)
+ for i, part in enumerate(parts):
+ if i > 0:
+ res.append(["BREAK", -1])
+ res.append([part, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
+ """
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
+
+ Args:
+ pipe (CLIPTokenizer)
+ A CLIPTokenizer
+ prompt (str)
+ A prompt string with weights
+
+ Returns:
+ text_tokens (list)
+ A list contains token ids
+ text_weight (list)
+ A list contains the correspodent weight of token ids
+
+ Example:
+ import torch
+ from transformers import CLIPTokenizer
+
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
+ "stablediffusionapi/deliberate-v2"
+ , subfolder = "tokenizer"
+ , dtype = torch.float16
+ )
+
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
+ clip_tokenizer = clip_tokenizer
+ ,prompt = "a (red:1.5) cat"*70
+ )
+ """
+ texts_and_weights = parse_prompt_attention(prompt)
+ text_tokens, text_weights = [], []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
+ # the returned token is a 1d list: [320, 1125, 539, 320]
+
+ # merge the new tokens to the all tokens holder: text_tokens
+ text_tokens = [*text_tokens, *token]
+
+ # each token chunk will come with one weight, like ['red cat', 2.0]
+ # need to expand weight for each token.
+ chunk_weights = [weight] * len(token)
+
+ # append the weight back to the weight holder: text_weights
+ text_weights = [*text_weights, *chunk_weights]
+ return text_tokens, text_weights
+
+
+def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False):
+ """
+ Produce tokens and weights in groups and pad the missing tokens
+
+ Args:
+ token_ids (list)
+ The token ids from tokenizer
+ weights (list)
+ The weights list from function get_prompts_tokens_with_weights
+ pad_last_block (bool)
+ Control if fill the last token list to 75 tokens with eos
+ Returns:
+ new_token_ids (2d list)
+ new_weights (2d list)
+
+ Example:
+ token_groups,weight_groups = group_tokens_and_weights(
+ token_ids = token_id_list
+ , weights = token_weight_list
+ )
+ """
+ bos, eos = 49406, 49407
+
+ # this will be a 2d list
+ new_token_ids = []
+ new_weights = []
+ while len(token_ids) >= 75:
+ # get the first 75 tokens
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
+ head_75_weights = [weights.pop(0) for _ in range(75)]
+
+ # extract token ids and weights
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
+
+ # add 77 token and weights chunk to the holder list
+ new_token_ids.append(temp_77_token_ids)
+ new_weights.append(temp_77_weights)
+
+ # padding the left
+ if len(token_ids) > 0:
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
+
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
+ new_token_ids.append(temp_77_token_ids)
+
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
+ new_weights.append(temp_77_weights)
+
+ return new_token_ids, new_weights
+
+
+def get_weighted_text_embeddings_sdxl(
+ pipe: StableDiffusionXLPipeline,
+ prompt: str = "",
+ prompt_2: str = None,
+ neg_prompt: str = "",
+ neg_prompt_2: str = None,
+ num_images_per_prompt: int = 1,
+):
+ """
+ This function can process long prompt with weights, no length limitation
+ for Stable Diffusion XL
+
+ Args:
+ pipe (StableDiffusionPipeline)
+ prompt (str)
+ prompt_2 (str)
+ neg_prompt (str)
+ neg_prompt_2 (str)
+ num_images_per_prompt (int)
+ Returns:
+ prompt_embeds (torch.Tensor)
+ neg_prompt_embeds (torch.Tensor)
+ """
+ if prompt_2:
+ prompt = f"{prompt} {prompt_2}"
+
+ if neg_prompt_2:
+ neg_prompt = f"{neg_prompt} {neg_prompt_2}"
+
+ eos = pipe.tokenizer.eos_token_id
+
+ # tokenizer 1
+ prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
+
+ neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
+
+ # tokenizer 2
+ prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
+
+ neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
+
+ # padding the shorter one for prompt set 1
+ prompt_token_len = len(prompt_tokens)
+ neg_prompt_token_len = len(neg_prompt_tokens)
+
+ if prompt_token_len > neg_prompt_token_len:
+ # padding the neg_prompt with eos token
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
+ else:
+ # padding the prompt
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
+
+ # padding the shorter one for token set 2
+ prompt_token_len_2 = len(prompt_tokens_2)
+ neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
+
+ if prompt_token_len_2 > neg_prompt_token_len_2:
+ # padding the neg_prompt with eos token
+ neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+ neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+ else:
+ # padding the prompt
+ prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+ prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
+
+ embeds = []
+ neg_embeds = []
+
+ prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
+
+ neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
+ neg_prompt_tokens.copy(), neg_prompt_weights.copy()
+ )
+
+ prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
+ prompt_tokens_2.copy(), prompt_weights_2.copy()
+ )
+
+ neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
+ neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
+ )
+
+ # get prompt embeddings one by one is not working.
+ for i in range(len(prompt_token_groups)):
+ # get positive prompt embeddings with weights
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
+
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
+
+ # use first text encoder
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
+
+ # use second text encoder
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
+ pooled_prompt_embeds = prompt_embeds_2[0]
+
+ prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
+ token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
+
+ for j in range(len(weight_tensor)):
+ if weight_tensor[j] != 1.0:
+ token_embedding[j] = (
+ token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
+ )
+
+ token_embedding = token_embedding.unsqueeze(0)
+ embeds.append(token_embedding)
+
+ # get negative prompt embeddings with weights
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
+
+ # use first text encoder
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
+ neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
+
+ # use second text encoder
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
+ neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
+ negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
+
+ neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
+ neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
+
+ for z in range(len(neg_weight_tensor)):
+ if neg_weight_tensor[z] != 1.0:
+ neg_token_embedding[z] = (
+ neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
+ )
+
+ neg_token_embedding = neg_token_embedding.unsqueeze(0)
+ neg_embeds.append(neg_token_embedding)
+
+ prompt_embeds = torch.cat(embeds, dim=1)
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+
+# -------------------------------------------------------------------------------------------------------------------------------
+# reuse the backbone code from StableDiffusionXLPipeline
+# -------------------------------------------------------------------------------------------------------------------------------
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0"
+ , torch_dtype = torch.float16
+ , use_safetensors = True
+ , variant = "fp16"
+ , custom_pipeline = "lpw_stable_diffusion_xl",
+ )
+
+ prompt = "a white cat running on the grass"*20
+ prompt2 = "play a football"*20
+ prompt = f"{prompt},{prompt2}"
+ neg_prompt = "blur, low quality"
+
+ pipe.to("cuda")
+ images = pipe(
+ prompt = prompt
+ , negative_prompt = neg_prompt
+ ).images[0]
+
+ pipe.to("cpu")
+ torch.cuda.empty_cache()
+ images
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ model_sequence = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+ model_sequence.extend([self.unet, self.vae])
+
+ hook = None
+ for cpu_offloaded_model in model_sequence:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: str = None,
+ prompt_2: Optional[str] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str`):
+ The prompt to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str`):
+ The prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str`):
+ The prompt not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str`):
+ The prompt not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
+
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = get_weighted_text_embeddings_sdxl(
+ pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
+
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = False,
+ ):
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ self.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/diffusers/examples/community/magic_mix.py b/diffusers/examples/community/magic_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3d118f84bfcdcdabad76a074515efec42f20bb7
--- /dev/null
+++ b/diffusers/examples/community/magic_mix.py
@@ -0,0 +1,152 @@
+from typing import Union
+
+import torch
+from PIL import Image
+from torchvision import transforms as tfms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+
+
+class MagicMixPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ # convert PIL image to latents
+ def encode(self, img):
+ with torch.no_grad():
+ latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
+ latent = 0.18215 * latent.latent_dist.sample()
+ return latent
+
+ # convert latents to PIL image
+ def decode(self, latent):
+ latent = (1 / 0.18215) * latent
+ with torch.no_grad():
+ img = self.vae.decode(latent).sample
+ img = (img / 2 + 0.5).clamp(0, 1)
+ img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ img = (img * 255).round().astype("uint8")
+ return Image.fromarray(img[0])
+
+ # convert prompt into text embeddings, also unconditional embeddings
+ def prep_text(self, prompt):
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ uncond_input = self.tokenizer(
+ "",
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ return torch.cat([uncond_embedding, text_embedding])
+
+ def __call__(
+ self,
+ img: Image.Image,
+ prompt: str,
+ kmin: float = 0.3,
+ kmax: float = 0.6,
+ mix_factor: float = 0.5,
+ seed: int = 42,
+ steps: int = 50,
+ guidance_scale: float = 7.5,
+ ) -> Image.Image:
+ tmin = steps - int(kmin * steps)
+ tmax = steps - int(kmax * steps)
+
+ text_embeddings = self.prep_text(prompt)
+
+ self.scheduler.set_timesteps(steps)
+
+ width, height = img.size
+ encoded = self.encode(img)
+
+ torch.manual_seed(seed)
+ noise = torch.randn(
+ (1, self.unet.config.in_channels, height // 8, width // 8),
+ ).to(self.device)
+
+ latents = self.scheduler.add_noise(
+ encoded,
+ noise,
+ timesteps=self.scheduler.timesteps[tmax],
+ )
+
+ input = torch.cat([latents] * 2)
+
+ input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])
+
+ with torch.no_grad():
+ pred = self.unet(
+ input,
+ self.scheduler.timesteps[tmax],
+ encoder_hidden_states=text_embeddings,
+ ).sample
+
+ pred_uncond, pred_text = pred.chunk(2)
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
+
+ latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample
+
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
+ if i > tmax:
+ if i < tmin: # layout generation phase
+ orig_latents = self.scheduler.add_noise(
+ encoded,
+ noise,
+ timesteps=t,
+ )
+
+ input = (
+ (mix_factor * latents) + (1 - mix_factor) * orig_latents
+ ) # interpolating between layout noise and conditionally generated noise to preserve layout sematics
+ input = torch.cat([input] * 2)
+
+ else: # content generation phase
+ input = torch.cat([latents] * 2)
+
+ input = self.scheduler.scale_model_input(input, t)
+
+ with torch.no_grad():
+ pred = self.unet(
+ input,
+ t,
+ encoder_hidden_states=text_embeddings,
+ ).sample
+
+ pred_uncond, pred_text = pred.chunk(2)
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
+
+ latents = self.scheduler.step(pred, t, latents).prev_sample
+
+ return self.decode(latents)
diff --git a/diffusers/examples/community/masked_stable_diffusion_img2img.py b/diffusers/examples/community/masked_stable_diffusion_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b08086c7da90e95958d56f777783ee08d4bd8a5
--- /dev/null
+++ b/diffusers/examples/community/masked_stable_diffusion_img2img.py
@@ -0,0 +1,262 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionImg2ImgPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+
+
+class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
+ debug_save = False
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ mask: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image` or tensor representing an image batch to be used as the starting point. Can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ strength (`float`, *optional*, defaults to 0.8):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):
+ A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # code adapted from parent class StableDiffusionImg2ImgPipeline
+
+ # 0. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
+
+ # 1. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 2. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ # it is sampled from the latent distribution of the VAE
+ latents = self.prepare_latents(
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
+ )
+
+ # mean of the latent distribution
+ init_latents = [
+ self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean
+ for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+
+ # 6. create latent mask
+ latent_mask = self._make_latent_mask(latents, mask)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if latent_mask is not None:
+ latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)
+ noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ scaled = latents / self.vae.config.scaling_factor
+ if latent_mask is not None:
+ # scaled = latents / self.vae.config.scaling_factor * latent_mask + init_latents * (1 - latent_mask)
+ scaled = torch.lerp(init_latents, scaled, latent_mask)
+ image = self.vae.decode(scaled, return_dict=False)[0]
+ if self.debug_save:
+ image_gen = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image_gen = self.image_processor.postprocess(image_gen, output_type=output_type, do_denormalize=[True])
+ image_gen[0].save("from_latent.png")
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def _make_latent_mask(self, latents, mask):
+ if mask is not None:
+ latent_mask = []
+ if not isinstance(mask, list):
+ tmp_mask = [mask]
+ else:
+ tmp_mask = mask
+ _, l_channels, l_height, l_width = latents.shape
+ for m in tmp_mask:
+ if not isinstance(m, PIL.Image.Image):
+ if len(m.shape) == 2:
+ m = m[..., np.newaxis]
+ if m.max() > 1:
+ m = m / 255.0
+ m = self.image_processor.numpy_to_pil(m)[0]
+ if m.mode != "L":
+ m = m.convert("L")
+ resized = self.image_processor.resize(m, l_height, l_width)
+ if self.debug_save:
+ resized.save("latent_mask.png")
+ latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))
+ latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)
+ latent_mask = latent_mask / latent_mask.max()
+ return latent_mask
diff --git a/diffusers/examples/community/mixture_canvas.py b/diffusers/examples/community/mixture_canvas.py
new file mode 100644
index 0000000000000000000000000000000000000000..3737183e5513dee9028335f5e10dc5b7c39ce088
--- /dev/null
+++ b/diffusers/examples/community/mixture_canvas.py
@@ -0,0 +1,501 @@
+import re
+from copy import deepcopy
+from dataclasses import asdict, dataclass
+from enum import Enum
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from numpy import exp, pi, sqrt
+from torchvision.transforms.functional import resize
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+def preprocess_image(image):
+ from PIL import Image
+
+ """Preprocess an input image
+
+ Same as
+ https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44
+ """
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+@dataclass
+class CanvasRegion:
+ """Class defining a rectangular region in the canvas"""
+
+ row_init: int # Region starting row in pixel space (included)
+ row_end: int # Region end row in pixel space (not included)
+ col_init: int # Region starting column in pixel space (included)
+ col_end: int # Region end column in pixel space (not included)
+ region_seed: int = None # Seed for random operations in this region
+ noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents
+
+ def __post_init__(self):
+ # Initialize arguments if not specified
+ if self.region_seed is None:
+ self.region_seed = np.random.randint(9999999999)
+ # Check coordinates are non-negative
+ for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
+ if coord < 0:
+ raise ValueError(
+ f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})"
+ )
+ # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space
+ for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
+ if coord // 8 != coord / 8:
+ raise ValueError(
+ f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})"
+ )
+ # Check noise eps is non-negative
+ if self.noise_eps < 0:
+ raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}")
+ # Compute coordinates for this region in latent space
+ self.latent_row_init = self.row_init // 8
+ self.latent_row_end = self.row_end // 8
+ self.latent_col_init = self.col_init // 8
+ self.latent_col_end = self.col_end // 8
+
+ @property
+ def width(self):
+ return self.col_end - self.col_init
+
+ @property
+ def height(self):
+ return self.row_end - self.row_init
+
+ def get_region_generator(self, device="cpu"):
+ """Creates a torch.Generator based on the random seed of this region"""
+ # Initialize region generator
+ return torch.Generator(device).manual_seed(self.region_seed)
+
+ @property
+ def __dict__(self):
+ return asdict(self)
+
+
+class MaskModes(Enum):
+ """Modes in which the influence of diffuser is masked"""
+
+ CONSTANT = "constant"
+ GAUSSIAN = "gaussian"
+ QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics)
+
+
+@dataclass
+class DiffusionRegion(CanvasRegion):
+ """Abstract class defining a region where some class of diffusion process is acting"""
+
+ pass
+
+
+@dataclass
+class Text2ImageRegion(DiffusionRegion):
+ """Class defining a region where a text guided diffusion process is acting"""
+
+ prompt: str = "" # Text prompt guiding the diffuser in this region
+ guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize
+ mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region
+ mask_weight: float = 1.0 # Global weights multiplier of the mask
+ tokenized_prompt = None # Tokenized prompt
+ encoded_prompt = None # Encoded prompt
+
+ def __post_init__(self):
+ super().__post_init__()
+ # Mask weight cannot be negative
+ if self.mask_weight < 0:
+ raise ValueError(
+ f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}"
+ )
+ # Mask type must be an actual known mask
+ if self.mask_type not in [e.value for e in MaskModes]:
+ raise ValueError(
+ f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})"
+ )
+ # Randomize arguments if given as None
+ if self.guidance_scale is None:
+ self.guidance_scale = np.random.randint(5, 30)
+ # Clean prompt
+ self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ")
+
+ def tokenize_prompt(self, tokenizer):
+ """Tokenizes the prompt for this diffusion region using a given tokenizer"""
+ self.tokenized_prompt = tokenizer(
+ self.prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ def encode_prompt(self, text_encoder, device):
+ """Encodes the previously tokenized prompt for this diffusion region using a given encoder"""
+ assert self.tokenized_prompt is not None, ValueError(
+ "Prompt in diffusion region must be tokenized before encoding"
+ )
+ self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0]
+
+
+@dataclass
+class Image2ImageRegion(DiffusionRegion):
+ """Class defining a region where an image guided diffusion process is acting"""
+
+ reference_image: torch.FloatTensor = None
+ strength: float = 0.8 # Strength of the image
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.reference_image is None:
+ raise ValueError("Must provide a reference image when creating an Image2ImageRegion")
+ if self.strength < 0 or self.strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}")
+ # Rescale image to region shape
+ self.reference_image = resize(self.reference_image, size=[self.height, self.width])
+
+ def encode_reference_image(self, encoder, device, generator, cpu_vae=False):
+ """Encodes the reference image for this Image2Image region into the latent space"""
+ # Place encoder in CPU or not following the parameter cpu_vae
+ if cpu_vae:
+ # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome
+ self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device)
+ else:
+ self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(
+ generator=generator
+ )
+ self.reference_latents = 0.18215 * self.reference_latents
+
+ @property
+ def __dict__(self):
+ # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON
+
+ # Get all basic fields from parent class
+ super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()}
+ # Pack other fields
+ return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength}
+
+
+class RerollModes(Enum):
+ """Modes in which the reroll regions operate"""
+
+ RESET = "reset" # Completely reset the random noise in the region
+ EPSILON = "epsilon" # Alter slightly the latents in the region
+
+
+@dataclass
+class RerollRegion(CanvasRegion):
+ """Class defining a rectangular canvas region in which initial latent noise will be rerolled"""
+
+ reroll_mode: RerollModes = RerollModes.RESET.value
+
+
+@dataclass
+class MaskWeightsBuilder:
+ """Auxiliary class to compute a tensor of weights for a given diffusion region"""
+
+ latent_space_dim: int # Size of the U-net latent space
+ nbatch: int = 1 # Batch size in the U-net
+
+ def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Computes a tensor of weights for a given diffusion region"""
+ MASK_BUILDERS = {
+ MaskModes.CONSTANT.value: self._constant_weights,
+ MaskModes.GAUSSIAN.value: self._gaussian_weights,
+ MaskModes.QUARTIC.value: self._quartic_weights,
+ }
+ return MASK_BUILDERS[region.mask_type](region)
+
+ def _constant_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Computes a tensor of constant for a given diffusion region"""
+ latent_width = region.latent_col_end - region.latent_col_init
+ latent_height = region.latent_row_end - region.latent_row_init
+ return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight
+
+ def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Generates a gaussian mask of weights for tile contributions"""
+ latent_width = region.latent_col_end - region.latent_col_init
+ latent_height = region.latent_row_end - region.latent_row_init
+
+ var = 0.01
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [
+ exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = (latent_height - 1) / 2
+ y_probs = [
+ exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+
+ weights = np.outer(y_probs, x_probs) * region.mask_weight
+ return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
+
+ def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor:
+ """Generates a quartic mask of weights for tile contributions
+
+ The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.
+ """
+ quartic_constant = 15.0 / 16.0
+
+ support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (
+ region.latent_col_end - region.latent_col_init - 1
+ ) * 1.99 - (1.99 / 2.0)
+ x_probs = quartic_constant * np.square(1 - np.square(support))
+ support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (
+ region.latent_row_end - region.latent_row_init - 1
+ ) * 1.99 - (1.99 / 2.0)
+ y_probs = quartic_constant * np.square(1 - np.square(support))
+
+ weights = np.outer(y_probs, x_probs) * region.mask_weight
+ return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
+
+
+class StableDiffusionCanvasPipeline(DiffusionPipeline):
+ """Stable Diffusion pipeline that mixes several diffusers in the same canvas"""
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def decode_latents(self, latents, cpu_vae=False):
+ """Decodes a given array of latents into pixel space"""
+ # scale and decode the image latents with vae
+ if cpu_vae:
+ lat = deepcopy(latents).cpu()
+ vae = deepcopy(self.vae).cpu()
+ else:
+ lat = latents
+ vae = self.vae
+
+ lat = 1 / 0.18215 * lat
+ image = vae.decode(lat).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ return self.numpy_to_pil(image)
+
+ def get_latest_timestep_img2img(self, num_inference_steps, strength):
+ """Finds the latest timesteps where an img2img strength does not impose latents anymore"""
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * (1 - strength)) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1)
+ latest_timestep = self.scheduler.timesteps[t_start]
+
+ return latest_timestep
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ canvas_height: int,
+ canvas_width: int,
+ regions: List[DiffusionRegion],
+ num_inference_steps: Optional[int] = 50,
+ seed: Optional[int] = 12345,
+ reroll_regions: Optional[List[RerollRegion]] = None,
+ cpu_vae: Optional[bool] = False,
+ decode_steps: Optional[bool] = False,
+ ):
+ if reroll_regions is None:
+ reroll_regions = []
+ batch_size = 1
+
+ if decode_steps:
+ steps_images = []
+
+ # Prepare scheduler
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ # Split diffusion regions by their kind
+ text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)]
+ image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)]
+
+ # Prepare text embeddings
+ for region in text2image_regions:
+ region.tokenize_prompt(self.tokenizer)
+ region.encode_prompt(self.text_encoder, self.device)
+
+ # Create original noisy latents using the timesteps
+ latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8)
+ generator = torch.Generator(self.device).manual_seed(seed)
+ init_noise = torch.randn(latents_shape, generator=generator, device=self.device)
+
+ # Reset latents in seed reroll regions, if requested
+ for region in reroll_regions:
+ if region.reroll_mode == RerollModes.RESET.value:
+ region_shape = (
+ latents_shape[0],
+ latents_shape[1],
+ region.latent_row_end - region.latent_row_init,
+ region.latent_col_end - region.latent_col_init,
+ )
+ init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device)
+
+ # Apply epsilon noise to regions: first diffusion regions, then reroll regions
+ all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value]
+ for region in all_eps_rerolls:
+ if region.noise_eps > 0:
+ region_noise = init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ eps_noise = (
+ torch.randn(
+ region_noise.shape, generator=region.get_region_generator(self.device), device=self.device
+ )
+ * region.noise_eps
+ )
+ init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += eps_noise
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = init_noise * self.scheduler.init_noise_sigma
+
+ # Get unconditional embeddings for classifier free guidance in text2image regions
+ for region in text2image_regions:
+ max_length = region.tokenized_prompt.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt])
+
+ # Prepare image latents
+ for region in image2image_regions:
+ region.encode_reference_image(self.vae, device=self.device, generator=generator)
+
+ # Prepare mask of weights for each region
+ mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size)
+ mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions]
+
+ # Diffusion timesteps
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
+ # Diffuse each region
+ noise_preds_regions = []
+
+ # text2image regions
+ for region in text2image_regions:
+ region_latents = latents[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([region_latents] * 2)
+ # scale model input following scheduler rules
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"]
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_preds_regions.append(noise_pred_region)
+
+ # Merge noise predictions for all tiles
+ noise_pred = torch.zeros(latents.shape, device=self.device)
+ contributors = torch.zeros(latents.shape, device=self.device)
+ # Add each tile contribution to overall latents
+ for region, noise_pred_region, mask_weights_region in zip(
+ text2image_regions, noise_preds_regions, mask_weights
+ ):
+ noise_pred[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += noise_pred_region * mask_weights_region
+ contributors[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] += mask_weights_region
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+ noise_pred = torch.nan_to_num(
+ noise_pred
+ ) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ # Image2Image regions: override latents generated by the scheduler
+ for region in image2image_regions:
+ influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength)
+ # Only override in the timesteps before the last influence step of the image (given by its strength)
+ if t > influence_step:
+ timestep = t.repeat(batch_size)
+ region_init_noise = init_noise[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ]
+ region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep)
+ latents[
+ :,
+ :,
+ region.latent_row_init : region.latent_row_end,
+ region.latent_col_init : region.latent_col_end,
+ ] = region_latents
+
+ if decode_steps:
+ steps_images.append(self.decode_latents(latents, cpu_vae))
+
+ # scale and decode the image latents with vae
+ image = self.decode_latents(latents, cpu_vae)
+
+ output = {"images": image}
+ if decode_steps:
+ output = {**output, "steps_images": steps_images}
+ return output
diff --git a/diffusers/examples/community/mixture_tiling.py b/diffusers/examples/community/mixture_tiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f92ae0e1d35934d90d59af470a3bae588bddb869
--- /dev/null
+++ b/diffusers/examples/community/mixture_tiling.py
@@ -0,0 +1,405 @@
+import inspect
+from copy import deepcopy
+from enum import Enum
+from typing import List, Optional, Tuple, Union
+
+import torch
+from tqdm.auto import tqdm
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import logging
+
+
+try:
+ from ligo.segments import segment
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+except ImportError:
+ raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import LMSDiscreteScheduler, DiffusionPipeline
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, custom_pipeline="mixture_tiling")
+ >>> pipeline.to("cuda")
+
+ >>> image = pipeline(
+ >>> prompt=[[
+ >>> "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ >>> "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
+ >>> "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
+ >>> ]],
+ >>> tile_height=640,
+ >>> tile_width=640,
+ >>> tile_row_overlap=0,
+ >>> tile_col_overlap=256,
+ >>> guidance_scale=8,
+ >>> seed=7178915308,
+ >>> num_inference_steps=50,
+ >>> )["images"][0]
+ ```
+"""
+
+
+def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
+ """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in pixel space
+ - Ending coordinates of rows in pixel space
+ - Starting coordinates of columns in pixel space
+ - Ending coordinates of columns in pixel space
+ """
+ px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
+ px_row_end = px_row_init + tile_height
+ px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
+ px_col_end = px_col_init + tile_width
+ return px_row_init, px_row_end, px_col_init, px_col_end
+
+
+def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):
+ """Translates coordinates in pixel space to coordinates in latent space"""
+ return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8
+
+
+def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
+ """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)
+
+
+def _tile2latent_exclusive_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns
+):
+ """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image
+
+ Returns a tuple with:
+ - Starting coordinates of rows in latent space
+ - Ending coordinates of rows in latent space
+ - Starting coordinates of columns in latent space
+ - Ending coordinates of columns in latent space
+ """
+ row_init, row_end, col_init, col_end = _tile2latent_indices(
+ tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ row_segment = segment(row_init, row_end)
+ col_segment = segment(col_init, col_end)
+ # Iterate over the rest of tiles, clipping the region for the current tile
+ for row in range(rows):
+ for column in range(columns):
+ if row != tile_row and column != tile_col:
+ clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(
+ row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ row_segment = row_segment - segment(clip_row_init, clip_row_end)
+ col_segment = col_segment - segment(clip_col_init, clip_col_end)
+ # return row_init, row_end, col_init, col_end
+ return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
+
+
+class StableDiffusionExtrasMixin:
+ """Mixin providing additional convenience method to Stable Diffusion pipelines"""
+
+ def decode_latents(self, latents, cpu_vae=False):
+ """Decodes a given array of latents into pixel space"""
+ # scale and decode the image latents with vae
+ if cpu_vae:
+ lat = deepcopy(latents).cpu()
+ vae = deepcopy(self.vae).cpu()
+ else:
+ lat = latents
+ vae = self.vae
+
+ lat = 1 / 0.18215 * lat
+ image = vae.decode(lat).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ return self.numpy_to_pil(image)
+
+
+class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ class SeedTilesMode(Enum):
+ """Modes in which the latents of a particular tile can be re-seeded"""
+
+ FULL = "full"
+ EXCLUSIVE = "exclusive"
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[List[str]]],
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ seed: Optional[int] = None,
+ tile_height: Optional[int] = 512,
+ tile_width: Optional[int] = 512,
+ tile_row_overlap: Optional[int] = 256,
+ tile_col_overlap: Optional[int] = 256,
+ guidance_scale_tiles: Optional[List[List[float]]] = None,
+ seed_tiles: Optional[List[List[int]]] = None,
+ seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
+ seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
+ cpu_vae: Optional[bool] = False,
+ ):
+ r"""
+ Function to run the diffusion pipeline with tiling support.
+
+ Args:
+ prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure.
+ num_inference_steps: number of diffusions steps.
+ guidance_scale: classifier-free guidance.
+ seed: general random seed to initialize latents.
+ tile_height: height in pixels of each grid tile.
+ tile_width: width in pixels of each grid tile.
+ tile_row_overlap: number of overlap pixels between tiles in consecutive rows.
+ tile_col_overlap: number of overlap pixels between tiles in consecutive columns.
+ guidance_scale_tiles: specific weights for classifier-free guidance in each tile.
+ guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used.
+ seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter.
+ seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden.
+ seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles.
+ cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues.
+
+ Examples:
+
+ Returns:
+ A PIL image with the generated image.
+
+ """
+ if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
+ raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
+ grid_rows = len(prompt)
+ grid_cols = len(prompt[0])
+ if not all(len(row) == grid_cols for row in prompt):
+ raise ValueError("All prompt rows must have the same number of prompt columns")
+ if not isinstance(seed_tiles_mode, str) and (
+ not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)
+ ):
+ raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
+ if isinstance(seed_tiles_mode, str):
+ seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
+
+ modes = [mode.value for mode in self.SeedTilesMode]
+ if any(mode not in modes for row in seed_tiles_mode for mode in row):
+ raise ValueError(f"Seed tiles mode must be one of {modes}")
+ if seed_reroll_regions is None:
+ seed_reroll_regions = []
+ batch_size = 1
+
+ # create original noisy latents using the timesteps
+ height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
+ width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
+ latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
+ generator = torch.Generator("cuda").manual_seed(seed)
+ latents = torch.randn(latents_shape, generator=generator, device=self.device)
+
+ # overwrite latents for specific tiles if provided
+ if seed_tiles is not None:
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ if (seed_tile := seed_tiles[row][col]) is not None:
+ mode = seed_tiles_mode[row][col]
+ if mode == self.SeedTilesMode.FULL.value:
+ row_init, row_end, col_init, col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ else:
+ row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(
+ row,
+ col,
+ tile_width,
+ tile_height,
+ tile_row_overlap,
+ tile_col_overlap,
+ grid_rows,
+ grid_cols,
+ )
+ tile_generator = torch.Generator("cuda").manual_seed(seed_tile)
+ tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
+ latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
+ tile_shape, generator=tile_generator, device=self.device
+ )
+
+ # overwrite again for seed reroll regions
+ for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:
+ row_init, row_end, col_init, col_end = _pixel2latent_indices(
+ row_init, row_end, col_init, col_end
+ ) # to latent space coordinates
+ reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll)
+ region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
+ latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(
+ region_shape, generator=reroll_generator, device=self.device
+ )
+
+ # Prepare scheduler
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # get prompts text embeddings
+ text_input = [
+ [
+ self.tokenizer(
+ col,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ for col in row
+ ]
+ for row in prompt
+ ]
+ text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ for i in range(grid_rows):
+ for j in range(grid_cols):
+ max_length = text_input[i][j].input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # Mask for tile weights strenght
+ tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size)
+
+ # Diffusion timesteps
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
+ # Diffuse each tile
+ noise_preds = []
+ for row in range(grid_rows):
+ noise_preds_row = []
+ for col in range(grid_cols):
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[
+ "sample"
+ ]
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ guidance = (
+ guidance_scale
+ if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None
+ else guidance_scale_tiles[row][col]
+ )
+ noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
+ noise_preds_row.append(noise_pred_tile)
+ noise_preds.append(noise_preds_row)
+ # Stitch noise predictions for all tiles
+ noise_pred = torch.zeros(latents.shape, device=self.device)
+ contributors = torch.zeros(latents.shape, device=self.device)
+ # Add each tile contribution to overall latents
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(
+ row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap
+ )
+ noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (
+ noise_preds[row][col] * tile_weights
+ )
+ contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ # scale and decode the image latents with vae
+ image = self.decode_latents(latents, cpu_vae)
+
+ return {"images": image}
+
+ def _gaussian_weights(self, tile_width, tile_height, nbatches):
+ """Generates a gaussian mask of weights for tile contributions"""
+ import numpy as np
+ from numpy import exp, pi, sqrt
+
+ latent_width = tile_width // 8
+ latent_height = tile_height // 8
+
+ var = 0.01
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [
+ exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
+ for x in range(latent_width)
+ ]
+ midpoint = latent_height / 2
+ y_probs = [
+ exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
+ for y in range(latent_height)
+ ]
+
+ weights = np.outer(y_probs, x_probs)
+ return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
diff --git a/diffusers/examples/community/multilingual_stable_diffusion.py b/diffusers/examples/community/multilingual_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7597efd215afb82a04443f8ee1f4f5f7b4e04d77
--- /dev/null
+++ b/diffusers/examples/community/multilingual_stable_diffusion.py
@@ -0,0 +1,437 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ MBart50TokenizerFast,
+ MBartForConditionalGeneration,
+ pipeline,
+)
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def detect_language(pipe, prompt, batch_size):
+ """helper function to detect language(s) of prompt"""
+
+ if batch_size == 1:
+ preds = pipe(prompt, top_k=1, truncation=True, max_length=128)
+ return preds[0]["label"]
+ else:
+ detected_languages = []
+ for p in prompt:
+ preds = pipe(p, top_k=1, truncation=True, max_length=128)
+ detected_languages.append(preds[0]["label"])
+
+ return detected_languages
+
+
+def translate_prompt(prompt, translation_tokenizer, translation_model, device):
+ """helper function to translate prompt to English"""
+
+ encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device)
+ generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)
+ en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+
+ return en_trans[0]
+
+
+class MultilingualStableDiffusion(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion in different languages.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ detection_pipeline ([`pipeline`]):
+ Transformers pipeline to detect prompt's language.
+ translation_model ([`MBartForConditionalGeneration`]):
+ Model to translate prompt to English, if necessary. Please refer to the
+ [model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.
+ translation_tokenizer ([`MBart50TokenizerFast`]):
+ Tokenizer of the translation model.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ detection_pipeline: pipeline,
+ translation_model: MBartForConditionalGeneration,
+ translation_tokenizer: MBart50TokenizerFast,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ detection_pipeline=detection_pipeline,
+ translation_model=translation_model,
+ translation_tokenizer=translation_tokenizer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation. Can be in different languages.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # detect language and translate if necessary
+ prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)
+ if batch_size == 1 and prompt_language != "en":
+ prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)
+
+ if isinstance(prompt, list):
+ for index in range(batch_size):
+ if prompt_language[index] != "en":
+ p = translate_prompt(
+ prompt[index], self.translation_tokenizer, self.translation_model, self.device
+ )
+ prompt[index] = p
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ # detect language and translate it if necessary
+ negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)
+ if negative_prompt_language != "en":
+ negative_prompt = translate_prompt(
+ negative_prompt, self.translation_tokenizer, self.translation_model, self.device
+ )
+ if isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ # detect language and translate it if necessary
+ if isinstance(negative_prompt, list):
+ negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)
+ for index in range(batch_size):
+ if negative_prompt_languages[index] != "en":
+ p = translate_prompt(
+ negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device
+ )
+ negative_prompt[index] = p
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/one_step_unet.py b/diffusers/examples/community/one_step_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d34bfd83191d63483bc562cb54cc887660cdffa
--- /dev/null
+++ b/diffusers/examples/community/one_step_unet.py
@@ -0,0 +1,24 @@
+#!/usr/bin/env python3
+import torch
+
+from diffusers import DiffusionPipeline
+
+
+class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ def __call__(self):
+ image = torch.randn(
+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
+ )
+ timestep = 1
+
+ model_output = self.unet(image, timestep).sample
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
+
+ result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)
+
+ return result
diff --git a/diffusers/examples/community/pipeline_fabric.py b/diffusers/examples/community/pipeline_fabric.py
new file mode 100644
index 0000000000000000000000000000000000000000..080d0c221727a29ce6f5601b684b3d7b0afe0601
--- /dev/null
+++ b/diffusers/examples/community/pipeline_fabric.py
@@ -0,0 +1,751 @@
+# Copyright 2023 FABRIC authors and the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, Union
+
+import torch
+from packaging import version
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, UNet2DConditionModel
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import EulerAncestralDiscreteScheduler, KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusers import DiffusionPipeline
+ >>> import torch
+
+ >>> model_id = "dreamlike-art/dreamlike-photoreal-2.0"
+ >>> pipe = DiffusionPipeline(model_id, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric")
+ >>> pipe = pipe.to("cuda")
+ >>> prompt = "a giant standing in a fantasy landscape best quality"
+ >>> liked = [] # list of images for positive feedback
+ >>> disliked = [] # list of images for negative feedback
+ >>> image = pipe(prompt, num_images=4, liked=liked, disliked=disliked).images[0]
+ ```
+"""
+
+
+class FabricCrossAttnProcessor:
+ def __init__(self):
+ self.attntion_probs = None
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ weights=None,
+ lora_scale=1.0,
+ ):
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if isinstance(attn.processor, LoRAAttnProcessor):
+ query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if isinstance(attn.processor, LoRAAttnProcessor):
+ key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+ if weights is not None:
+ if weights.shape[0] != 1:
+ weights = weights.repeat_interleave(attn.heads, dim=0)
+ attention_probs = attention_probs * weights[:, None]
+ attention_probs = attention_probs / attention_probs.sum(dim=-1, keepdim=True)
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ if isinstance(attn.processor, LoRAAttnProcessor):
+ hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)
+ else:
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class FabricPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion and conditioning the results using feedback images.
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`EulerAncestralDiscreteScheduler`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ unet=unet,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def get_unet_hidden_states(self, z_all, t, prompt_embd):
+ cached_hidden_states = []
+ for module in self.unet.modules():
+ if isinstance(module, BasicTransformerBlock):
+
+ def new_forward(self, hidden_states, *args, **kwargs):
+ cached_hidden_states.append(hidden_states.clone().detach().cpu())
+ return self.old_forward(hidden_states, *args, **kwargs)
+
+ module.attn1.old_forward = module.attn1.forward
+ module.attn1.forward = new_forward.__get__(module.attn1)
+
+ # run forward pass to cache hidden states, output can be discarded
+ _ = self.unet(z_all, t, encoder_hidden_states=prompt_embd)
+
+ # restore original forward pass
+ for module in self.unet.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.attn1.forward = module.attn1.old_forward
+ del module.attn1.old_forward
+
+ return cached_hidden_states
+
+ def unet_forward_with_cached_hidden_states(
+ self,
+ z_all,
+ t,
+ prompt_embd,
+ cached_pos_hiddens: Optional[List[torch.Tensor]] = None,
+ cached_neg_hiddens: Optional[List[torch.Tensor]] = None,
+ pos_weights=(0.8, 0.8),
+ neg_weights=(0.5, 0.5),
+ ):
+ if cached_pos_hiddens is None and cached_neg_hiddens is None:
+ return self.unet(z_all, t, encoder_hidden_states=prompt_embd)
+
+ local_pos_weights = torch.linspace(*pos_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()
+ local_neg_weights = torch.linspace(*neg_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()
+ for block, pos_weight, neg_weight in zip(
+ self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks,
+ local_pos_weights + [pos_weights[1]] + local_pos_weights[::-1],
+ local_neg_weights + [neg_weights[1]] + local_neg_weights[::-1],
+ ):
+ for module in block.modules():
+ if isinstance(module, BasicTransformerBlock):
+
+ def new_forward(
+ self,
+ hidden_states,
+ pos_weight=pos_weight,
+ neg_weight=neg_weight,
+ **kwargs,
+ ):
+ cond_hiddens, uncond_hiddens = hidden_states.chunk(2, dim=0)
+ batch_size, d_model = cond_hiddens.shape[:2]
+ device, dtype = hidden_states.device, hidden_states.dtype
+
+ weights = torch.ones(batch_size, d_model, device=device, dtype=dtype)
+ out_pos = self.old_forward(hidden_states)
+ out_neg = self.old_forward(hidden_states)
+
+ if cached_pos_hiddens is not None:
+ cached_pos_hs = cached_pos_hiddens.pop(0).to(hidden_states.device)
+ cond_pos_hs = torch.cat([cond_hiddens, cached_pos_hs], dim=1)
+ pos_weights = weights.clone().repeat(1, 1 + cached_pos_hs.shape[1] // d_model)
+ pos_weights[:, d_model:] = pos_weight
+ attn_with_weights = FabricCrossAttnProcessor()
+ out_pos = attn_with_weights(
+ self,
+ cond_hiddens,
+ encoder_hidden_states=cond_pos_hs,
+ weights=pos_weights,
+ )
+ else:
+ out_pos = self.old_forward(cond_hiddens)
+
+ if cached_neg_hiddens is not None:
+ cached_neg_hs = cached_neg_hiddens.pop(0).to(hidden_states.device)
+ uncond_neg_hs = torch.cat([uncond_hiddens, cached_neg_hs], dim=1)
+ neg_weights = weights.clone().repeat(1, 1 + cached_neg_hs.shape[1] // d_model)
+ neg_weights[:, d_model:] = neg_weight
+ attn_with_weights = FabricCrossAttnProcessor()
+ out_neg = attn_with_weights(
+ self,
+ uncond_hiddens,
+ encoder_hidden_states=uncond_neg_hs,
+ weights=neg_weights,
+ )
+ else:
+ out_neg = self.old_forward(uncond_hiddens)
+
+ out = torch.cat([out_pos, out_neg], dim=0)
+ return out
+
+ module.attn1.old_forward = module.attn1.forward
+ module.attn1.forward = new_forward.__get__(module.attn1)
+
+ out = self.unet(z_all, t, encoder_hidden_states=prompt_embd)
+
+ # restore original forward pass
+ for module in self.unet.modules():
+ if isinstance(module, BasicTransformerBlock):
+ module.attn1.forward = module.attn1.old_forward
+ del module.attn1.old_forward
+
+ return out
+
+ def preprocess_feedback_images(self, images, vae, dim, device, dtype, generator) -> torch.tensor:
+ images_t = [self.image_to_tensor(img, dim, dtype) for img in images]
+ images_t = torch.stack(images_t).to(device)
+ latents = vae.config.scaling_factor * vae.encode(images_t).latent_dist.sample(generator)
+
+ return torch.cat([latents], dim=0)
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt=None,
+ liked=None,
+ disliked=None,
+ height=None,
+ width=None,
+ ):
+ if prompt is None:
+ raise ValueError("Provide `prompt`. Cannot leave both `prompt` undefined.")
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if liked is not None and not isinstance(liked, list):
+ raise ValueError(f"`liked` has to be of type `list` but is {type(liked)}")
+
+ if disliked is not None and not isinstance(disliked, list):
+ raise ValueError(f"`disliked` has to be of type `list` but is {type(disliked)}")
+
+ if height is not None and not isinstance(height, int):
+ raise ValueError(f"`height` has to be of type `int` but is {type(height)}")
+
+ if width is not None and not isinstance(width, int):
+ raise ValueError(f"`width` has to be of type `int` but is {type(width)}")
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = "",
+ negative_prompt: Optional[Union[str, List[str]]] = "lowres, bad anatomy, bad hands, cropped, worst quality",
+ liked: Optional[Union[List[str], List[Image.Image]]] = [],
+ disliked: Optional[Union[List[str], List[Image.Image]]] = [],
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ height: int = 512,
+ width: int = 512,
+ return_dict: bool = True,
+ num_images: int = 4,
+ guidance_scale: float = 7.0,
+ num_inference_steps: int = 20,
+ output_type: Optional[str] = "pil",
+ feedback_start_ratio: float = 0.33,
+ feedback_end_ratio: float = 0.66,
+ min_weight: float = 0.05,
+ max_weight: float = 0.8,
+ neg_scale: float = 0.5,
+ pos_bottleneck_scale: float = 1.0,
+ neg_bottleneck_scale: float = 1.0,
+ latents: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation. Generate a trajectory of images with binary feedback. The
+ feedback can be given as a list of liked and disliked images.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ liked (`List[Image.Image]` or `List[str]`, *optional*):
+ Encourages images with liked features.
+ disliked (`List[Image.Image]` or `List[str]`, *optional*):
+ Discourages images with disliked features.
+ generator (`torch.Generator` or `List[torch.Generator]` or `int`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) or an `int` to
+ make generation deterministic.
+ height (`int`, *optional*, defaults to 512):
+ Height of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ Width of the generated image.
+ num_images (`int`, *optional*, defaults to 4):
+ The number of images to generate per prompt.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ feedback_start_ratio (`float`, *optional*, defaults to `.33`):
+ Start point for providing feedback (between 0 and 1).
+ feedback_end_ratio (`float`, *optional*, defaults to `.66`):
+ End point for providing feedback (between 0 and 1).
+ min_weight (`float`, *optional*, defaults to `.05`):
+ Minimum weight for feedback.
+ max_weight (`float`, *optional*, defults tp `1.0`):
+ Maximum weight for feedback.
+ neg_scale (`float`, *optional*, defaults to `.5`):
+ Scale factor for negative feedback.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.fabric.FabricPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+
+ """
+
+ self.check_inputs(prompt, negative_prompt, liked, disliked)
+
+ device = self._execution_device
+ dtype = self.unet.dtype
+
+ if isinstance(prompt, str) and prompt is not None:
+ batch_size = 1
+ elif isinstance(prompt, list) and prompt is not None:
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = negative_prompt
+ elif isinstance(negative_prompt, list):
+ negative_prompt = negative_prompt
+ else:
+ assert len(negative_prompt) == batch_size
+
+ shape = (
+ batch_size * num_images,
+ self.unet.config.in_channels,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ latent_noise = randn_tensor(
+ shape,
+ device=device,
+ dtype=dtype,
+ generator=generator,
+ )
+
+ positive_latents = (
+ self.preprocess_feedback_images(liked, self.vae, (height, width), device, dtype, generator)
+ if liked and len(liked) > 0
+ else torch.tensor(
+ [],
+ device=device,
+ dtype=dtype,
+ )
+ )
+ negative_latents = (
+ self.preprocess_feedback_images(disliked, self.vae, (height, width), device, dtype, generator)
+ if disliked and len(disliked) > 0
+ else torch.tensor(
+ [],
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ do_classifier_free_guidance = guidance_scale > 0.1
+
+ (prompt_neg_embs, prompt_pos_embs) = self._encode_prompt(
+ prompt,
+ device,
+ num_images,
+ do_classifier_free_guidance,
+ negative_prompt,
+ ).split([num_images * batch_size, num_images * batch_size])
+
+ batched_prompt_embd = torch.cat([prompt_pos_embs, prompt_neg_embs], dim=0)
+
+ null_tokens = self.tokenizer(
+ [""],
+ return_tensors="pt",
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = null_tokens.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ null_prompt_emb = self.text_encoder(
+ input_ids=null_tokens.input_ids.to(device),
+ attention_mask=attention_mask,
+ ).last_hidden_state
+
+ null_prompt_emb = null_prompt_emb.to(device=device, dtype=dtype)
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ latent_noise = latent_noise * self.scheduler.init_noise_sigma
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ ref_start_idx = round(len(timesteps) * feedback_start_ratio)
+ ref_end_idx = round(len(timesteps) * feedback_end_ratio)
+
+ with self.progress_bar(total=num_inference_steps) as pbar:
+ for i, t in enumerate(timesteps):
+ sigma = self.scheduler.sigma_t[t] if hasattr(self.scheduler, "sigma_t") else 0
+ if hasattr(self.scheduler, "sigmas"):
+ sigma = self.scheduler.sigmas[i]
+
+ alpha_hat = 1 / (sigma**2 + 1)
+
+ z_single = self.scheduler.scale_model_input(latent_noise, t)
+ z_all = torch.cat([z_single] * 2, dim=0)
+ z_ref = torch.cat([positive_latents, negative_latents], dim=0)
+
+ if i >= ref_start_idx and i <= ref_end_idx:
+ weight_factor = max_weight
+ else:
+ weight_factor = min_weight
+
+ pos_ws = (weight_factor, weight_factor * pos_bottleneck_scale)
+ neg_ws = (weight_factor * neg_scale, weight_factor * neg_scale * neg_bottleneck_scale)
+
+ if z_ref.size(0) > 0 and weight_factor > 0:
+ noise = torch.randn_like(z_ref)
+ if isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
+ z_ref_noised = (alpha_hat**0.5 * z_ref + (1 - alpha_hat) ** 0.5 * noise).type(dtype)
+ else:
+ z_ref_noised = self.scheduler.add_noise(z_ref, noise, t)
+
+ ref_prompt_embd = torch.cat(
+ [null_prompt_emb] * (len(positive_latents) + len(negative_latents)), dim=0
+ )
+ cached_hidden_states = self.get_unet_hidden_states(z_ref_noised, t, ref_prompt_embd)
+
+ n_pos, n_neg = positive_latents.shape[0], negative_latents.shape[0]
+ cached_pos_hs, cached_neg_hs = [], []
+ for hs in cached_hidden_states:
+ cached_pos, cached_neg = hs.split([n_pos, n_neg], dim=0)
+ cached_pos = cached_pos.view(1, -1, *cached_pos.shape[2:]).expand(num_images, -1, -1)
+ cached_neg = cached_neg.view(1, -1, *cached_neg.shape[2:]).expand(num_images, -1, -1)
+ cached_pos_hs.append(cached_pos)
+ cached_neg_hs.append(cached_neg)
+
+ if n_pos == 0:
+ cached_pos_hs = None
+ if n_neg == 0:
+ cached_neg_hs = None
+ else:
+ cached_pos_hs, cached_neg_hs = None, None
+ unet_out = self.unet_forward_with_cached_hidden_states(
+ z_all,
+ t,
+ prompt_embd=batched_prompt_embd,
+ cached_pos_hiddens=cached_pos_hs,
+ cached_neg_hiddens=cached_neg_hs,
+ pos_weights=pos_ws,
+ neg_weights=neg_ws,
+ )[0]
+
+ noise_cond, noise_uncond = unet_out.chunk(2)
+ guidance = noise_cond - noise_uncond
+ noise_pred = noise_uncond + guidance_scale * guidance
+ latent_noise = self.scheduler.step(noise_pred, t, latent_noise)[0]
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ pbar.update()
+
+ y = self.vae.decode(latent_noise / self.vae.config.scaling_factor, return_dict=False)[0]
+ imgs = self.image_processor.postprocess(
+ y,
+ output_type=output_type,
+ )
+
+ if not return_dict:
+ return imgs
+
+ return StableDiffusionPipelineOutput(imgs, False)
+
+ def image_to_tensor(self, image: Union[str, Image.Image], dim: tuple, dtype):
+ """
+ Convert latent PIL image to a torch tensor for further processing.
+ """
+ if isinstance(image, str):
+ image = Image.open(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = self.image_processor.preprocess(image, height=dim[0], width=dim[1])[0]
+ return image.type(dtype)
diff --git a/diffusers/examples/community/pipeline_prompt2prompt.py b/diffusers/examples/community/pipeline_prompt2prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..59b8e691bde365528e636543724be84e2f4aa61d
--- /dev/null
+++ b/diffusers/examples/community/pipeline_prompt2prompt.py
@@ -0,0 +1,861 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import abc
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ...src.diffusers.models.attention import Attention
+from ...src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class Prompt2PromptPipeline(StableDiffusionPipeline):
+ r"""
+ Args:
+ Prompt-to-Prompt-Pipeline for text-to-image generation using Stable Diffusion. This model inherits from
+ [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for
+ all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler
+ ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ The keyword arguments to configure the edit are:
+ - edit_type (`str`). The edit type to apply. Can be either of `replace`, `refine`, `reweight`.
+ - n_cross_replace (`int`): Number of diffusion steps in which cross attention should be replaced
+ - n_self_replace (`int`): Number of diffusion steps in which self attention should be replaced
+ - local_blend_words(`List[str]`, *optional*, default to `None`): Determines which area should be
+ changed. If None, then the whole image can be changed.
+ - equalizer_words(`List[str]`, *optional*, default to `None`): Required for edit type `reweight`.
+ Determines which words should be enhanced.
+ - equalizer_strengths (`List[float]`, *optional*, default to `None`) Required for edit type `reweight`.
+ Determines which how much the words in `equalizer_words` should be enhanced.
+
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ self.controller = create_controller(
+ prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=self.device
+ )
+ self.register_attention_control(self.controller) # add attention controller
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # step callback
+ latents = self.controller.step_callback(latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 8. Post-processing
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ # 9. Run safety checker
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def register_attention_control(self, controller):
+ attn_procs = {}
+ cross_att_count = 0
+ for name in self.unet.attn_processors.keys():
+ None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ self.unet.config.block_out_channels[-1]
+ place_in_unet = "mid"
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ list(reversed(self.unet.config.block_out_channels))[block_id]
+ place_in_unet = "up"
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ self.unet.config.block_out_channels[block_id]
+ place_in_unet = "down"
+ else:
+ continue
+ cross_att_count += 1
+ attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)
+
+ self.unet.set_attn_processor(attn_procs)
+ controller.num_att_layers = cross_att_count
+
+
+class P2PCrossAttnProcessor:
+ def __init__(self, controller, place_in_unet):
+ super().__init__()
+ self.controller = controller
+ self.place_in_unet = place_in_unet
+
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ query = attn.to_q(hidden_states)
+
+ is_cross = encoder_hidden_states is not None
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+ # one line change
+ self.controller(attention_probs, is_cross, self.place_in_unet)
+
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+def create_controller(
+ prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device
+) -> AttentionControl:
+ edit_type = cross_attention_kwargs.get("edit_type", None)
+ local_blend_words = cross_attention_kwargs.get("local_blend_words", None)
+ equalizer_words = cross_attention_kwargs.get("equalizer_words", None)
+ equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None)
+ n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4)
+ n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4)
+
+ # only replace
+ if edit_type == "replace" and local_blend_words is None:
+ return AttentionReplace(
+ prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device
+ )
+
+ # replace + localblend
+ if edit_type == "replace" and local_blend_words is not None:
+ lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
+ return AttentionReplace(
+ prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device
+ )
+
+ # only refine
+ if edit_type == "refine" and local_blend_words is None:
+ return AttentionRefine(
+ prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device
+ )
+
+ # refine + localblend
+ if edit_type == "refine" and local_blend_words is not None:
+ lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
+ return AttentionRefine(
+ prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device
+ )
+
+ # reweight
+ if edit_type == "reweight":
+ assert (
+ equalizer_words is not None and equalizer_strengths is not None
+ ), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
+ assert len(equalizer_words) == len(
+ equalizer_strengths
+ ), "equalizer_words and equalizer_strengths must be of same length."
+ equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
+ return AttentionReweight(
+ prompts,
+ num_inference_steps,
+ n_cross_replace,
+ n_self_replace,
+ tokenizer=tokenizer,
+ device=device,
+ equalizer=equalizer,
+ )
+
+ raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.")
+
+
+class AttentionControl(abc.ABC):
+ def step_callback(self, x_t):
+ return x_t
+
+ def between_steps(self):
+ return
+
+ @property
+ def num_uncond_att_layers(self):
+ return 0
+
+ @abc.abstractmethod
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ raise NotImplementedError
+
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
+ if self.cur_att_layer >= self.num_uncond_att_layers:
+ h = attn.shape[0]
+ attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
+ self.cur_att_layer += 1
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
+ self.cur_att_layer = 0
+ self.cur_step += 1
+ self.between_steps()
+ return attn
+
+ def reset(self):
+ self.cur_step = 0
+ self.cur_att_layer = 0
+
+ def __init__(self):
+ self.cur_step = 0
+ self.num_att_layers = -1
+ self.cur_att_layer = 0
+
+
+class EmptyControl(AttentionControl):
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ return attn
+
+
+class AttentionStore(AttentionControl):
+ @staticmethod
+ def get_empty_store():
+ return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
+
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+ if attn.shape[1] <= 32**2: # avoid memory overhead
+ self.step_store[key].append(attn)
+ return attn
+
+ def between_steps(self):
+ if len(self.attention_store) == 0:
+ self.attention_store = self.step_store
+ else:
+ for key in self.attention_store:
+ for i in range(len(self.attention_store[key])):
+ self.attention_store[key][i] += self.step_store[key][i]
+ self.step_store = self.get_empty_store()
+
+ def get_average_attention(self):
+ average_attention = {
+ key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store
+ }
+ return average_attention
+
+ def reset(self):
+ super(AttentionStore, self).reset()
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+
+ def __init__(self):
+ super(AttentionStore, self).__init__()
+ self.step_store = self.get_empty_store()
+ self.attention_store = {}
+
+
+class LocalBlend:
+ def __call__(self, x_t, attention_store):
+ k = 1
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
+ maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps]
+ maps = torch.cat(maps, dim=1)
+ maps = (maps * self.alpha_layers).sum(-1).mean(1)
+ mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+ mask = F.interpolate(mask, size=(x_t.shape[2:]))
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+ mask = mask.gt(self.threshold)
+ mask = (mask[:1] + mask[1:]).float()
+ x_t = x_t[:1] + mask * (x_t - x_t[:1])
+ return x_t
+
+ def __init__(
+ self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, max_num_words=77
+ ):
+ self.max_num_words = 77
+
+ alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
+ for i, (prompt, words_) in enumerate(zip(prompts, words)):
+ if isinstance(words_, str):
+ words_ = [words_]
+ for word in words_:
+ ind = get_word_inds(prompt, word, tokenizer)
+ alpha_layers[i, :, :, :, :, ind] = 1
+ self.alpha_layers = alpha_layers.to(device)
+ self.threshold = threshold
+
+
+class AttentionControlEdit(AttentionStore, abc.ABC):
+ def step_callback(self, x_t):
+ if self.local_blend is not None:
+ x_t = self.local_blend(x_t, self.attention_store)
+ return x_t
+
+ def replace_self_attention(self, attn_base, att_replace):
+ if att_replace.shape[2] <= 16**2:
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
+ else:
+ return att_replace
+
+ @abc.abstractmethod
+ def replace_cross_attention(self, attn_base, att_replace):
+ raise NotImplementedError
+
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
+ # FIXME not replace correctly
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
+ h = attn.shape[0] // (self.batch_size)
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
+ attn_base, attn_repalce = attn[0], attn[1:]
+ if is_cross:
+ alpha_words = self.cross_replace_alpha[self.cur_step]
+ attn_repalce_new = (
+ self.replace_cross_attention(attn_base, attn_repalce) * alpha_words
+ + (1 - alpha_words) * attn_repalce
+ )
+ attn[1:] = attn_repalce_new
+ else:
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
+ return attn
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
+ self_replace_steps: Union[float, Tuple[float, float]],
+ local_blend: Optional[LocalBlend],
+ tokenizer,
+ device,
+ ):
+ super(AttentionControlEdit, self).__init__()
+ # add tokenizer and device here
+
+ self.tokenizer = tokenizer
+ self.device = device
+
+ self.batch_size = len(prompts)
+ self.cross_replace_alpha = get_time_words_attention_alpha(
+ prompts, num_steps, cross_replace_steps, self.tokenizer
+ ).to(self.device)
+ if isinstance(self_replace_steps, float):
+ self_replace_steps = 0, self_replace_steps
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
+ self.local_blend = local_blend # 在外面定义后传进来
+
+
+class AttentionReplace(AttentionControlEdit):
+ def replace_cross_attention(self, attn_base, att_replace):
+ return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: float,
+ self_replace_steps: float,
+ local_blend: Optional[LocalBlend] = None,
+ tokenizer=None,
+ device=None,
+ ):
+ super(AttentionReplace, self).__init__(
+ prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
+ )
+ self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
+
+
+class AttentionRefine(AttentionControlEdit):
+ def replace_cross_attention(self, attn_base, att_replace):
+ attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
+ attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
+ return attn_replace
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: float,
+ self_replace_steps: float,
+ local_blend: Optional[LocalBlend] = None,
+ tokenizer=None,
+ device=None,
+ ):
+ super(AttentionRefine, self).__init__(
+ prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
+ )
+ self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
+ self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
+
+
+class AttentionReweight(AttentionControlEdit):
+ def replace_cross_attention(self, attn_base, att_replace):
+ if self.prev_controller is not None:
+ attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
+ attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
+ return attn_replace
+
+ def __init__(
+ self,
+ prompts,
+ num_steps: int,
+ cross_replace_steps: float,
+ self_replace_steps: float,
+ equalizer,
+ local_blend: Optional[LocalBlend] = None,
+ controller: Optional[AttentionControlEdit] = None,
+ tokenizer=None,
+ device=None,
+ ):
+ super(AttentionReweight, self).__init__(
+ prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
+ )
+ self.equalizer = equalizer.to(self.device)
+ self.prev_controller = controller
+
+
+### util functions for all Edits
+def update_alpha_time_word(
+ alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
+):
+ if isinstance(bounds, float):
+ bounds = 0, bounds
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
+ if word_inds is None:
+ word_inds = torch.arange(alpha.shape[2])
+ alpha[:start, prompt_ind, word_inds] = 0
+ alpha[start:end, prompt_ind, word_inds] = 1
+ alpha[end:, prompt_ind, word_inds] = 0
+ return alpha
+
+
+def get_time_words_attention_alpha(
+ prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
+):
+ if not isinstance(cross_replace_steps, dict):
+ cross_replace_steps = {"default_": cross_replace_steps}
+ if "default_" not in cross_replace_steps:
+ cross_replace_steps["default_"] = (0.0, 1.0)
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
+ for i in range(len(prompts) - 1):
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i)
+ for key, item in cross_replace_steps.items():
+ if key != "default_":
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
+ for i, ind in enumerate(inds):
+ if len(ind) > 0:
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
+ return alpha_time_words
+
+
+### util functions for LocalBlend and ReplacementEdit
+def get_word_inds(text: str, word_place: int, tokenizer):
+ split_text = text.split(" ")
+ if isinstance(word_place, str):
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
+ elif isinstance(word_place, int):
+ word_place = [word_place]
+ out = []
+ if len(word_place) > 0:
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
+ cur_len, ptr = 0, 0
+
+ for i in range(len(words_encode)):
+ cur_len += len(words_encode[i])
+ if ptr in word_place:
+ out.append(i + 1)
+ if cur_len >= len(split_text[ptr]):
+ ptr += 1
+ cur_len = 0
+ return np.array(out)
+
+
+### util functions for ReplacementEdit
+def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
+ words_x = x.split(" ")
+ words_y = y.split(" ")
+ if len(words_x) != len(words_y):
+ raise ValueError(
+ f"attention replacement edit can only be applied on prompts with the same length"
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words."
+ )
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
+ mapper = np.zeros((max_len, max_len))
+ i = j = 0
+ cur_inds = 0
+ while i < max_len and j < max_len:
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
+ if len(inds_source_) == len(inds_target_):
+ mapper[inds_source_, inds_target_] = 1
+ else:
+ ratio = 1 / len(inds_target_)
+ for i_t in inds_target_:
+ mapper[inds_source_, i_t] = ratio
+ cur_inds += 1
+ i += len(inds_source_)
+ j += len(inds_target_)
+ elif cur_inds < len(inds_source):
+ mapper[i, j] = 1
+ i += 1
+ j += 1
+ else:
+ mapper[j, j] = 1
+ i += 1
+ j += 1
+
+ return torch.from_numpy(mapper).float()
+
+
+def get_replacement_mapper(prompts, tokenizer, max_len=77):
+ x_seq = prompts[0]
+ mappers = []
+ for i in range(1, len(prompts)):
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
+ mappers.append(mapper)
+ return torch.stack(mappers)
+
+
+### util functions for ReweightEdit
+def get_equalizer(
+ text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
+):
+ if isinstance(word_select, (int, str)):
+ word_select = (word_select,)
+ equalizer = torch.ones(len(values), 77)
+ values = torch.tensor(values, dtype=torch.float32)
+ for word in word_select:
+ inds = get_word_inds(text, word, tokenizer)
+ equalizer[:, inds] = values
+ return equalizer
+
+
+### util functions for RefinementEdit
+class ScoreParams:
+ def __init__(self, gap, match, mismatch):
+ self.gap = gap
+ self.match = match
+ self.mismatch = mismatch
+
+ def mis_match_char(self, x, y):
+ if x != y:
+ return self.mismatch
+ else:
+ return self.match
+
+
+def get_matrix(size_x, size_y, gap):
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
+ return matrix
+
+
+def get_traceback_matrix(size_x, size_y):
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+ matrix[0, 1:] = 1
+ matrix[1:, 0] = 2
+ matrix[0, 0] = 4
+ return matrix
+
+
+def global_align(x, y, score):
+ matrix = get_matrix(len(x), len(y), score.gap)
+ trace_back = get_traceback_matrix(len(x), len(y))
+ for i in range(1, len(x) + 1):
+ for j in range(1, len(y) + 1):
+ left = matrix[i, j - 1] + score.gap
+ up = matrix[i - 1, j] + score.gap
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
+ matrix[i, j] = max(left, up, diag)
+ if matrix[i, j] == left:
+ trace_back[i, j] = 1
+ elif matrix[i, j] == up:
+ trace_back[i, j] = 2
+ else:
+ trace_back[i, j] = 3
+ return matrix, trace_back
+
+
+def get_aligned_sequences(x, y, trace_back):
+ x_seq = []
+ y_seq = []
+ i = len(x)
+ j = len(y)
+ mapper_y_to_x = []
+ while i > 0 or j > 0:
+ if trace_back[i, j] == 3:
+ x_seq.append(x[i - 1])
+ y_seq.append(y[j - 1])
+ i = i - 1
+ j = j - 1
+ mapper_y_to_x.append((j, i))
+ elif trace_back[i][j] == 1:
+ x_seq.append("-")
+ y_seq.append(y[j - 1])
+ j = j - 1
+ mapper_y_to_x.append((j, -1))
+ elif trace_back[i][j] == 2:
+ x_seq.append(x[i - 1])
+ y_seq.append("-")
+ i = i - 1
+ elif trace_back[i][j] == 4:
+ break
+ mapper_y_to_x.reverse()
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
+
+
+def get_mapper(x: str, y: str, tokenizer, max_len=77):
+ x_seq = tokenizer.encode(x)
+ y_seq = tokenizer.encode(y)
+ score = ScoreParams(0, 1, -1)
+ matrix, trace_back = global_align(x_seq, y_seq, score)
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
+ alphas = torch.ones(max_len)
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
+ mapper = torch.zeros(max_len, dtype=torch.int64)
+ mapper[: mapper_base.shape[0]] = mapper_base[:, 1]
+ mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq))
+ return mapper, alphas
+
+
+def get_refinement_mapper(prompts, tokenizer, max_len=77):
+ x_seq = prompts[0]
+ mappers, alphas = [], []
+ for i in range(1, len(prompts)):
+ mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
+ mappers.append(mapper)
+ alphas.append(alpha)
+ return torch.stack(mappers), torch.stack(alphas)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d801de86cc7024d1faedd9416994beedcb8762e2
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -0,0 +1,1463 @@
+# Copyright 2023 TencentARC and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler
+ >>> from diffusers.utils import load_image
+ >>> from controlnet_aux.midas import MidasDetector
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> image = load_image(img_url).resize((1024, 1024))
+ >>> mask_image = load_image(mask_url).resize((1024, 1024))
+
+ >>> midas_depth = MidasDetector.from_pretrained(
+ ... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+ ... ).to("cuda")
+
+ >>> depth_image = midas_depth(
+ ... image, detect_resolution=512, image_resolution=1024
+ ... )
+
+ >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+
+ >>> adapter = T2IAdapter.from_pretrained(
+ ... "Adapter/t2iadapter",
+ ... subfolder="sketch_sdxl_1.0",
+ ... torch_dtype=torch.float16,
+ ... adapter_type="full_adapter_xl",
+ ... )
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "diffusers/controlnet-depth-sdxl-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True
+ ... ).to("cuda")
+
+ >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
+
+ >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ ... model_id,
+ ... adapter=adapter,
+ ... controlnet=controlnet,
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... scheduler=scheduler
+ ... ).to("cuda")
+
+ >>> strength = 0.5
+
+ >>> generator = torch.manual_seed(42)
+ >>> sketch_image_out = pipe(
+ ... prompt="a photo of a tiger sitting on a park bench",
+ ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
+ ... adapter_image=depth_image,
+ ... control_image=mask_image,
+ ... adapter_conditioning_scale=strength,
+ ... controlnet_conditioning_scale=strength,
+ ... generator=generator,
+ ... guidance_scale=7.5,
+ ... ).images[0]
+ ```
+"""
+
+
+def _preprocess_adapter_image(image, height, width):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
+ image = [
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ if image[0].ndim == 3:
+ image = torch.stack(image, dim=0)
+ elif image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ else:
+ raise ValueError(
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ )
+ return image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLControlNetAdapterPipeline(
+ DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+ https://arxiv.org/abs/2302.08453
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
+ together.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
+ controlnet: Union[ControlNetModel, MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ adapter=adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.default_sample_size = self.unet.config.sample_size
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def check_conditions(
+ self,
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ ):
+ # controlnet checks
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Check controlnet `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(control_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(control_image, list):
+ raise TypeError("For multiple controlnets: `control_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in control_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(control_image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in control_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ # adapter checks
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ self.check_image(adapter_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if not isinstance(adapter_image, list):
+ raise TypeError("For multiple adapters: `adapter_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in adapter_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(adapter_image) != len(self.adapter.adapters):
+ raise ValueError(
+ f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters."
+ )
+
+ for image_ in adapter_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `adapter_conditioning_scale`
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ if not isinstance(adapter_conditioning_scale, float):
+ raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if isinstance(adapter_conditioning_scale, list):
+ if any(isinstance(i, list) for i in adapter_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len(
+ self.adapter.adapters
+ ):
+ raise ValueError(
+ "For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of adapters"
+ )
+ else:
+ assert False
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ return height, width
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ adapter_image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
+ adapter_conditioning_factor: float = 1.0,
+ clip_skip: Optional[int] = None,
+ controlnet_conditioning_scale=1.0,
+ guess_mode: bool = False,
+ control_guidance_start: float = 0.0,
+ control_guidance_end: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ adapter_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
+ type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
+ accepted as an image. The control image is automatically resized to fit the output image.
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]
+ instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ adapter_conditioning_factor (`float`, *optional*, defaults to 1.0):
+ The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is
+ `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for
+ all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+ adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter
+
+ # 0. Default height and width to unet
+
+ height, width = self._default_height_width(height, width, adapter_image)
+ device = self._execution_device
+
+ if isinstance(adapter, MultiAdapter):
+ adapter_input = []
+
+ for one_image in adapter_image:
+ one_image = _preprocess_adapter_image(one_image, height, width)
+ one_image = one_image.to(device=device, dtype=adapter.dtype)
+ adapter_input.append(one_image)
+ else:
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width)
+ adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 0.1 align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+ if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float):
+ adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.adapters)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ self.check_conditions(
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ clip_skip=clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings & adapter features
+ if isinstance(adapter, MultiAdapter):
+ adapter_state = adapter(adapter_input, adapter_conditioning_scale)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v
+ else:
+ adapter_state = adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
+ if num_images_per_prompt > 1:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
+ if do_classifier_free_guidance:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
+
+ # 7.2 Prepare control images
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ raise ValueError(f"{controlnet.__class__} is not supported.")
+
+ # 8.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ if isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_keep.append(keeps)
+ else:
+ controlnet_keep.append(keeps[0])
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ if i < int(num_inference_steps * adapter_conditioning_factor):
+ down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
+ else:
+ down_intrablock_additional_residuals = None
+
+ # ----------- ControlNet
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input_controlnet
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals, # t2iadapter
+ down_block_additional_residuals=down_block_res_samples, # controlnet
+ mid_block_additional_residual=mid_block_res_sample, # controlnet
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc612edbc20ebeccea9a3f8366acb0bc17c305d4
--- /dev/null
+++ b/diffusers/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -0,0 +1,1896 @@
+# Copyright 2023 Jake Babbidge, TencentARC and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ignore the entire file for precommit
+# type: ignore
+
+import inspect
+from collections.abc import Callable
+from typing import Any, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import (
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from diffusers import DiffusionPipeline
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ LoraLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import (
+ AutoencoderKL,
+ ControlNetModel,
+ MultiAdapter,
+ T2IAdapter,
+ UNet2DConditionModel,
+)
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline, T2IAdapter
+ >>> from diffusers.utils import load_image
+ >>> from PIL import Image
+ >>> from controlnet_aux.midas import MidasDetector
+
+ >>> adapter = T2IAdapter.from_pretrained(
+ ... "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
+ ... ).to("cuda")
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "diffusers/controlnet-depth-sdxl-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True
+ ... ).to("cuda")
+
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... custom_pipeline="stable_diffusion_xl_adapter_controlnet_inpaint",
+ ... adapter=adapter,
+ ... controlnet=controlnet,
+ ... ).to("cuda")
+
+ >>> prompt = "a tiger sitting on a park bench"
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> image = load_image(img_url).resize((1024, 1024))
+ >>> mask_image = load_image(mask_url).resize((1024, 1024))
+
+ >>> midas_depth = MidasDetector.from_pretrained(
+ ... "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
+ ... ).to("cuda")
+
+ >>> depth_image = midas_depth(
+ ... image, detect_resolution=512, image_resolution=1024
+ ... )
+
+ >>> strength = 0.4
+
+ >>> generator = torch.manual_seed(42)
+
+ >>> result_image = pipe(
+ ... image=image,
+ ... mask_image=mask,
+ ... adapter_image=depth_image,
+ ... control_image=depth_image,
+ ... controlnet_conditioning_scale=strength,
+ ... adapter_conditioning_scale=strength,
+ ... strength=0.7,
+ ... generator=generator,
+ ... prompt=prompt,
+ ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
+ ... num_inference_steps=50
+ ... ).images[0]
+ ```
+"""
+
+
+def _preprocess_adapter_image(image, height, width):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
+ image = [
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ if image[0].ndim == 3:
+ image = torch.stack(image, dim=0)
+ elif image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ else:
+ raise ValueError(
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ )
+ return image
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, Union[PIL.Image.Image, np.ndarray]):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, Union[PIL.Image.Image, np.ndarray]):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLControlNetAdapterInpaintPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
+ https://arxiv.org/abs/2302.08453
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
+ adapter_weights (`List[float]`, *optional*, defaults to None):
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
+ together.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ adapter: Union[T2IAdapter, MultiAdapter],
+ controlnet: Union[ControlNetModel, MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ ):
+ super().__init__()
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ adapter=adapter,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+ self.default_sample_size = self.unet.config.sample_size
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def check_conditions(
+ self,
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ ):
+ # controlnet checks
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if isinstance(self.controlnet, MultiControlNetModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Check controlnet `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ self.check_image(control_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if not isinstance(control_image, list):
+ raise TypeError("For multiple controlnets: `control_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in control_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(control_image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for image_ in control_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if (
+ isinstance(self.controlnet, ControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
+ ):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.controlnet, MultiControlNetModel)
+ or is_compiled
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
+ ):
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ # adapter checks
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ self.check_image(adapter_image, prompt, prompt_embeds)
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if not isinstance(adapter_image, list):
+ raise TypeError("For multiple adapters: `adapter_image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in adapter_image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(adapter_image) != len(self.adapter.adapters):
+ raise ValueError(
+ f"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters."
+ )
+
+ for image_ in adapter_image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `adapter_conditioning_scale`
+ if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):
+ if not isinstance(adapter_conditioning_scale, float):
+ raise TypeError("For single adapter: `adapter_conditioning_scale` must be type `float`.")
+ elif (
+ isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)
+ ):
+ if isinstance(adapter_conditioning_scale, list):
+ if any(isinstance(i, list) for i in adapter_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len(
+ self.adapter.adapters
+ ):
+ raise ValueError(
+ "For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of adapters"
+ )
+ else:
+ assert False
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask,
+ size=(
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ ),
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ masked_image_latents = None
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[-2]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[-1]
+
+ # round down to nearest multiple of `self.adapter.downscale_factor`
+ width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
+
+ return height, width
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, list[str]]] = None,
+ prompt_2: Optional[Union[str, list[str]]] = None,
+ image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
+ mask_image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
+ adapter_image: PipelineImageInput = None,
+ control_image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, list[str]]] = None,
+ negative_prompt_2: Optional[Union[str, list[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
+ latents: Optional[Union[torch.FloatTensor]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[tuple[int, int]] = None,
+ crops_coords_top_left: Optional[tuple[int, int]] = (0, 0),
+ target_size: Optional[tuple[int, int]] = None,
+ adapter_conditioning_scale: Optional[Union[float, list[float]]] = 1.0,
+ cond_tau: float = 1.0,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ controlnet_conditioning_scale=1.0,
+ guess_mode: bool = False,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ adapter_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
+ type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
+ accepted as an image. The control image is automatically resized to fit the output image.
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]
+ instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
+ residual in the original unet. If multiple adapters are specified in init, you can set the
+ corresponding scale as a list.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+ adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter
+ height, width = self._default_height_width(height, width, adapter_image)
+ device = self._execution_device
+
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 0.1 align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+ if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float):
+ adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.nets)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ self.check_conditions(
+ prompt,
+ prompt_embeds,
+ adapter_image,
+ control_image,
+ adapter_conditioning_scale,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(denoising_end, float) and 0 < dnv < 1
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=denoising_start if denoising_value_valid else None,
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = denoising_start is None
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Prepare added time ids & embeddings & adapter features
+ adapter_input = adapter_input.type(latents.dtype)
+ adapter_state = adapter(adapter_input)
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v * adapter_conditioning_scale
+ if num_images_per_prompt > 1:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
+ if do_classifier_free_guidance:
+ for k, v in enumerate(adapter_state):
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
+
+ # 10.2 Prepare control images
+ if isinstance(controlnet, ControlNetModel):
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ raise ValueError(f"{controlnet.__class__} is not supported.")
+
+ # 8.2 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ if isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_keep.append(keeps)
+ else:
+ controlnet_keep.append(keeps[0])
+ # ----------------------------------------------------------------
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 11.1 Apply denoising_end
+ if (
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
+ )
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ if i < int(num_inference_steps * cond_tau):
+ down_block_additional_residuals = [state.clone() for state in adapter_state]
+ else:
+ down_block_additional_residuals = None
+
+ # ----------- ControlNet
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ controlnet_added_cond_kwargs = {
+ "text_embeds": add_text_embeds.chunk(2)[1],
+ "time_ids": add_time_ids.chunk(2)[1],
+ }
+ else:
+ control_model_input = latent_model_input_controlnet
+ controlnet_prompt_embeds = prompt_embeds
+ controlnet_added_cond_kwargs = added_cond_kwargs
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ down_intrablock_additional_residuals=down_block_additional_residuals, # t2iadapter
+ down_block_additional_residuals=down_block_res_samples, # controlnet
+ mid_block_additional_residual=mid_block_res_sample, # controlnet
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(
+ noise_pred,
+ noise_pred_text,
+ guidance_rescale=guidance_rescale,
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper,
+ noise,
+ torch.tensor([noise_timestep]),
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if output_type != "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/pipeline_zero1to3.py b/diffusers/examples/community/pipeline_zero1to3.py
new file mode 100644
index 0000000000000000000000000000000000000000..600cf2dc1b6309b8a9f0521595c4efdf1edb86d7
--- /dev/null
+++ b/diffusers/examples/community/pipeline_zero1to3.py
@@ -0,0 +1,893 @@
+# A diffuser version implementation of Zero1to3 (https://github.com/cvlab-columbia/zero123), ICCV 2023
+# by Xin Kong
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import kornia
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
+
+# from ...configuration_utils import FrozenDict
+# from ...models import AutoencoderKL, UNet2DConditionModel
+# from ...schedulers import KarrasDiffusionSchedulers
+# from ...utils import (
+# deprecate,
+# is_accelerate_available,
+# is_accelerate_version,
+# logging,
+# randn_tensor,
+# replace_example_docstring,
+# )
+# from ..pipeline_utils import DiffusionPipeline
+# from . import StableDiffusionPipelineOutput
+# from .safety_checker import StableDiffusionSafetyChecker
+from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
+from diffusers.configuration_utils import ConfigMixin, FrozenDict
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+# todo
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class CCProjection(ModelMixin, ConfigMixin):
+ def __init__(self, in_channel=772, out_channel=768):
+ super().__init__()
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.projection = torch.nn.Linear(in_channel, out_channel)
+
+ def forward(self, x):
+ return self.projection(x)
+
+
+class Zero1to3StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for single view conditioned novel view generation using Zero1to3.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ cc_projection ([`CCProjection`]):
+ Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ cc_projection: CCProjection,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ cc_projection=cc_projection,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+ # self.model_mode = None
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def CLIP_preprocess(self, x):
+ dtype = x.dtype
+ # following openai's implementation
+ # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741
+ # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608
+ if isinstance(x, torch.Tensor):
+ if x.min() < -1.0 or x.max() > 1.0:
+ raise ValueError("Expected input tensor to have values in the range [-1, 1]")
+ x = kornia.geometry.resize(
+ x.to(torch.float32), (224, 224), interpolation="bicubic", align_corners=True, antialias=False
+ ).to(dtype=dtype)
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(
+ x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.26862954, 0.26130258, 0.27577711])
+ )
+ return x
+
+ # from image_variation
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ assert image.ndim == 4, "Image must have 4 dimensions"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ image = image.to(device=device, dtype=dtype)
+
+ image = self.CLIP_preprocess(image)
+ # if not isinstance(image, torch.Tensor):
+ # # 0-255
+ # print("Warning: image is processed by hf's preprocess, which is different from openai original's.")
+ # image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+ image_embeddings = self.image_encoder(image).image_embeds.to(dtype=dtype)
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+
+ return image_embeddings
+
+ def _encode_pose(self, pose, device, num_images_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.cc_projection.parameters()).dtype
+ if isinstance(pose, torch.Tensor):
+ pose_embeddings = pose.unsqueeze(1).to(device=device, dtype=dtype)
+ else:
+ if isinstance(pose[0], list):
+ pose = torch.Tensor(pose)
+ else:
+ pose = torch.Tensor([pose])
+ x, y, z = pose[:, 0].unsqueeze(1), pose[:, 1].unsqueeze(1), pose[:, 2].unsqueeze(1)
+ pose_embeddings = (
+ torch.cat([torch.deg2rad(x), torch.sin(torch.deg2rad(y)), torch.cos(torch.deg2rad(y)), z], dim=-1)
+ .unsqueeze(1)
+ .to(device=device, dtype=dtype)
+ ) # B, 1, 4
+ # duplicate pose embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = pose_embeddings.shape
+ pose_embeddings = pose_embeddings.repeat(1, num_images_per_prompt, 1)
+ pose_embeddings = pose_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(pose_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ pose_embeddings = torch.cat([negative_prompt_embeds, pose_embeddings])
+ return pose_embeddings
+
+ def _encode_image_with_pose(self, image, pose, device, num_images_per_prompt, do_classifier_free_guidance):
+ img_prompt_embeds = self._encode_image(image, device, num_images_per_prompt, False)
+ pose_prompt_embeds = self._encode_pose(pose, device, num_images_per_prompt, False)
+ prompt_embeds = torch.cat([img_prompt_embeds, pose_prompt_embeds], dim=-1)
+ prompt_embeds = self.cc_projection(prompt_embeds)
+ # prompt_embeds = img_prompt_embeds
+ # follow 0123, add negative prompt, after projection
+ if do_classifier_free_guidance:
+ negative_prompt = torch.zeros_like(prompt_embeds)
+ prompt_embeds = torch.cat([negative_prompt, prompt_embeds])
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, image, height, width, callback_steps):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_img_latents(self, image, batch_size, dtype, device, generator=None, do_classifier_free_guidance=False):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ assert image.ndim == 4, "Image must have 4 dimensions"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i])
+ for i in range(batch_size) # sample
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.mode()
+
+ # init_latents = self.vae.config.scaling_factor * init_latents # todo in original zero123's inference gradio_new.py, model.encode_first_stage() is not scaled by scaling_factor
+ if batch_size > init_latents.shape[0]:
+ # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1)
+ num_images_per_prompt = batch_size // init_latents.shape[0]
+ # duplicate image latents for each generation per prompt, using mps friendly method
+ bs_embed, emb_c, emb_h, emb_w = init_latents.shape
+ init_latents = init_latents.unsqueeze(1)
+ init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1)
+ init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w)
+
+ # init_latents = torch.cat([init_latents]*2) if do_classifier_free_guidance else init_latents # follow zero123
+ init_latents = (
+ torch.cat([torch.zeros_like(init_latents), init_latents]) if do_classifier_free_guidance else init_latents
+ )
+
+ init_latents = init_latents.to(device=device, dtype=dtype)
+ return init_latents
+
+ # def load_cc_projection(self, pretrained_weights=None):
+ # self.cc_projection = torch.nn.Linear(772, 768)
+ # torch.nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
+ # torch.nn.init.zeros_(list(self.cc_projection.parameters())[1])
+ # if pretrained_weights is not None:
+ # self.cc_projection.load_state_dict(pretrained_weights)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ input_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ prompt_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ poses: Union[List[float], List[List[float]]] = None,
+ torch_dtype=torch.float32,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ input_imgs (`PIL` or `List[PIL]`, *optional*):
+ The single input image for each 3D object
+ prompt_imgs (`PIL` or `List[PIL]`, *optional*):
+ Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ # input_image = hint_imgs
+ self.check_inputs(input_imgs, height, width, callback_steps)
+
+ # 2. Define call parameters
+ if isinstance(input_imgs, PIL.Image.Image):
+ batch_size = 1
+ elif isinstance(input_imgs, list):
+ batch_size = len(input_imgs)
+ else:
+ batch_size = input_imgs.shape[0]
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image with pose as prompt
+ prompt_embeds = self._encode_image_with_pose(
+ prompt_imgs, poses, device, num_images_per_prompt, do_classifier_free_guidance
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ 4,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare image latents
+ img_latents = self.prepare_img_latents(
+ input_imgs,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = torch.cat([latent_model_input, img_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 8. Post-processing
+ has_nsfw_concept = None
+ if output_type == "latent":
+ image = latents
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/run_onnx_controlnet.py b/diffusers/examples/community/run_onnx_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed9b2331841480b04f237be997685287128ba63c
--- /dev/null
+++ b/diffusers/examples/community/run_onnx_controlnet.py
@@ -0,0 +1,911 @@
+import argparse
+import inspect
+import os
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import CLIPTokenizer
+
+from diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+ ... )
+ >>> np_image = np.array(image)
+
+ >>> # get canny image
+ >>> np_image = cv2.Canny(np_image, 100, 200)
+ >>> np_image = np_image[:, :, None]
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ >>> canny_image = Image.fromarray(np_image)
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "futuristic-looking woman",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... image=image,
+ ... control_image=canny_image,
+ ... ).images[0]
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: KarrasDiffusionSchedulers
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (4 - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ num_controlnet,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ if num_controlnet == 1:
+ self.check_image(image, prompt, prompt_embeds)
+ elif num_controlnet > 1:
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != num_controlnet:
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if num_controlnet == 1:
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif num_controlnet > 1:
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif (
+ isinstance(controlnet_conditioning_scale, list)
+ and len(controlnet_conditioning_scale) != num_controlnet
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if num_controlnet > 1:
+ if len(control_guidance_start) != num_controlnet:
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ _image = image.cpu().detach().numpy()
+ init_latents = self.vae_encoder(sample=_image)[0]
+ init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ num_controlnet: int,
+ fp16: bool = True,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ control_image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The initial image will be used as the starting point for the image generation process. Can also accept
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if fp16:
+ torch_dtype = torch.float16
+ np_dtype = np.float16
+ else:
+ torch_dtype = torch.float32
+ np_dtype = np.float32
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = num_controlnet
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ num_controlnet,
+ prompt,
+ control_image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # 4. Prepare image
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
+
+ # 5. Prepare controlnet_conditioning_image
+ if num_controlnet == 1:
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif num_controlnet > 1:
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ torch_dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # predict the noise residual
+ _latent_model_input = latent_model_input.cpu().detach().numpy()
+ _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype)
+ _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype)
+
+ if num_controlnet == 1:
+ control_images = np.array([control_image], dtype=np_dtype)
+ else:
+ control_images = []
+ for _control_img in control_image:
+ _control_img = _control_img.cpu().detach().numpy()
+ control_images.append(_control_img)
+ control_images = np.array(control_images, dtype=np_dtype)
+
+ control_scales = np.array(cond_scale, dtype=np_dtype)
+ control_scales = np.resize(control_scales, (num_controlnet, 1))
+
+ noise_pred = self.unet(
+ sample=_latent_model_input,
+ timestep=_t,
+ encoder_hidden_states=_prompt_embeds,
+ controlnet_conds=control_images,
+ conditioning_scales=control_scales,
+ )[0]
+ noise_pred = torch.from_numpy(noise_pred).to(device)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ _latents = latents.cpu().detach().numpy() / 0.18215
+ _latents = np.array(_latents, dtype=np_dtype)
+ image = self.vae_decoder(latent_sample=_latents)[0]
+ image = torch.from_numpy(image).to(device, dtype=torch.float32)
+ has_nsfw_concept = None
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--sd_model",
+ type=str,
+ required=True,
+ help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
+ )
+
+ parser.add_argument(
+ "--onnx_model_dir",
+ type=str,
+ required=True,
+ help="Path to the ONNX directory",
+ )
+
+ parser.add_argument("--qr_img_path", type=str, required=True, help="Path to the qr code image")
+
+ args = parser.parse_args()
+
+ qr_image = Image.open(args.qr_img_path)
+ qr_image = qr_image.resize((512, 512))
+
+ # init stable diffusion pipeline
+ pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model)
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+
+ provider = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ onnx_pipeline = OnnxStableDiffusionControlNetImg2ImgPipeline(
+ vae_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_encoder"), provider=provider
+ ),
+ vae_decoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_decoder"), provider=provider
+ ),
+ text_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "text_encoder"), provider=provider
+ ),
+ tokenizer=pipeline.tokenizer,
+ unet=OnnxRuntimeModel.from_pretrained(os.path.join(args.onnx_model_dir, "unet"), provider=provider),
+ scheduler=pipeline.scheduler,
+ )
+ onnx_pipeline = onnx_pipeline.to("cuda")
+
+ prompt = "a cute cat fly to the moon"
+ negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
+
+ for i in range(10):
+ start_time = time.time()
+ image = onnx_pipeline(
+ num_controlnet=2,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=qr_image,
+ control_image=[qr_image, qr_image],
+ width=512,
+ height=512,
+ strength=0.75,
+ num_inference_steps=20,
+ num_images_per_prompt=1,
+ controlnet_conditioning_scale=[0.8, 0.8],
+ control_guidance_start=[0.3, 0.3],
+ control_guidance_end=[0.9, 0.9],
+ ).images[0]
+ print(time.time() - start_time)
+ image.save("output_qr_code.png")
diff --git a/diffusers/examples/community/run_tensorrt_controlnet.py b/diffusers/examples/community/run_tensorrt_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..aece5484e304ef21a8ae205938fcd4d9d6c1fa83
--- /dev/null
+++ b/diffusers/examples/community/run_tensorrt_controlnet.py
@@ -0,0 +1,1022 @@
+import argparse
+import atexit
+import inspect
+import os
+import time
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import pycuda.driver as cuda
+import tensorrt as trt
+import torch
+from PIL import Image
+from pycuda.tools import make_default_context
+from transformers import CLIPTokenizer
+
+from diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+# Initialize CUDA
+cuda.init()
+context = make_default_context()
+device = context.get_device()
+atexit.register(context.pop)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def load_engine(trt_runtime, engine_path):
+ with open(engine_path, "rb") as f:
+ engine_data = f.read()
+ engine = trt_runtime.deserialize_cuda_engine(engine_data)
+ return engine
+
+
+class TensorRTModel:
+ def __init__(
+ self,
+ trt_engine_path,
+ **kwargs,
+ ):
+ cuda.init()
+ stream = cuda.Stream()
+ TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
+ trt.init_libnvinfer_plugins(TRT_LOGGER, "")
+ trt_runtime = trt.Runtime(TRT_LOGGER)
+ engine = load_engine(trt_runtime, trt_engine_path)
+ context = engine.create_execution_context()
+
+ # allocates memory for network inputs/outputs on both CPU and GPU
+ host_inputs = []
+ cuda_inputs = []
+ host_outputs = []
+ cuda_outputs = []
+ bindings = []
+ input_names = []
+ output_names = []
+
+ for binding in engine:
+ datatype = engine.get_binding_dtype(binding)
+ if datatype == trt.DataType.HALF:
+ dtype = np.float16
+ else:
+ dtype = np.float32
+
+ shape = tuple(engine.get_binding_shape(binding))
+ host_mem = cuda.pagelocked_empty(shape, dtype)
+ cuda_mem = cuda.mem_alloc(host_mem.nbytes)
+ bindings.append(int(cuda_mem))
+
+ if engine.binding_is_input(binding):
+ host_inputs.append(host_mem)
+ cuda_inputs.append(cuda_mem)
+ input_names.append(binding)
+ else:
+ host_outputs.append(host_mem)
+ cuda_outputs.append(cuda_mem)
+ output_names.append(binding)
+
+ self.stream = stream
+ self.context = context
+ self.engine = engine
+
+ self.host_inputs = host_inputs
+ self.cuda_inputs = cuda_inputs
+ self.host_outputs = host_outputs
+ self.cuda_outputs = cuda_outputs
+ self.bindings = bindings
+ self.batch_size = engine.max_batch_size
+
+ self.input_names = input_names
+ self.output_names = output_names
+
+ def __call__(self, **kwargs):
+ context = self.context
+ stream = self.stream
+ bindings = self.bindings
+
+ host_inputs = self.host_inputs
+ cuda_inputs = self.cuda_inputs
+ host_outputs = self.host_outputs
+ cuda_outputs = self.cuda_outputs
+
+ for idx, input_name in enumerate(self.input_names):
+ _input = kwargs[input_name]
+ np.copyto(host_inputs[idx], _input)
+ # transfer input data to the GPU
+ cuda.memcpy_htod_async(cuda_inputs[idx], host_inputs[idx], stream)
+
+ context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+
+ result = {}
+ for idx, output_name in enumerate(self.output_names):
+ # transfer predictions back from the GPU
+ cuda.memcpy_dtoh_async(host_outputs[idx], cuda_outputs[idx], stream)
+ result[output_name] = host_outputs[idx]
+
+ stream.synchronize()
+
+ return result
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
+ ... )
+ >>> np_image = np.array(image)
+
+ >>> # get canny image
+ >>> np_image = cv2.Canny(np_image, 100, 200)
+ >>> np_image = np_image[:, :, None]
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
+ >>> canny_image = Image.fromarray(np_image)
+
+ >>> # load control net and stable diffusion v1-5
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> # speed up diffusion process with faster scheduler and memory optimization
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> generator = torch.manual_seed(0)
+ >>> image = pipe(
+ ... "futuristic-looking woman",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... image=image,
+ ... control_image=canny_image,
+ ... ).images[0]
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
+ vae_encoder: OnnxRuntimeModel
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: TensorRTModel
+ scheduler: KarrasDiffusionSchedulers
+
+ def __init__(
+ self,
+ vae_encoder: OnnxRuntimeModel,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: TensorRTModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae_encoder=vae_encoder,
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (4 - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+ )
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: Optional[int],
+ do_classifier_free_guidance: bool,
+ negative_prompt: Optional[str],
+ prompt_embeds: Optional[np.ndarray] = None,
+ negative_prompt_embeds: Optional[np.ndarray] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ prompt to be encoded
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`np.ndarray`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
+
+ if not np.array_equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
+
+ prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt] * batch_size
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ warnings.warn(
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
+ " use VaeImageProcessor instead",
+ FutureWarning,
+ )
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ num_controlnet,
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ if num_controlnet == 1:
+ self.check_image(image, prompt, prompt_embeds)
+ elif num_controlnet > 1:
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ # When `image` is a nested list:
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
+ elif any(isinstance(i, list) for i in image):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif len(image) != num_controlnet:
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets."
+ )
+
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if num_controlnet == 1:
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif num_controlnet > 1:
+ if isinstance(controlnet_conditioning_scale, list):
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
+ elif (
+ isinstance(controlnet_conditioning_scale, list)
+ and len(controlnet_conditioning_scale) != num_controlnet
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ if num_controlnet > 1:
+ if len(control_guidance_start) != num_controlnet:
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_control_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ _image = image.cpu().detach().numpy()
+ init_latents = self.vae_encoder(sample=_image)[0]
+ init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many initial images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ num_controlnet: int,
+ fp16: bool = True,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ control_image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The initial image will be used as the starting point for the image generation process. Can also accept
+ image latents as `image`, if passing latents directly, it will not be encoded again.
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the controlnet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the controlnet stops applying.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if fp16:
+ torch_dtype = torch.float16
+ np_dtype = np.float16
+ else:
+ torch_dtype = torch.float32
+ np_dtype = np.float32
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = num_controlnet
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ num_controlnet,
+ prompt,
+ control_image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # 4. Prepare image
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
+
+ # 5. Prepare controlnet_conditioning_image
+ if num_controlnet == 1:
+ control_image = self.prepare_control_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ elif num_controlnet > 1:
+ control_images = []
+
+ for control_image_ in control_image:
+ control_image_ = self.prepare_control_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=torch_dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ torch_dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # predict the noise residual
+ _latent_model_input = latent_model_input.cpu().detach().numpy()
+ _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype)
+ _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype)
+
+ if num_controlnet == 1:
+ control_images = np.array([control_image], dtype=np_dtype)
+ else:
+ control_images = []
+ for _control_img in control_image:
+ _control_img = _control_img.cpu().detach().numpy()
+ control_images.append(_control_img)
+ control_images = np.array(control_images, dtype=np_dtype)
+
+ control_scales = np.array(cond_scale, dtype=np_dtype)
+ control_scales = np.resize(control_scales, (num_controlnet, 1))
+
+ noise_pred = self.unet(
+ sample=_latent_model_input,
+ timestep=_t,
+ encoder_hidden_states=_prompt_embeds,
+ controlnet_conds=control_images,
+ conditioning_scales=control_scales,
+ )["noise_pred"]
+ noise_pred = torch.from_numpy(noise_pred).to(device)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ _latents = latents.cpu().detach().numpy() / 0.18215
+ _latents = np.array(_latents, dtype=np_dtype)
+ image = self.vae_decoder(latent_sample=_latents)[0]
+ image = torch.from_numpy(image).to(device, dtype=torch.float32)
+ has_nsfw_concept = None
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--sd_model",
+ type=str,
+ required=True,
+ help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
+ )
+
+ parser.add_argument(
+ "--onnx_model_dir",
+ type=str,
+ required=True,
+ help="Path to the ONNX directory",
+ )
+
+ parser.add_argument(
+ "--unet_engine_path",
+ type=str,
+ required=True,
+ help="Path to the unet + controlnet tensorrt model",
+ )
+
+ parser.add_argument("--qr_img_path", type=str, required=True, help="Path to the qr code image")
+
+ args = parser.parse_args()
+
+ qr_image = Image.open(args.qr_img_path)
+ qr_image = qr_image.resize((512, 512))
+
+ # init stable diffusion pipeline
+ pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model)
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+
+ provider = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ onnx_pipeline = TensorRTStableDiffusionControlNetImg2ImgPipeline(
+ vae_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_encoder"), provider=provider
+ ),
+ vae_decoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "vae_decoder"), provider=provider
+ ),
+ text_encoder=OnnxRuntimeModel.from_pretrained(
+ os.path.join(args.onnx_model_dir, "text_encoder"), provider=provider
+ ),
+ tokenizer=pipeline.tokenizer,
+ unet=TensorRTModel(args.unet_engine_path),
+ scheduler=pipeline.scheduler,
+ )
+ onnx_pipeline = onnx_pipeline.to("cuda")
+
+ prompt = "a cute cat fly to the moon"
+ negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
+
+ for i in range(10):
+ start_time = time.time()
+ image = onnx_pipeline(
+ num_controlnet=2,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=qr_image,
+ control_image=[qr_image, qr_image],
+ width=512,
+ height=512,
+ strength=0.75,
+ num_inference_steps=20,
+ num_images_per_prompt=1,
+ controlnet_conditioning_scale=[0.8, 0.8],
+ control_guidance_start=[0.3, 0.3],
+ control_guidance_end=[0.9, 0.9],
+ ).images[0]
+ print(time.time() - start_time)
+ image.save("output_qr_code.png")
diff --git a/diffusers/examples/community/sd_text2img_k_diffusion.py b/diffusers/examples/community/sd_text2img_k_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9371ac8819ed2e9f39e7a009380c6493719d245c
--- /dev/null
+++ b/diffusers/examples/community/sd_text2img_k_diffusion.py
@@ -0,0 +1,476 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
+
+from diffusers import DiffusionPipeline, LMSDiscreteScheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import is_accelerate_available, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ModelWrapper:
+ def __init__(self, model, alphas_cumprod):
+ self.model = model
+ self.alphas_cumprod = alphas_cumprod
+
+ def apply_model(self, *args, **kwargs):
+ if len(args) == 3:
+ encoder_hidden_states = args[-1]
+ args = args[:2]
+ if kwargs.get("cond", None) is not None:
+ encoder_hidden_states = kwargs.pop("cond")
+ return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ # get correct sigmas from LMS
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ model = ModelWrapper(unet, scheduler.alphas_cumprod)
+ if scheduler.config.prediction_type == "v_prediction":
+ self.k_diffusion_model = CompVisVDenoiser(model)
+ else:
+ self.k_diffusion_model = CompVisDenoiser(model)
+
+ def set_sampler(self, scheduler_type: str):
+ warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
+ return self.set_scheduler(scheduler_type)
+
+ def set_scheduler(self, scheduler_type: str):
+ library = importlib.import_module("k_diffusion")
+ sampling = getattr(library, "sampling")
+ self.sampler = getattr(sampling, scheduler_type)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = True
+ if guidance_scale <= 1.0:
+ raise ValueError("has to use guidance_scale")
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
+ sigmas = self.scheduler.sigmas
+ sigmas = sigmas.to(text_embeddings.dtype)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents = latents * sigmas[0]
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
+
+ def model_fn(x, t):
+ latent_model_input = torch.cat([x] * 2)
+
+ noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)
+
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ return noise_pred
+
+ latents = self.sampler(model_fn, latents, sigmas)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/seed_resize_stable_diffusion.py b/diffusers/examples/community/seed_resize_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9318277b8f01119aee386e026471211e1b87a1e8
--- /dev/null
+++ b/diffusers/examples/community/seed_resize_stable_diffusion.py
@@ -0,0 +1,367 @@
+"""
+ modified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+"""
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ text_embeddings: Optional[torch.FloatTensor] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ if text_embeddings is None:
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents_reference = torch.randn(
+ latents_shape_reference, generator=generator, device="cpu", dtype=latents_dtype
+ ).to(self.device)
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents_reference = torch.randn(
+ latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype
+ )
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents_reference.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents_reference = latents_reference.to(self.device)
+ latents = latents.to(self.device)
+
+ # This is the key part of the pipeline where we
+ # try to ensure that the generated images w/ the same seed
+ # but different sizes actually result in similar images
+ dx = (latents_shape[3] - latents_shape_reference[3]) // 2
+ dy = (latents_shape[2] - latents_shape_reference[2]) // 2
+ w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx
+ h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy
+ tx = 0 if dx < 0 else dx
+ ty = 0 if dy < 0 else dy
+ dx = max(-dx, 0)
+ dy = max(-dy, 0)
+ # import pdb
+ # pdb.set_trace()
+ latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w]
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/speech_to_image_diffusion.py b/diffusers/examples/community/speech_to_image_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..63bcfb662517e27e1afafa7040ba3c508cd8c90c
--- /dev/null
+++ b/diffusers/examples/community/speech_to_image_diffusion.py
@@ -0,0 +1,262 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class SpeechToImagePipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ speech_model: WhisperForConditionalGeneration,
+ speech_processor: WhisperProcessor,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ speech_model=speech_model,
+ speech_processor=speech_processor,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ if slice_size == "auto":
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ audio,
+ sampling_rate=16_000,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ inputs = self.speech_processor.feature_extractor(
+ audio, return_tensors="pt", sampling_rate=sampling_rate
+ ).input_features.to(self.device)
+ predicted_ids = self.speech_model.generate(inputs, max_length=480_000)
+
+ prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[
+ 0
+ ]
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return image
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/stable_diffusion_comparison.py b/diffusers/examples/community/stable_diffusion_comparison.py
new file mode 100644
index 0000000000000000000000000000000000000000..7997a0cc01864dfe2ac0e37f8f5b4d5559c0ca4c
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_comparison.py
@@ -0,0 +1,405 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+
+
+pipe1_model_id = "CompVis/stable-diffusion-v1-1"
+pipe2_model_id = "CompVis/stable-diffusion-v1-2"
+pipe3_model_id = "CompVis/stable-diffusion-v1-3"
+pipe4_model_id = "CompVis/stable-diffusion-v1-4"
+
+
+class StableDiffusionComparisonPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for parallel comparison of Stable Diffusion v1-v4
+ This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
+ downloading pre-trained checkpoints from Hugging Face Hub.
+ If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded
+ automatically.
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super()._init_()
+
+ self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)
+ self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)
+ self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)
+ self.pipe4 = StableDiffusionPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=requires_safety_checker,
+ )
+
+ self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
+
+ @property
+ def layers(self) -> Dict[str, Any]:
+ return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def text2img_sd1_1(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe1(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def text2img_sd1_2(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe2(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def text2img_sd1_3(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe3(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def text2img_sd1_4(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ return self.pipe4(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def _call_(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation. This function will generate 4 results as part
+ of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, optional, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, optional, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, optional, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, optional, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, optional, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, optional):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, optional):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, optional, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, optional, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Checks if the height and width are divisible by 8 or not
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
+
+ # Get first result from Stable Diffusion Checkpoint v1.1
+ res1 = self.text2img_sd1_1(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get first result from Stable Diffusion Checkpoint v1.2
+ res2 = self.text2img_sd1_2(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get first result from Stable Diffusion Checkpoint v1.3
+ res3 = self.text2img_sd1_3(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get first result from Stable Diffusion Checkpoint v1.4
+ res4 = self.text2img_sd1_4(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ **kwargs,
+ )
+
+ # Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
+ return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_img2img.py b/diffusers/examples/community/stable_diffusion_controlnet_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b92fff0fb551d55d70013fa8d5624b09143f3d
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_img2img.py
@@ -0,0 +1,990 @@
+# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import numpy as np
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+
+ >>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ )
+
+ >>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
+ >>> pipe_controlnet.enable_xformers_memory_efficient_attention()
+ >>> pipe_controlnet.enable_model_cpu_offload()
+
+ # using image with edges for our canny controlnet
+ >>> control_image = load_image(
+ "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")
+
+
+ >>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,
+ image=input_image,
+ prompt="an android robot, cyberpank, digitl art masterpiece",
+ num_inference_steps=20).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+def prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance,
+):
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
+ controlnet_conditioning_image = [controlnet_conditioning_image]
+
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
+ controlnet_conditioning_image = [
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
+ for i in controlnet_conditioning_image
+ ]
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
+
+ image_batch_size = controlnet_conditioning_image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
+
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance:
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
+
+ return controlnet_conditioning_image
+
+
+class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
+ """
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ # the safety checker can offload the vae again
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ elif image_is_tensor:
+ image_batch_size = image.shape[0]
+ elif image_is_pil_list:
+ image_batch_size = len(image)
+ elif image_is_tensor_list:
+ image_batch_size = len(image)
+ else:
+ raise ValueError("controlnet condition image is not valid")
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+ else:
+ raise ValueError("prompt or prompt_embeds are not valid")
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ strength=None,
+ controlnet_guidance_start=None,
+ controlnet_guidance_end=None,
+ controlnet_conditioning_scale=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # check controlnet condition image
+
+ if isinstance(self.controlnet, ControlNetModel):
+ self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if not isinstance(controlnet_conditioning_image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+
+ if len(controlnet_conditioning_image) != len(self.controlnet.nets):
+ raise ValueError(
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
+ )
+
+ for image_ in controlnet_conditioning_image:
+ self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+
+ if isinstance(self.controlnet, ControlNetModel):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim != 3 and image.ndim != 4:
+ raise ValueError("`image` must have 3 or 4 dimensions")
+
+ if image.ndim == 3:
+ image_batch_size = 1
+ image_channels, image_height, image_width = image.shape
+ elif image.ndim == 4:
+ image_batch_size, image_channels, image_height, image_width = image.shape
+ else:
+ assert False
+
+ if image_channels != 3:
+ raise ValueError("`image` must have 3 channels")
+
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("`image` should be in range [-1, 1]")
+
+ if self.vae.config.latent_channels != self.unet.config.in_channels:
+ raise ValueError(
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
+ f" latent channels: {self.vae.config.latent_channels},"
+ f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"
+ )
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of `strength` should in [0.0, 1.0] but is {strength}")
+
+ if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
+ raise ValueError(
+ f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
+ )
+
+ if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
+ raise ValueError(
+ f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
+ )
+
+ if controlnet_guidance_start > controlnet_guidance_end:
+ raise ValueError(
+ "The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
+ f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def _default_height_width(self, height, width, image):
+ if isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[3]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[2]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ controlnet_conditioning_image: Union[
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
+ ] = None,
+ strength: float = 0.8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ controlnet_guidance_start: float = 0.0,
+ controlnet_guidance_end: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
+ also be accepted as an image. The control image is automatically resized to fit the output image.
+ strength (`float`, *optional*):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+ controlnet_guidance_start ('float', *optional*, defaults to 0.0):
+ The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
+ controlnet_guidance_end ('float', *optional*, defaults to 1.0):
+ The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
+ than `controlnet_guidance_start`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ strength,
+ controlnet_guidance_start,
+ controlnet_guidance_end,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare image, and controlnet_conditioning_image
+ image = prepare_image(image)
+
+ # condition image(s)
+ if isinstance(self.controlnet, ControlNetModel):
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=controlnet_conditioning_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_conditioning_images = []
+
+ for image_ in controlnet_conditioning_image:
+ image_ = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ controlnet_conditioning_images.append(image_)
+
+ controlnet_conditioning_image = controlnet_conditioning_images
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # compute the percentage of total steps we are at
+ current_sampling_percent = i / len(timesteps)
+
+ if (
+ current_sampling_percent < controlnet_guidance_start
+ or current_sampling_percent > controlnet_guidance_end
+ ):
+ # do not apply the controlnet
+ down_block_res_samples = None
+ mid_block_res_sample = None
+ else:
+ # apply the controlnet
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=controlnet_conditioning_image,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_inpaint.py b/diffusers/examples/community/stable_diffusion_controlnet_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8797336641827528f34929fb86b69cca5d5c818
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -0,0 +1,1139 @@
+# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import numpy as np
+ >>> import torch
+ >>> from PIL import Image
+ >>> from stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
+
+ >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> def ade_palette():
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
+ >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_xformers_memory_efficient_attention()
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> def image_to_seg(image):
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
+ with torch.no_grad():
+ outputs = image_segmentor(pixel_values)
+ seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
+ palette = np.array(ade_palette())
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ color_seg = color_seg.astype(np.uint8)
+ seg_image = Image.fromarray(color_seg)
+ return seg_image
+
+ >>> image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ )
+
+ >>> mask_image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+
+ >>> controlnet_conditioning_image = image_to_seg(image)
+
+ >>> image = pipe(
+ "Face of a yellow cat, high resolution, sitting on a park bench",
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ num_inference_steps=20,
+ ).images[0]
+
+ >>> image.save("out.png")
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+def prepare_mask_image(mask_image):
+ if isinstance(mask_image, torch.Tensor):
+ if mask_image.ndim == 2:
+ # Batch and add channel dim for single mask
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
+ # Single mask, the 0'th dimension is considered to be
+ # the existing batch size of 1
+ mask_image = mask_image.unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
+ # Batch of mask, the 0'th dimension is considered to be
+ # the batching dimension
+ mask_image = mask_image.unsqueeze(1)
+
+ # Binarize mask
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ else:
+ # preprocess mask
+ if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
+ mask_image = [mask_image]
+
+ if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
+ mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
+ mask_image = mask_image.astype(np.float32) / 255.0
+ elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ return mask_image
+
+
+def prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance,
+):
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
+ controlnet_conditioning_image = [controlnet_conditioning_image]
+
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
+ controlnet_conditioning_image = [
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
+ for i in controlnet_conditioning_image
+ ]
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
+
+ image_batch_size = controlnet_conditioning_image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
+
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance:
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
+
+ return controlnet_conditioning_image
+
+
+class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
+ """
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetModel(controlnet)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ # the safety checker can offload the vae again
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ elif image_is_tensor:
+ image_batch_size = image.shape[0]
+ elif image_is_pil_list:
+ image_batch_size = len(image)
+ elif image_is_tensor_list:
+ image_batch_size = len(image)
+ else:
+ raise ValueError("controlnet condition image is not valid")
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+ else:
+ raise ValueError("prompt or prompt_embeds are not valid")
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ controlnet_conditioning_scale=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # check controlnet condition image
+ if isinstance(self.controlnet, ControlNetModel):
+ self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if not isinstance(controlnet_conditioning_image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ if len(controlnet_conditioning_image) != len(self.controlnet.nets):
+ raise ValueError(
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
+ )
+ for image_ in controlnet_conditioning_image:
+ self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `controlnet_conditioning_scale`
+ if isinstance(self.controlnet, ControlNetModel):
+ if not isinstance(controlnet_conditioning_scale, float):
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
+ self.controlnet.nets
+ ):
+ raise ValueError(
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
+ " the same length as the number of controlnets"
+ )
+ else:
+ assert False
+
+ if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
+ raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
+
+ if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
+ raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim != 3 and image.ndim != 4:
+ raise ValueError("`image` must have 3 or 4 dimensions")
+
+ if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
+ raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
+
+ if image.ndim == 3:
+ image_batch_size = 1
+ image_channels, image_height, image_width = image.shape
+ elif image.ndim == 4:
+ image_batch_size, image_channels, image_height, image_width = image.shape
+ else:
+ assert False
+
+ if mask_image.ndim == 2:
+ mask_image_batch_size = 1
+ mask_image_channels = 1
+ mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 3:
+ mask_image_channels = 1
+ mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 4:
+ mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
+
+ if image_channels != 3:
+ raise ValueError("`image` must have 3 channels")
+
+ if mask_image_channels != 1:
+ raise ValueError("`mask_image` must have 1 channel")
+
+ if image_batch_size != mask_image_batch_size:
+ raise ValueError("`image` and `mask_image` mush have the same batch sizes")
+
+ if image_height != mask_image_height or image_width != mask_image_width:
+ raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
+
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("`image` should be in range [-1, 1]")
+
+ if mask_image.min() < 0 or mask_image.max() > 1:
+ raise ValueError("`mask_image` should be in range [0, 1]")
+ else:
+ mask_image_channels = 1
+ image_channels = 3
+
+ single_image_latent_channels = self.vae.config.latent_channels
+
+ total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
+
+ if total_latent_channels != self.unet.config.in_channels:
+ raise ValueError(
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
+ f" non inpainting latent channels: {single_image_latent_channels},"
+ f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
+ f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ return latents
+
+ def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ # duplicate mask for each generation per prompt, using mps friendly method
+ if mask_image.shape[0] < batch_size:
+ if not batch_size % mask_image.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
+
+ mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
+
+ mask_image_latents = mask_image
+
+ return mask_image_latents
+
+ def prepare_masked_image_latents(
+ self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ masked_image_latents = [
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ else:
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+
+ # duplicate masked_image_latents for each generation per prompt, using mps friendly method
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return masked_image_latents
+
+ def _default_height_width(self, height, width, image):
+ if isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[3]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[2]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ controlnet_conditioning_image: Union[
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
+ ] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
+ also be accepted as an image. The control image is automatically resized to fit the output image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare mask, image, and controlnet_conditioning_image
+ image = prepare_image(image)
+
+ mask_image = prepare_mask_image(mask_image)
+
+ # condition image(s)
+ if isinstance(self.controlnet, ControlNetModel):
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=controlnet_conditioning_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ elif isinstance(self.controlnet, MultiControlNetModel):
+ controlnet_conditioning_images = []
+
+ for image_ in controlnet_conditioning_image:
+ image_ = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ controlnet_conditioning_images.append(image_)
+
+ controlnet_conditioning_image = controlnet_conditioning_images
+ else:
+ assert False
+
+ masked_image = image * (mask_image < 0.5)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_image_latents = self.prepare_mask_latents(
+ mask_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ do_classifier_free_guidance,
+ )
+
+ masked_image_latents = self.prepare_masked_image_latents(
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ non_inpainting_latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+
+ non_inpainting_latent_model_input = self.scheduler.scale_model_input(
+ non_inpainting_latent_model_input, t
+ )
+
+ inpainting_latent_model_input = torch.cat(
+ [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ non_inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=controlnet_conditioning_image,
+ conditioning_scale=controlnet_conditioning_scale,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ad3c39239d0f12dce61807cf5c86a8bd35195f
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -0,0 +1,1120 @@
+# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ PIL_INTERPOLATION,
+ is_accelerate_available,
+ is_accelerate_version,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import numpy as np
+ >>> import torch
+ >>> from PIL import Image
+ >>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline
+
+ >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> def ade_palette():
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
+ >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
+
+ >>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_xformers_memory_efficient_attention()
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> def image_to_seg(image):
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values
+ with torch.no_grad():
+ outputs = image_segmentor(pixel_values)
+ seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
+ palette = np.array(ade_palette())
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ color_seg = color_seg.astype(np.uint8)
+ seg_image = Image.fromarray(color_seg)
+ return seg_image
+
+ >>> image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ )
+
+ >>> mask_image = load_image(
+ "https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ )
+
+ >>> controlnet_conditioning_image = image_to_seg(image)
+
+ >>> image = pipe(
+ "Face of a yellow cat, high resolution, sitting on a park bench",
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ num_inference_steps=20,
+ ).images[0]
+
+ >>> image.save("out.png")
+ ```
+"""
+
+
+def prepare_image(image):
+ if isinstance(image, torch.Tensor):
+ # Batch single image
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ image = image.to(dtype=torch.float32)
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ return image
+
+
+def prepare_mask_image(mask_image):
+ if isinstance(mask_image, torch.Tensor):
+ if mask_image.ndim == 2:
+ # Batch and add channel dim for single mask
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
+ # Single mask, the 0'th dimension is considered to be
+ # the existing batch size of 1
+ mask_image = mask_image.unsqueeze(0)
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
+ # Batch of mask, the 0'th dimension is considered to be
+ # the batching dimension
+ mask_image = mask_image.unsqueeze(1)
+
+ # Binarize mask
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ else:
+ # preprocess mask
+ if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):
+ mask_image = [mask_image]
+
+ if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):
+ mask_image = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0)
+ mask_image = mask_image.astype(np.float32) / 255.0
+ elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
+
+ mask_image[mask_image < 0.5] = 0
+ mask_image[mask_image >= 0.5] = 1
+ mask_image = torch.from_numpy(mask_image)
+
+ return mask_image
+
+
+def prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
+):
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
+ controlnet_conditioning_image = [controlnet_conditioning_image]
+
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
+ controlnet_conditioning_image = [
+ np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
+ for i in controlnet_conditioning_image
+ ]
+ controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
+ controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
+ controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
+ controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
+
+ image_batch_size = controlnet_conditioning_image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
+
+ controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
+
+ return controlnet_conditioning_image
+
+
+class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
+ """
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ controlnet: ControlNetModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ # the safety checker can offload the vae again
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # control net hook has be manually offloaded as it alternates with unet
+ cpu_offload_with_hook(self.controlnet, device)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ strength=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
+ controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
+ controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
+ controlnet_conditioning_image[0], PIL.Image.Image
+ )
+ controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
+ controlnet_conditioning_image[0], torch.Tensor
+ )
+
+ if (
+ not controlnet_cond_image_is_pil
+ and not controlnet_cond_image_is_tensor
+ and not controlnet_cond_image_is_pil_list
+ and not controlnet_cond_image_is_tensor_list
+ ):
+ raise TypeError(
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
+ )
+
+ if controlnet_cond_image_is_pil:
+ controlnet_cond_image_batch_size = 1
+ elif controlnet_cond_image_is_tensor:
+ controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
+ elif controlnet_cond_image_is_pil_list:
+ controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
+ elif controlnet_cond_image_is_tensor_list:
+ controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
+ raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
+
+ if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):
+ raise TypeError("if `image` is a PIL image, `mask_image` must also be a PIL image")
+
+ if isinstance(image, torch.Tensor):
+ if image.ndim != 3 and image.ndim != 4:
+ raise ValueError("`image` must have 3 or 4 dimensions")
+
+ if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
+ raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
+
+ if image.ndim == 3:
+ image_batch_size = 1
+ image_channels, image_height, image_width = image.shape
+ elif image.ndim == 4:
+ image_batch_size, image_channels, image_height, image_width = image.shape
+
+ if mask_image.ndim == 2:
+ mask_image_batch_size = 1
+ mask_image_channels = 1
+ mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 3:
+ mask_image_channels = 1
+ mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape
+ elif mask_image.ndim == 4:
+ mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape
+
+ if image_channels != 3:
+ raise ValueError("`image` must have 3 channels")
+
+ if mask_image_channels != 1:
+ raise ValueError("`mask_image` must have 1 channel")
+
+ if image_batch_size != mask_image_batch_size:
+ raise ValueError("`image` and `mask_image` mush have the same batch sizes")
+
+ if image_height != mask_image_height or image_width != mask_image_width:
+ raise ValueError("`image` and `mask_image` must have the same height and width dimensions")
+
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("`image` should be in range [-1, 1]")
+
+ if mask_image.min() < 0 or mask_image.max() > 1:
+ raise ValueError("`mask_image` should be in range [0, 1]")
+ else:
+ mask_image_channels = 1
+ image_channels = 3
+
+ single_image_latent_channels = self.vae.config.latent_channels
+
+ total_latent_channels = single_image_latent_channels * 2 + mask_image_channels
+
+ if total_latent_channels != self.unet.config.in_channels:
+ raise ValueError(
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
+ f" non inpainting latent channels: {single_image_latent_channels},"
+ f" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}."
+ f" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs."
+ )
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+ )
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ init_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))
+ mask_image = mask_image.to(device=device, dtype=dtype)
+
+ # duplicate mask for each generation per prompt, using mps friendly method
+ if mask_image.shape[0] < batch_size:
+ if not batch_size % mask_image.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
+
+ mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image
+
+ mask_image_latents = mask_image
+
+ return mask_image_latents
+
+ def prepare_masked_image_latents(
+ self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ masked_image_latents = [
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ else:
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+
+ # duplicate masked_image_latents for each generation per prompt, using mps friendly method
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return masked_image_latents
+
+ def _default_height_width(self, height, width, image):
+ if isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[3]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[2]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
+ controlnet_conditioning_image: Union[
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
+ ] = None,
+ strength: float = 0.8,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`torch.Tensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
+ also be accepted as an image. The control image is automatically resized to fit the output image.
+ strength (`float`, *optional*):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, controlnet_conditioning_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ controlnet_conditioning_image,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ strength,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare mask, image, and controlnet_conditioning_image
+ image = prepare_image(image)
+
+ mask_image = prepare_mask_image(mask_image)
+
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
+ controlnet_conditioning_image,
+ width,
+ height,
+ batch_size * num_images_per_prompt,
+ num_images_per_prompt,
+ device,
+ self.controlnet.dtype,
+ )
+
+ masked_image = image * (mask_image < 0.5)
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ mask_image_latents = self.prepare_mask_latents(
+ mask_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ do_classifier_free_guidance,
+ )
+
+ masked_image_latents = self.prepare_masked_image_latents(
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ if do_classifier_free_guidance:
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ non_inpainting_latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+
+ non_inpainting_latent_model_input = self.scheduler.scale_model_input(
+ non_inpainting_latent_model_input, t
+ )
+
+ inpainting_latent_model_input = torch.cat(
+ [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1
+ )
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ non_inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=controlnet_conditioning_image,
+ return_dict=False,
+ )
+
+ down_block_res_samples = [
+ down_block_res_sample * controlnet_conditioning_scale
+ for down_block_res_sample in down_block_res_samples
+ ]
+ mid_block_res_sample *= controlnet_conditioning_scale
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ inpainting_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_controlnet_reference.py b/diffusers/examples/community/stable_diffusion_controlnet_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..358fc1c6dc67f4c98a3f9c2d4bab75a027e19938
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_controlnet_reference.py
@@ -0,0 +1,838 @@
+# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionControlNetPipeline
+from diffusers.models import ControlNetModel
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import cv2
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> # get canny image
+ >>> image = cv2.Canny(np.array(input_image), 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
+
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ image=canny_image,
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[
+ torch.FloatTensor,
+ PIL.Image.Image,
+ np.ndarray,
+ List[torch.FloatTensor],
+ List[PIL.Image.Image],
+ List[np.ndarray],
+ ] = None,
+ ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
+ specified in init, images must be passed as a list such that each element of the list can be correctly
+ batched for input to a single controlnet.
+ ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
+ corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ controlnet_conditioning_scale,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare image
+ if isinstance(controlnet, ControlNetModel):
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ height, width = image.shape[-2:]
+ elif isinstance(controlnet, MultiControlNetModel):
+ images = []
+
+ for image_ in image:
+ image_ = self.prepare_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+
+ image = images
+ height, width = image[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 6. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 7. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attention and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
+ ):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 11. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # controlnet(s) inference
+ if guess_mode and do_classifier_free_guidance:
+ # Infer ControlNet only for the conditional batch.
+ control_model_input = latents
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ else:
+ control_model_input = latent_model_input
+ controlnet_prompt_embeds = prompt_embeds
+
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=controlnet_prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and do_classifier_free_guidance:
+ # Infered ControlNet only for the conditional batch.
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # If we do sequential model offloading, let's offload unet and controlnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.controlnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_ipex.py b/diffusers/examples/community/stable_diffusion_ipex.py
new file mode 100644
index 0000000000000000000000000000000000000000..385227db0b7010ab6e9b6ffff96a96f661119205
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_ipex.py
@@ -0,0 +1,852 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import intel_extension_for_pytorch as ipex
+import torch
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.loaders import TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
+
+ >>> # For Float32
+ >>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
+ >>> # For BFloat16
+ >>> pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> # For Float32
+ >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'
+ >>> # For BFloat16
+ >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'
+ ```
+"""
+
+
+class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion on IPEX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
+ prompt_embeds = None
+ negative_prompt_embeds = None
+ negative_prompt = None
+ callback_steps = 1
+ generator = None
+ latents = None
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+
+ device = "cpu"
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 5. Prepare latent variables
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ self.unet.in_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ dummy = torch.ones(1, dtype=torch.int32)
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)
+
+ unet_input_example = (latent_model_input, dummy, prompt_embeds)
+ vae_decoder_input_example = latents
+
+ return unet_input_example, vae_decoder_input_example
+
+ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None, guidance_scale=7.5):
+ self.unet = self.unet.to(memory_format=torch.channels_last)
+ self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)
+ self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)
+ if self.safety_checker is not None:
+ self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last)
+
+ unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale)
+
+ # optimize with ipex
+ if dtype == torch.bfloat16:
+ self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)
+ self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
+ self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+ if self.safety_checker is not None:
+ self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
+ elif dtype == torch.float32:
+ self.unet = ipex.optimize(
+ self.unet.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.vae.decoder = ipex.optimize(
+ self.vae.decoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ self.text_encoder = ipex.optimize(
+ self.text_encoder.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ if self.safety_checker is not None:
+ self.safety_checker = ipex.optimize(
+ self.safety_checker.eval(),
+ dtype=torch.float32,
+ inplace=True,
+ weights_prepack=True,
+ auto_kernel_selection=False,
+ )
+ else:
+ raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !")
+
+ # trace unet model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False)
+ unet_trace_model = torch.jit.freeze(unet_trace_model)
+ self.unet.forward = unet_trace_model.forward
+
+ # trace vae.decoder model to get better performance on IPEX
+ with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():
+ ave_decoder_trace_model = torch.jit.trace(
+ self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False
+ )
+ ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model)
+ self.vae.decoder.forward = ave_decoder_trace_model.forward
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)["sample"]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_mega.py b/diffusers/examples/community/stable_diffusion_mega.py
new file mode 100644
index 0000000000000000000000000000000000000000..faed00b49d40373eb16fdc9bd83c5f3f627c7710
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_mega.py
@@ -0,0 +1,228 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipelineLegacy,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.configuration_utils import FrozenDict
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionMegaPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionMegaSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def inpaint(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
+ return StableDiffusionInpaintPipelineLegacy(**self.components)(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ )
+
+ @torch.no_grad()
+ def img2img(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
+ return StableDiffusionImg2ImgPipeline(**self.components)(
+ prompt=prompt,
+ image=image,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+
+ @torch.no_grad()
+ def text2img(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ # For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
+ return StableDiffusionPipeline(**self.components)(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
diff --git a/diffusers/examples/community/stable_diffusion_reference.py b/diffusers/examples/community/stable_diffusion_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..505470574a0be6d112870f7e1c78f1549d53eaec
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_reference.py
@@ -0,0 +1,797 @@
+# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionPipeline
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
+from diffusers.utils import PIL_INTERPOLATION, logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> pipe = StableDiffusionReferencePipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ safety_checker=None,
+ torch_dtype=torch.float16
+ ).to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
+
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+class StableDiffusionReferencePipeline(StableDiffusionPipeline):
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+
+ width = (width // 8) * 8 # round down to nearest multiple of 8
+
+ return height, width
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = (image - 0.5) / 0.5
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ ref_image (`torch.FloatTensor`, `PIL.Image.Image`):
+ The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
+ the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
+ also be accepted as an image.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ attention_auto_machine_weight (`float`):
+ Weight of using reference query for self attention's context.
+ If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
+ gn_auto_machine_weight (`float`):
+ Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
+ style_fidelity (`float`):
+ style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
+ elif style_fidelity=0.0, prompt more important, else balanced.
+ reference_attn (`bool`):
+ Whether to use reference query for self attention's context.
+ reference_adain (`bool`):
+ Whether to use reference adain.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 0. Default height and width to unet
+ height, width = self._default_height_width(height, width, ref_image)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attention and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_repaint.py b/diffusers/examples/community/stable_diffusion_repaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..4da46b3708159f5f389660d92056b8a6fe605825
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_repaint.py
@@ -0,0 +1,958 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
+from diffusers.configuration_utils import FrozenDict, deprecate
+from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import (
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(image, mask):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ # masked_image = image * (mask >= 0.5)
+ masked_image = image
+
+ return mask, masked_image
+
+
+class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "skip_prk_steps not set",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet.config.in_channels != 4:
+ logger.warning(
+ f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
+ f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
+ ". If you did not intend to modify"
+ " this behavior, please check whether you have loaded the right checkpoint."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ if self.safety_checker is not None:
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ masked_image_latents = [
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ else:
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ jump_length: Optional[int] = 10,
+ jump_n_sample: Optional[int] = 10,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ jump_length (`int`, *optional*, defaults to 10):
+ The number of steps taken forward in time before going backward in time for a single jump ("j" in
+ RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ jump_n_sample (`int`, *optional*, defaults to 10):
+ The number of times we will make forward time jump for a given chosen time sample. Take a look at
+ Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ Examples:
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+ >>> from diffusers import StableDiffusionPipeline, RePaintScheduler
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+ >>> base_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/"
+ >>> img_url = base_url + "overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = base_url + "overture-creations-5sI6fQgYIuo_mask.png "
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+ >>> pipe = DiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, custom_pipeline="stable_diffusion_repaint",
+ ... )
+ >>> pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config)
+ >>> pipe = pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask_image is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Preprocess mask and image
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, device)
+ self.scheduler.eta = eta
+
+ timesteps = self.scheduler.timesteps
+ # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance=False, # We do not need duplicate mask and image
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ # num_channels_mask = mask.shape[1]
+ # num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} "
+ f" = Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ t_last = timesteps[0] + 1
+
+ # 10. Denoising loop
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if t >= t_last:
+ # compute the reverse: x_t-1 -> x_t
+ latents = self.scheduler.undo_step(latents, t_last, generator)
+ progress_bar.update()
+ t_last = t
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred,
+ t,
+ latents,
+ masked_image_latents,
+ mask,
+ **extra_step_kwargs,
+ ).prev_sample
+
+ # call the callback, if provided
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ t_last = t
+
+ # 11. Post-processing
+ image = self.decode_latents(latents)
+
+ # 12. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 13. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_tensorrt_img2img.py b/diffusers/examples/community/stable_diffusion_tensorrt_img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..041cf3a12dbdd847ab28f7942f9b997de47cf9b5
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_tensorrt_img2img.py
@@ -0,0 +1,1055 @@
+#
+# Copyright 2023 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from collections import OrderedDict
+from copy import copy
+from typing import List, Optional, Union
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import PIL.Image
+import tensorrt as trt
+import torch
+from huggingface_hub import snapshot_download
+from onnx import shape_inference
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.onnx.loader import fold_constants
+from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_from_bytes,
+ engine_from_network,
+ network_from_onnx_path,
+ save_engine,
+)
+from polygraphy.backend.trt import util as trt_util
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionPipelineOutput,
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import DIFFUSERS_CACHE, logging
+
+
+"""
+Installation instructions
+python3 -m pip install --upgrade transformers diffusers>=0.16.0
+python3 -m pip install --upgrade tensorrt>=8.6.1
+python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
+python3 -m pip install onnxruntime
+"""
+
+TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Map of numpy dtype -> torch dtype
+numpy_to_torch_dtype_dict = {
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float64,
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128,
+}
+if np.version.full_version >= "1.24.0":
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
+else:
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
+
+# Map of torch dtype -> numpy dtype
+torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
+
+
+def device_view(t):
+ return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
+
+
+def preprocess_image(image):
+ """
+ image: torch.Tensor
+ """
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h))
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).contiguous()
+ return 2.0 * image - 1.0
+
+
+class Engine:
+ def __init__(self, engine_path):
+ self.engine_path = engine_path
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ def __del__(self):
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(
+ self,
+ onnx_path,
+ fp16,
+ input_profile=None,
+ enable_preview=False,
+ enable_all_tactics=False,
+ timing_cache=None,
+ workspace_size=0,
+ ):
+ logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ config_kwargs = {}
+
+ config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
+ if enable_preview:
+ # Faster dynamic shapes made optional since it increases engine build time.
+ config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
+ if workspace_size > 0:
+ config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
+ if not enable_all_tactics:
+ config_kwargs["tactic_sources"] = []
+
+ engine = engine_from_network(
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
+ config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
+ save_timing_cache=timing_cache,
+ )
+ save_engine(engine, path=self.engine_path)
+
+ def load(self):
+ logger.warning(f"Loading TensorRT engine: {self.engine_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+
+ def activate(self):
+ self.context = self.engine.create_execution_context()
+
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
+ for idx in range(trt_util.get_bindings_per_profile(self.engine)):
+ binding = self.engine[idx]
+ if shape_dict and binding in shape_dict:
+ shape = shape_dict[binding]
+ else:
+ shape = self.engine.get_binding_shape(binding)
+ dtype = trt.nptype(self.engine.get_binding_dtype(binding))
+ if self.engine.binding_is_input(binding):
+ self.context.set_binding_shape(idx, shape)
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ self.tensors[binding] = tensor
+ self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
+
+ def infer(self, feed_dict, stream):
+ start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
+ # shallow copy of ordered dict
+ device_buffers = copy(self.buffers)
+ for name, buf in feed_dict.items():
+ assert isinstance(buf, cuda.DeviceView)
+ device_buffers[name] = buf
+ bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
+ noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class Optimizer:
+ def __init__(self, onnx_graph):
+ self.graph = gs.import_onnx(onnx_graph)
+
+ def cleanup(self, return_onnx=False):
+ self.graph.cleanup().toposort()
+ if return_onnx:
+ return gs.export_onnx(self.graph)
+
+ def select_outputs(self, keep, names=None):
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
+ if names:
+ for i, name in enumerate(names):
+ self.graph.outputs[i].name = name
+
+ def fold_constants(self, return_onnx=False):
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+ def infer_shapes(self, return_onnx=False):
+ onnx_graph = gs.export_onnx(self.graph)
+ if onnx_graph.ByteSize() > 2147483648:
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
+ else:
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
+
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+
+class BaseModel:
+ def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
+ self.model = model
+ self.name = "SD Model"
+ self.fp16 = fp16
+ self.device = device
+
+ self.min_batch = 1
+ self.max_batch = max_batch_size
+ self.min_image_shape = 256 # min image resolution: 256x256
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
+ self.min_latent_shape = self.min_image_shape // 8
+ self.max_latent_shape = self.max_image_shape // 8
+
+ self.embedding_dim = embedding_dim
+ self.text_maxlen = text_maxlen
+
+ def get_model(self):
+ return self.model
+
+ def get_input_names(self):
+ pass
+
+ def get_output_names(self):
+ pass
+
+ def get_dynamic_axes(self):
+ return None
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ pass
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ return None
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ return None
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
+ return onnx_opt_graph
+
+ def check_dims(self, batch_size, image_height, image_width):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ assert image_height % 8 == 0 or image_width % 8 == 0
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ min_image_height = image_height if static_shape else self.min_image_shape
+ max_image_height = image_height if static_shape else self.max_image_shape
+ min_image_width = image_width if static_shape else self.min_image_shape
+ max_image_width = image_width if static_shape else self.max_image_shape
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
+ return (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ )
+
+
+def getOnnxPath(model_name, onnx_dir, opt=True):
+ return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
+
+
+def getEnginePath(model_name, engine_dir):
+ return os.path.join(engine_dir, model_name + ".plan")
+
+
+def build_engines(
+ models: dict,
+ engine_dir,
+ onnx_dir,
+ onnx_opset,
+ opt_image_height,
+ opt_image_width,
+ opt_batch_size=1,
+ force_engine_rebuild=False,
+ static_batch=False,
+ static_shape=True,
+ enable_preview=False,
+ enable_all_tactics=False,
+ timing_cache=None,
+ max_workspace_size=0,
+):
+ built_engines = {}
+ if not os.path.isdir(onnx_dir):
+ os.makedirs(onnx_dir)
+ if not os.path.isdir(engine_dir):
+ os.makedirs(engine_dir)
+
+ # Export models to ONNX
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ if force_engine_rebuild or not os.path.exists(engine_path):
+ logger.warning("Building Engines...")
+ logger.warning("Engine build can take a while to complete")
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ if force_engine_rebuild or not os.path.exists(onnx_path):
+ logger.warning(f"Exporting model: {onnx_path}")
+ model = model_obj.get_model()
+ with torch.inference_mode(), torch.autocast("cuda"):
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
+ torch.onnx.export(
+ model,
+ inputs,
+ onnx_path,
+ export_params=True,
+ opset_version=onnx_opset,
+ do_constant_folding=True,
+ input_names=model_obj.get_input_names(),
+ output_names=model_obj.get_output_names(),
+ dynamic_axes=model_obj.get_dynamic_axes(),
+ )
+ del model
+ torch.cuda.empty_cache()
+ gc.collect()
+ else:
+ logger.warning(f"Found cached model: {onnx_path}")
+
+ # Optimize onnx
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ logger.warning(f"Generating optimizing model: {onnx_opt_path}")
+ onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
+ onnx.save(onnx_opt_graph, onnx_opt_path)
+ else:
+ logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
+
+ # Build TensorRT engines
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ engine = Engine(engine_path)
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+
+ if force_engine_rebuild or not os.path.exists(engine.engine_path):
+ engine.build(
+ onnx_opt_path,
+ fp16=True,
+ input_profile=model_obj.get_input_profile(
+ opt_batch_size,
+ opt_image_height,
+ opt_image_width,
+ static_batch=static_batch,
+ static_shape=static_shape,
+ ),
+ enable_preview=enable_preview,
+ timing_cache=timing_cache,
+ workspace_size=max_workspace_size,
+ )
+ built_engines[model_name] = engine
+
+ # Load and activate TensorRT engines
+ for model_name, model_obj in models.items():
+ engine = built_engines[model_name]
+ engine.load()
+ engine.activate()
+
+ return built_engines
+
+
+def runEngine(engine, feed_dict, stream):
+ return engine.infer(feed_dict, stream)
+
+
+class CLIP(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(CLIP, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "CLIP"
+
+ def get_input_names(self):
+ return ["input_ids"]
+
+ def get_output_names(self):
+ return ["text_embeddings", "pooler_output"]
+
+ def get_dynamic_axes(self):
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
+ batch_size, image_height, image_width, static_batch, static_shape
+ )
+ return {
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return {
+ "input_ids": (batch_size, self.text_maxlen),
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.select_outputs([0]) # delete graph output#1
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
+ return opt_onnx_graph
+
+
+def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class UNet(BaseModel):
+ def __init__(
+ self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
+ ):
+ super(UNet, self).__init__(
+ model=model,
+ fp16=fp16,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ text_maxlen=text_maxlen,
+ )
+ self.unet_dim = unet_dim
+ self.name = "UNet"
+
+ def get_input_names(self):
+ return ["sample", "timestep", "encoder_hidden_states"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {
+ "sample": {0: "2B", 2: "H", 3: "W"},
+ "encoder_hidden_states": {0: "2B"},
+ "latent": {0: "2B", 2: "H", 3: "W"},
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "sample": [
+ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
+ (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
+ ],
+ "encoder_hidden_states": [
+ (2 * min_batch, self.text_maxlen, self.embedding_dim),
+ (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ (2 * max_batch, self.text_maxlen, self.embedding_dim),
+ ],
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.float32
+ return (
+ torch.randn(
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
+ ),
+ torch.tensor([1.0], dtype=torch.float32, device=self.device),
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ )
+
+
+def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return UNet(
+ model,
+ fp16=True,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ unet_dim=(9 if inpaint else 4),
+ )
+
+
+class VAE(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAE, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE decoder"
+
+ def get_input_names(self):
+ return ["latent"]
+
+ def get_output_names(self):
+ return ["images"]
+
+ def get_dynamic_axes(self):
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "latent": [
+ (min_batch, 4, min_latent_height, min_latent_width),
+ (batch_size, 4, latent_height, latent_width),
+ (max_batch, 4, max_latent_height, max_latent_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "latent": (batch_size, 4, latent_height, latent_width),
+ "images": (batch_size, 3, image_height, image_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TorchVAEEncoder(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.vae_encoder = model
+
+ def forward(self, x):
+ return self.vae_encoder.encode(x).latent_dist.sample()
+
+
+class VAEEncoder(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAEEncoder, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE encoder"
+
+ def get_model(self):
+ vae_encoder = TorchVAEEncoder(self.model)
+ return vae_encoder
+
+ def get_input_names(self):
+ return ["images"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ _,
+ _,
+ _,
+ _,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+
+ return {
+ "images": [
+ (min_batch, 3, min_image_height, min_image_width),
+ (batch_size, 3, image_height, image_width),
+ (max_batch, 3, max_image_height, max_image_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "images": (batch_size, 3, image_height, image_width),
+ "latent": (batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
+ r"""
+ Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.
+
+ This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ stages=["clip", "unet", "vae", "vae_encoder"],
+ image_height: int = 512,
+ image_width: int = 512,
+ max_batch_size: int = 16,
+ # ONNX export parameters
+ onnx_opset: int = 17,
+ onnx_dir: str = "onnx",
+ # TensorRT engine build parameters
+ engine_dir: str = "engine",
+ build_preview_features: bool = True,
+ force_engine_rebuild: bool = False,
+ timing_cache: str = "timing_cache",
+ ):
+ super().__init__(
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ )
+
+ self.vae.forward = self.vae.decode
+
+ self.stages = stages
+ self.image_height, self.image_width = image_height, image_width
+ self.inpaint = False
+ self.onnx_opset = onnx_opset
+ self.onnx_dir = onnx_dir
+ self.engine_dir = engine_dir
+ self.force_engine_rebuild = force_engine_rebuild
+ self.timing_cache = timing_cache
+ self.build_static_batch = False
+ self.build_dynamic_shape = False
+ self.build_preview_features = build_preview_features
+
+ self.max_batch_size = max_batch_size
+ # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
+ if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
+ self.max_batch_size = 4
+
+ self.stream = None # loaded in loadResources()
+ self.models = {} # loaded in __loadModels()
+ self.engine = {} # loaded in build_engines()
+
+ def __loadModels(self):
+ # Load pipeline models
+ self.embedding_dim = self.text_encoder.config.hidden_size
+ models_args = {
+ "device": self.torch_device,
+ "max_batch_size": self.max_batch_size,
+ "embedding_dim": self.embedding_dim,
+ "inpaint": self.inpaint,
+ }
+ if "clip" in self.stages:
+ self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
+ if "unet" in self.stages:
+ self.models["unet"] = make_UNet(self.unet, **models_args)
+ if "vae" in self.stages:
+ self.models["vae"] = make_VAE(self.vae, **models_args)
+ if "vae_encoder" in self.stages:
+ self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
+
+ @classmethod
+ def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+
+ cls.cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ )
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
+ super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
+
+ self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
+ self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
+ self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
+
+ # set device
+ self.torch_device = self._execution_device
+ logger.warning(f"Running inference on device: {self.torch_device}")
+
+ # load models
+ self.__loadModels()
+
+ # build engines
+ self.engine = build_engines(
+ self.models,
+ self.engine_dir,
+ self.onnx_dir,
+ self.onnx_opset,
+ opt_image_height=self.image_height,
+ opt_image_width=self.image_width,
+ force_engine_rebuild=self.force_engine_rebuild,
+ static_batch=self.build_static_batch,
+ static_shape=not self.build_dynamic_shape,
+ enable_preview=self.build_preview_features,
+ timing_cache=self.timing_cache,
+ )
+
+ return self
+
+ def __initialize_timesteps(self, timesteps, strength):
+ self.scheduler.set_timesteps(timesteps)
+ offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
+ init_timestep = int(timesteps * strength) + offset
+ init_timestep = min(init_timestep, timesteps)
+ t_start = max(timesteps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device)
+ return timesteps, t_start
+
+ def __preprocess_images(self, batch_size, images=()):
+ init_images = []
+ for image in images:
+ image = image.to(self.torch_device).float()
+ image = image.repeat(batch_size, 1, 1, 1)
+ init_images.append(image)
+ return tuple(init_images)
+
+ def __encode_image(self, init_image):
+ init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[
+ "latent"
+ ]
+ init_latents = 0.18215 * init_latents
+ return init_latents
+
+ def __encode_prompt(self, prompt, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ """
+ # Tokenize prompt
+ text_input_ids = (
+ self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+
+ text_input_ids_inp = device_view(text_input_ids)
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
+ "text_embeddings"
+ ].clone()
+
+ # Tokenize negative prompt
+ uncond_input_ids = (
+ self.tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+ uncond_input_ids_inp = device_view(uncond_input_ids)
+ uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
+ "text_embeddings"
+ ]
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
+
+ return text_embeddings
+
+ def __denoise_latent(
+ self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
+ ):
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = self.scheduler.timesteps
+ for step_index, timestep in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if isinstance(mask, torch.Tensor):
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # Predict the noise residual
+ timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
+
+ sample_inp = device_view(latent_model_input)
+ timestep_inp = device_view(timestep_float)
+ embeddings_inp = device_view(text_embeddings)
+ noise_pred = runEngine(
+ self.engine["unet"],
+ {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
+ self.stream,
+ )["latent"]
+
+ # Perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
+
+ latents = 1.0 / 0.18215 * latents
+ return latents
+
+ def __decode_latent(self, latents):
+ images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
+ images = (images / 2 + 0.5).clamp(0, 1)
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ def __loadResources(self, image_height, image_width, batch_size):
+ self.stream = cuda.Stream()
+
+ # Allocate buffers for TensorRT engine bindings
+ for model_name, obj in self.models.items():
+ self.engine[model_name].allocate_buffers(
+ shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ strength: float = 0.8,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+
+ """
+ self.generator = generator
+ self.denoising_steps = num_inference_steps
+ self.guidance_scale = guidance_scale
+
+ # Pre-compute latent input scales and linear multistep coefficients
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
+
+ # Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ assert len(prompt) == len(negative_prompt)
+
+ if batch_size > self.max_batch_size:
+ raise ValueError(
+ f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
+ )
+
+ # load resources
+ self.__loadResources(self.image_height, self.image_width, batch_size)
+
+ with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
+ # Initialize timesteps
+ timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+
+ # Pre-process input image
+ if isinstance(image, PIL.Image.Image):
+ image = preprocess_image(image)
+ init_image = self.__preprocess_images(batch_size, (image,))[0]
+
+ # VAE encode init image
+ init_latents = self.__encode_image(init_image)
+
+ # Add noise to latents using timesteps
+ noise = torch.randn(
+ init_latents.shape, generator=self.generator, device=self.torch_device, dtype=torch.float32
+ )
+ latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
+
+ # CLIP text encoder
+ text_embeddings = self.__encode_prompt(prompt, negative_prompt)
+
+ # UNet denoiser
+ latents = self.__denoise_latent(latents, text_embeddings, timesteps=timesteps, step_offset=t_start)
+
+ # VAE decode latent
+ images = self.__decode_latent(latents)
+
+ images = self.numpy_to_pil(images)
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/stable_diffusion_tensorrt_inpaint.py b/diffusers/examples/community/stable_diffusion_tensorrt_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..71fa1b0a5f1124f1cbbef4d03ba2fcc420706f3b
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_tensorrt_inpaint.py
@@ -0,0 +1,1107 @@
+#
+# Copyright 2023 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from collections import OrderedDict
+from copy import copy
+from typing import List, Optional, Union
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import PIL.Image
+import tensorrt as trt
+import torch
+from huggingface_hub import snapshot_download
+from onnx import shape_inference
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.onnx.loader import fold_constants
+from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_from_bytes,
+ engine_from_network,
+ network_from_onnx_path,
+ save_engine,
+)
+from polygraphy.backend.trt import util as trt_util
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipelineOutput,
+ StableDiffusionSafetyChecker,
+)
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import DIFFUSERS_CACHE, logging
+
+
+"""
+Installation instructions
+python3 -m pip install --upgrade transformers diffusers>=0.16.0
+python3 -m pip install --upgrade tensorrt>=8.6.1
+python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
+python3 -m pip install onnxruntime
+"""
+
+TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Map of numpy dtype -> torch dtype
+numpy_to_torch_dtype_dict = {
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float64,
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128,
+}
+if np.version.full_version >= "1.24.0":
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
+else:
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
+
+# Map of torch dtype -> numpy dtype
+torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
+
+
+def device_view(t):
+ return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
+
+
+def preprocess_image(image):
+ """
+ image: torch.Tensor
+ """
+ w, h = image.size
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h))
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).contiguous()
+ return 2.0 * image - 1.0
+
+
+class Engine:
+ def __init__(self, engine_path):
+ self.engine_path = engine_path
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ def __del__(self):
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(
+ self,
+ onnx_path,
+ fp16,
+ input_profile=None,
+ enable_preview=False,
+ enable_all_tactics=False,
+ timing_cache=None,
+ workspace_size=0,
+ ):
+ logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ config_kwargs = {}
+
+ config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
+ if enable_preview:
+ # Faster dynamic shapes made optional since it increases engine build time.
+ config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
+ if workspace_size > 0:
+ config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
+ if not enable_all_tactics:
+ config_kwargs["tactic_sources"] = []
+
+ engine = engine_from_network(
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
+ config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
+ save_timing_cache=timing_cache,
+ )
+ save_engine(engine, path=self.engine_path)
+
+ def load(self):
+ logger.warning(f"Loading TensorRT engine: {self.engine_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+
+ def activate(self):
+ self.context = self.engine.create_execution_context()
+
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
+ for idx in range(trt_util.get_bindings_per_profile(self.engine)):
+ binding = self.engine[idx]
+ if shape_dict and binding in shape_dict:
+ shape = shape_dict[binding]
+ else:
+ shape = self.engine.get_binding_shape(binding)
+ dtype = trt.nptype(self.engine.get_binding_dtype(binding))
+ if self.engine.binding_is_input(binding):
+ self.context.set_binding_shape(idx, shape)
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ self.tensors[binding] = tensor
+ self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
+
+ def infer(self, feed_dict, stream):
+ start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
+ # shallow copy of ordered dict
+ device_buffers = copy(self.buffers)
+ for name, buf in feed_dict.items():
+ assert isinstance(buf, cuda.DeviceView)
+ device_buffers[name] = buf
+ bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
+ noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class Optimizer:
+ def __init__(self, onnx_graph):
+ self.graph = gs.import_onnx(onnx_graph)
+
+ def cleanup(self, return_onnx=False):
+ self.graph.cleanup().toposort()
+ if return_onnx:
+ return gs.export_onnx(self.graph)
+
+ def select_outputs(self, keep, names=None):
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
+ if names:
+ for i, name in enumerate(names):
+ self.graph.outputs[i].name = name
+
+ def fold_constants(self, return_onnx=False):
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+ def infer_shapes(self, return_onnx=False):
+ onnx_graph = gs.export_onnx(self.graph)
+ if onnx_graph.ByteSize() > 2147483648:
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
+ else:
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
+
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+
+class BaseModel:
+ def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
+ self.model = model
+ self.name = "SD Model"
+ self.fp16 = fp16
+ self.device = device
+
+ self.min_batch = 1
+ self.max_batch = max_batch_size
+ self.min_image_shape = 256 # min image resolution: 256x256
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
+ self.min_latent_shape = self.min_image_shape // 8
+ self.max_latent_shape = self.max_image_shape // 8
+
+ self.embedding_dim = embedding_dim
+ self.text_maxlen = text_maxlen
+
+ def get_model(self):
+ return self.model
+
+ def get_input_names(self):
+ pass
+
+ def get_output_names(self):
+ pass
+
+ def get_dynamic_axes(self):
+ return None
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ pass
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ return None
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ return None
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
+ return onnx_opt_graph
+
+ def check_dims(self, batch_size, image_height, image_width):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ assert image_height % 8 == 0 or image_width % 8 == 0
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ min_image_height = image_height if static_shape else self.min_image_shape
+ max_image_height = image_height if static_shape else self.max_image_shape
+ min_image_width = image_width if static_shape else self.min_image_shape
+ max_image_width = image_width if static_shape else self.max_image_shape
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
+ return (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ )
+
+
+def getOnnxPath(model_name, onnx_dir, opt=True):
+ return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
+
+
+def getEnginePath(model_name, engine_dir):
+ return os.path.join(engine_dir, model_name + ".plan")
+
+
+def build_engines(
+ models: dict,
+ engine_dir,
+ onnx_dir,
+ onnx_opset,
+ opt_image_height,
+ opt_image_width,
+ opt_batch_size=1,
+ force_engine_rebuild=False,
+ static_batch=False,
+ static_shape=True,
+ enable_preview=False,
+ enable_all_tactics=False,
+ timing_cache=None,
+ max_workspace_size=0,
+):
+ built_engines = {}
+ if not os.path.isdir(onnx_dir):
+ os.makedirs(onnx_dir)
+ if not os.path.isdir(engine_dir):
+ os.makedirs(engine_dir)
+
+ # Export models to ONNX
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ if force_engine_rebuild or not os.path.exists(engine_path):
+ logger.warning("Building Engines...")
+ logger.warning("Engine build can take a while to complete")
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ if force_engine_rebuild or not os.path.exists(onnx_path):
+ logger.warning(f"Exporting model: {onnx_path}")
+ model = model_obj.get_model()
+ with torch.inference_mode(), torch.autocast("cuda"):
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
+ torch.onnx.export(
+ model,
+ inputs,
+ onnx_path,
+ export_params=True,
+ opset_version=onnx_opset,
+ do_constant_folding=True,
+ input_names=model_obj.get_input_names(),
+ output_names=model_obj.get_output_names(),
+ dynamic_axes=model_obj.get_dynamic_axes(),
+ )
+ del model
+ torch.cuda.empty_cache()
+ gc.collect()
+ else:
+ logger.warning(f"Found cached model: {onnx_path}")
+
+ # Optimize onnx
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ logger.warning(f"Generating optimizing model: {onnx_opt_path}")
+ onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
+ onnx.save(onnx_opt_graph, onnx_opt_path)
+ else:
+ logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
+
+ # Build TensorRT engines
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ engine = Engine(engine_path)
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+
+ if force_engine_rebuild or not os.path.exists(engine.engine_path):
+ engine.build(
+ onnx_opt_path,
+ fp16=True,
+ input_profile=model_obj.get_input_profile(
+ opt_batch_size,
+ opt_image_height,
+ opt_image_width,
+ static_batch=static_batch,
+ static_shape=static_shape,
+ ),
+ enable_preview=enable_preview,
+ timing_cache=timing_cache,
+ workspace_size=max_workspace_size,
+ )
+ built_engines[model_name] = engine
+
+ # Load and activate TensorRT engines
+ for model_name, model_obj in models.items():
+ engine = built_engines[model_name]
+ engine.load()
+ engine.activate()
+
+ return built_engines
+
+
+def runEngine(engine, feed_dict, stream):
+ return engine.infer(feed_dict, stream)
+
+
+class CLIP(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(CLIP, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "CLIP"
+
+ def get_input_names(self):
+ return ["input_ids"]
+
+ def get_output_names(self):
+ return ["text_embeddings", "pooler_output"]
+
+ def get_dynamic_axes(self):
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
+ batch_size, image_height, image_width, static_batch, static_shape
+ )
+ return {
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return {
+ "input_ids": (batch_size, self.text_maxlen),
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.select_outputs([0]) # delete graph output#1
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
+ return opt_onnx_graph
+
+
+def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class UNet(BaseModel):
+ def __init__(
+ self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
+ ):
+ super(UNet, self).__init__(
+ model=model,
+ fp16=fp16,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ text_maxlen=text_maxlen,
+ )
+ self.unet_dim = unet_dim
+ self.name = "UNet"
+
+ def get_input_names(self):
+ return ["sample", "timestep", "encoder_hidden_states"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {
+ "sample": {0: "2B", 2: "H", 3: "W"},
+ "encoder_hidden_states": {0: "2B"},
+ "latent": {0: "2B", 2: "H", 3: "W"},
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "sample": [
+ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
+ (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
+ ],
+ "encoder_hidden_states": [
+ (2 * min_batch, self.text_maxlen, self.embedding_dim),
+ (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ (2 * max_batch, self.text_maxlen, self.embedding_dim),
+ ],
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.float32
+ return (
+ torch.randn(
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
+ ),
+ torch.tensor([1.0], dtype=torch.float32, device=self.device),
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ )
+
+
+def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False, unet_dim=4):
+ return UNet(
+ model,
+ fp16=True,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ unet_dim=unet_dim,
+ )
+
+
+class VAE(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAE, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE decoder"
+
+ def get_input_names(self):
+ return ["latent"]
+
+ def get_output_names(self):
+ return ["images"]
+
+ def get_dynamic_axes(self):
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "latent": [
+ (min_batch, 4, min_latent_height, min_latent_width),
+ (batch_size, 4, latent_height, latent_width),
+ (max_batch, 4, max_latent_height, max_latent_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "latent": (batch_size, 4, latent_height, latent_width),
+ "images": (batch_size, 3, image_height, image_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TorchVAEEncoder(torch.nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.vae_encoder = model
+
+ def forward(self, x):
+ return self.vae_encoder.encode(x).latent_dist.sample()
+
+
+class VAEEncoder(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAEEncoder, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE encoder"
+
+ def get_model(self):
+ vae_encoder = TorchVAEEncoder(self.model)
+ return vae_encoder
+
+ def get_input_names(self):
+ return ["images"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ _,
+ _,
+ _,
+ _,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+
+ return {
+ "images": [
+ (min_batch, 3, min_image_height, min_image_width),
+ (batch_size, 3, image_height, image_width),
+ (max_batch, 3, max_image_height, max_image_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "images": (batch_size, 3, image_height, image_width),
+ "latent": (batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
+ r"""
+ Pipeline for inpainting using TensorRT accelerated Stable Diffusion.
+
+ This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ stages=["clip", "unet", "vae", "vae_encoder"],
+ image_height: int = 512,
+ image_width: int = 512,
+ max_batch_size: int = 16,
+ # ONNX export parameters
+ onnx_opset: int = 17,
+ onnx_dir: str = "onnx",
+ # TensorRT engine build parameters
+ engine_dir: str = "engine",
+ build_preview_features: bool = True,
+ force_engine_rebuild: bool = False,
+ timing_cache: str = "timing_cache",
+ ):
+ super().__init__(
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ )
+
+ self.vae.forward = self.vae.decode
+
+ self.stages = stages
+ self.image_height, self.image_width = image_height, image_width
+ self.inpaint = True
+ self.onnx_opset = onnx_opset
+ self.onnx_dir = onnx_dir
+ self.engine_dir = engine_dir
+ self.force_engine_rebuild = force_engine_rebuild
+ self.timing_cache = timing_cache
+ self.build_static_batch = False
+ self.build_dynamic_shape = False
+ self.build_preview_features = build_preview_features
+
+ self.max_batch_size = max_batch_size
+ # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
+ if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
+ self.max_batch_size = 4
+
+ self.stream = None # loaded in loadResources()
+ self.models = {} # loaded in __loadModels()
+ self.engine = {} # loaded in build_engines()
+
+ def __loadModels(self):
+ # Load pipeline models
+ self.embedding_dim = self.text_encoder.config.hidden_size
+ models_args = {
+ "device": self.torch_device,
+ "max_batch_size": self.max_batch_size,
+ "embedding_dim": self.embedding_dim,
+ "inpaint": self.inpaint,
+ }
+ if "clip" in self.stages:
+ self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
+ if "unet" in self.stages:
+ self.models["unet"] = make_UNet(self.unet, **models_args, unet_dim=self.unet.config.in_channels)
+ if "vae" in self.stages:
+ self.models["vae"] = make_VAE(self.vae, **models_args)
+ if "vae_encoder" in self.stages:
+ self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
+
+ @classmethod
+ def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+
+ cls.cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ )
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
+ super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
+
+ self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
+ self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
+ self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
+
+ # set device
+ self.torch_device = self._execution_device
+ logger.warning(f"Running inference on device: {self.torch_device}")
+
+ # load models
+ self.__loadModels()
+
+ # build engines
+ self.engine = build_engines(
+ self.models,
+ self.engine_dir,
+ self.onnx_dir,
+ self.onnx_opset,
+ opt_image_height=self.image_height,
+ opt_image_width=self.image_width,
+ force_engine_rebuild=self.force_engine_rebuild,
+ static_batch=self.build_static_batch,
+ static_shape=not self.build_dynamic_shape,
+ enable_preview=self.build_preview_features,
+ timing_cache=self.timing_cache,
+ )
+
+ return self
+
+ def __initialize_timesteps(self, num_inference_steps, strength):
+ self.scheduler.set_timesteps(num_inference_steps)
+ offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device)
+ return timesteps, num_inference_steps - t_start
+
+ def __preprocess_images(self, batch_size, images=()):
+ init_images = []
+ for image in images:
+ image = image.to(self.torch_device).float()
+ image = image.repeat(batch_size, 1, 1, 1)
+ init_images.append(image)
+ return tuple(init_images)
+
+ def __encode_image(self, init_image):
+ init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[
+ "latent"
+ ]
+ init_latents = 0.18215 * init_latents
+ return init_latents
+
+ def __encode_prompt(self, prompt, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ """
+ # Tokenize prompt
+ text_input_ids = (
+ self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+
+ text_input_ids_inp = device_view(text_input_ids)
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
+ "text_embeddings"
+ ].clone()
+
+ # Tokenize negative prompt
+ uncond_input_ids = (
+ self.tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+ uncond_input_ids_inp = device_view(uncond_input_ids)
+ uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
+ "text_embeddings"
+ ]
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
+
+ return text_embeddings
+
+ def __denoise_latent(
+ self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
+ ):
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = self.scheduler.timesteps
+ for step_index, timestep in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if isinstance(mask, torch.Tensor):
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # Predict the noise residual
+ timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
+
+ sample_inp = device_view(latent_model_input)
+ timestep_inp = device_view(timestep_float)
+ embeddings_inp = device_view(text_embeddings)
+ noise_pred = runEngine(
+ self.engine["unet"],
+ {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
+ self.stream,
+ )["latent"]
+
+ # Perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
+
+ latents = 1.0 / 0.18215 * latents
+ return latents
+
+ def __decode_latent(self, latents):
+ images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
+ images = (images / 2 + 0.5).clamp(0, 1)
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ def __loadResources(self, image_height, image_width, batch_size):
+ self.stream = cuda.Stream()
+
+ # Allocate buffers for TensorRT engine bindings
+ for model_name, obj in self.models.items():
+ self.engine[model_name].allocate_buffers(
+ shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
+ be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+
+ """
+ self.generator = generator
+ self.denoising_steps = num_inference_steps
+ self.guidance_scale = guidance_scale
+
+ # Pre-compute latent input scales and linear multistep coefficients
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
+
+ # Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ assert len(prompt) == len(negative_prompt)
+
+ if batch_size > self.max_batch_size:
+ raise ValueError(
+ f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
+ )
+
+ # Validate image dimensions
+ mask_width, mask_height = mask_image.size
+ if mask_height != self.image_height or mask_width != self.image_width:
+ raise ValueError(
+ f"Input image height and width {self.image_height} and {self.image_width} are not equal to "
+ f"the respective dimensions of the mask image {mask_height} and {mask_width}"
+ )
+
+ # load resources
+ self.__loadResources(self.image_height, self.image_width, batch_size)
+
+ with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
+ # Spatial dimensions of latent tensor
+ latent_height = self.image_height // 8
+ latent_width = self.image_width // 8
+
+ # Pre-process input images
+ mask, masked_image, init_image = self.__preprocess_images(
+ batch_size,
+ prepare_mask_and_masked_image(
+ image,
+ mask_image,
+ self.image_height,
+ self.image_width,
+ return_image=True,
+ ),
+ )
+
+ mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))
+ mask = torch.cat([mask] * 2)
+
+ # Initialize timesteps
+ timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)
+
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # Pre-initialize latents
+ num_channels_latents = self.vae.config.latent_channels
+ latents_outputs = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ self.image_height,
+ self.image_width,
+ torch.float32,
+ self.torch_device,
+ generator,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ )
+
+ latents = latents_outputs[0]
+
+ # VAE encode masked image
+ masked_latents = self.__encode_image(masked_image)
+ masked_latents = torch.cat([masked_latents] * 2)
+
+ # CLIP text encoder
+ text_embeddings = self.__encode_prompt(prompt, negative_prompt)
+
+ # UNet denoiser
+ latents = self.__denoise_latent(
+ latents,
+ text_embeddings,
+ timesteps=timesteps,
+ step_offset=t_start,
+ mask=mask,
+ masked_image_latents=masked_latents,
+ )
+
+ # VAE decode latent
+ images = self.__decode_latent(latents)
+
+ images = self.numpy_to_pil(images)
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None)
diff --git a/diffusers/examples/community/stable_diffusion_tensorrt_txt2img.py b/diffusers/examples/community/stable_diffusion_tensorrt_txt2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51f3176b958263c174e9cbb16d28e1575c8d1fb
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_tensorrt_txt2img.py
@@ -0,0 +1,928 @@
+#
+# Copyright 2023 The HuggingFace Inc. team.
+# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+from collections import OrderedDict
+from copy import copy
+from typing import List, Optional, Union
+
+import numpy as np
+import onnx
+import onnx_graphsurgeon as gs
+import tensorrt as trt
+import torch
+from huggingface_hub import snapshot_download
+from onnx import shape_inference
+from polygraphy import cuda
+from polygraphy.backend.common import bytes_from_path
+from polygraphy.backend.onnx.loader import fold_constants
+from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_from_bytes,
+ engine_from_network,
+ network_from_onnx_path,
+ save_engine,
+)
+from polygraphy.backend.trt import util as trt_util
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import (
+ StableDiffusionPipeline,
+ StableDiffusionPipelineOutput,
+ StableDiffusionSafetyChecker,
+)
+from diffusers.schedulers import DDIMScheduler
+from diffusers.utils import DIFFUSERS_CACHE, logging
+
+
+"""
+Installation instructions
+python3 -m pip install --upgrade transformers diffusers>=0.16.0
+python3 -m pip install --upgrade tensorrt>=8.6.1
+python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
+python3 -m pip install onnxruntime
+"""
+
+TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Map of numpy dtype -> torch dtype
+numpy_to_torch_dtype_dict = {
+ np.uint8: torch.uint8,
+ np.int8: torch.int8,
+ np.int16: torch.int16,
+ np.int32: torch.int32,
+ np.int64: torch.int64,
+ np.float16: torch.float16,
+ np.float32: torch.float32,
+ np.float64: torch.float64,
+ np.complex64: torch.complex64,
+ np.complex128: torch.complex128,
+}
+if np.version.full_version >= "1.24.0":
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
+else:
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
+
+# Map of torch dtype -> numpy dtype
+torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
+
+
+def device_view(t):
+ return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
+
+
+class Engine:
+ def __init__(self, engine_path):
+ self.engine_path = engine_path
+ self.engine = None
+ self.context = None
+ self.buffers = OrderedDict()
+ self.tensors = OrderedDict()
+
+ def __del__(self):
+ [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
+ del self.engine
+ del self.context
+ del self.buffers
+ del self.tensors
+
+ def build(
+ self,
+ onnx_path,
+ fp16,
+ input_profile=None,
+ enable_preview=False,
+ enable_all_tactics=False,
+ timing_cache=None,
+ workspace_size=0,
+ ):
+ logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
+ p = Profile()
+ if input_profile:
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+
+ config_kwargs = {}
+
+ config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
+ if enable_preview:
+ # Faster dynamic shapes made optional since it increases engine build time.
+ config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
+ if workspace_size > 0:
+ config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
+ if not enable_all_tactics:
+ config_kwargs["tactic_sources"] = []
+
+ engine = engine_from_network(
+ network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),
+ config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
+ save_timing_cache=timing_cache,
+ )
+ save_engine(engine, path=self.engine_path)
+
+ def load(self):
+ logger.warning(f"Loading TensorRT engine: {self.engine_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
+
+ def activate(self):
+ self.context = self.engine.create_execution_context()
+
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
+ for idx in range(trt_util.get_bindings_per_profile(self.engine)):
+ binding = self.engine[idx]
+ if shape_dict and binding in shape_dict:
+ shape = shape_dict[binding]
+ else:
+ shape = self.engine.get_binding_shape(binding)
+ dtype = trt.nptype(self.engine.get_binding_dtype(binding))
+ if self.engine.binding_is_input(binding):
+ self.context.set_binding_shape(idx, shape)
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
+ self.tensors[binding] = tensor
+ self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
+
+ def infer(self, feed_dict, stream):
+ start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
+ # shallow copy of ordered dict
+ device_buffers = copy(self.buffers)
+ for name, buf in feed_dict.items():
+ assert isinstance(buf, cuda.DeviceView)
+ device_buffers[name] = buf
+ bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
+ noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class Optimizer:
+ def __init__(self, onnx_graph):
+ self.graph = gs.import_onnx(onnx_graph)
+
+ def cleanup(self, return_onnx=False):
+ self.graph.cleanup().toposort()
+ if return_onnx:
+ return gs.export_onnx(self.graph)
+
+ def select_outputs(self, keep, names=None):
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
+ if names:
+ for i, name in enumerate(names):
+ self.graph.outputs[i].name = name
+
+ def fold_constants(self, return_onnx=False):
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+ def infer_shapes(self, return_onnx=False):
+ onnx_graph = gs.export_onnx(self.graph)
+ if onnx_graph.ByteSize() > 2147483648:
+ raise TypeError("ERROR: model size exceeds supported 2GB limit")
+ else:
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
+
+ self.graph = gs.import_onnx(onnx_graph)
+ if return_onnx:
+ return onnx_graph
+
+
+class BaseModel:
+ def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
+ self.model = model
+ self.name = "SD Model"
+ self.fp16 = fp16
+ self.device = device
+
+ self.min_batch = 1
+ self.max_batch = max_batch_size
+ self.min_image_shape = 256 # min image resolution: 256x256
+ self.max_image_shape = 1024 # max image resolution: 1024x1024
+ self.min_latent_shape = self.min_image_shape // 8
+ self.max_latent_shape = self.max_image_shape // 8
+
+ self.embedding_dim = embedding_dim
+ self.text_maxlen = text_maxlen
+
+ def get_model(self):
+ return self.model
+
+ def get_input_names(self):
+ pass
+
+ def get_output_names(self):
+ pass
+
+ def get_dynamic_axes(self):
+ return None
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ pass
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ return None
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ return None
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ onnx_opt_graph = opt.cleanup(return_onnx=True)
+ return onnx_opt_graph
+
+ def check_dims(self, batch_size, image_height, image_width):
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
+ assert image_height % 8 == 0 or image_width % 8 == 0
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
+ return (latent_height, latent_width)
+
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
+ min_batch = batch_size if static_batch else self.min_batch
+ max_batch = batch_size if static_batch else self.max_batch
+ latent_height = image_height // 8
+ latent_width = image_width // 8
+ min_image_height = image_height if static_shape else self.min_image_shape
+ max_image_height = image_height if static_shape else self.max_image_shape
+ min_image_width = image_width if static_shape else self.min_image_shape
+ max_image_width = image_width if static_shape else self.max_image_shape
+ min_latent_height = latent_height if static_shape else self.min_latent_shape
+ max_latent_height = latent_height if static_shape else self.max_latent_shape
+ min_latent_width = latent_width if static_shape else self.min_latent_shape
+ max_latent_width = latent_width if static_shape else self.max_latent_shape
+ return (
+ min_batch,
+ max_batch,
+ min_image_height,
+ max_image_height,
+ min_image_width,
+ max_image_width,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ )
+
+
+def getOnnxPath(model_name, onnx_dir, opt=True):
+ return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
+
+
+def getEnginePath(model_name, engine_dir):
+ return os.path.join(engine_dir, model_name + ".plan")
+
+
+def build_engines(
+ models: dict,
+ engine_dir,
+ onnx_dir,
+ onnx_opset,
+ opt_image_height,
+ opt_image_width,
+ opt_batch_size=1,
+ force_engine_rebuild=False,
+ static_batch=False,
+ static_shape=True,
+ enable_preview=False,
+ enable_all_tactics=False,
+ timing_cache=None,
+ max_workspace_size=0,
+):
+ built_engines = {}
+ if not os.path.isdir(onnx_dir):
+ os.makedirs(onnx_dir)
+ if not os.path.isdir(engine_dir):
+ os.makedirs(engine_dir)
+
+ # Export models to ONNX
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ if force_engine_rebuild or not os.path.exists(engine_path):
+ logger.warning("Building Engines...")
+ logger.warning("Engine build can take a while to complete")
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ if force_engine_rebuild or not os.path.exists(onnx_path):
+ logger.warning(f"Exporting model: {onnx_path}")
+ model = model_obj.get_model()
+ with torch.inference_mode(), torch.autocast("cuda"):
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
+ torch.onnx.export(
+ model,
+ inputs,
+ onnx_path,
+ export_params=True,
+ opset_version=onnx_opset,
+ do_constant_folding=True,
+ input_names=model_obj.get_input_names(),
+ output_names=model_obj.get_output_names(),
+ dynamic_axes=model_obj.get_dynamic_axes(),
+ )
+ del model
+ torch.cuda.empty_cache()
+ gc.collect()
+ else:
+ logger.warning(f"Found cached model: {onnx_path}")
+
+ # Optimize onnx
+ if force_engine_rebuild or not os.path.exists(onnx_opt_path):
+ logger.warning(f"Generating optimizing model: {onnx_opt_path}")
+ onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
+ onnx.save(onnx_opt_graph, onnx_opt_path)
+ else:
+ logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
+
+ # Build TensorRT engines
+ for model_name, model_obj in models.items():
+ engine_path = getEnginePath(model_name, engine_dir)
+ engine = Engine(engine_path)
+ onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
+ onnx_opt_path = getOnnxPath(model_name, onnx_dir)
+
+ if force_engine_rebuild or not os.path.exists(engine.engine_path):
+ engine.build(
+ onnx_opt_path,
+ fp16=True,
+ input_profile=model_obj.get_input_profile(
+ opt_batch_size,
+ opt_image_height,
+ opt_image_width,
+ static_batch=static_batch,
+ static_shape=static_shape,
+ ),
+ enable_preview=enable_preview,
+ timing_cache=timing_cache,
+ workspace_size=max_workspace_size,
+ )
+ built_engines[model_name] = engine
+
+ # Load and activate TensorRT engines
+ for model_name, model_obj in models.items():
+ engine = built_engines[model_name]
+ engine.load()
+ engine.activate()
+
+ return built_engines
+
+
+def runEngine(engine, feed_dict, stream):
+ return engine.infer(feed_dict, stream)
+
+
+class CLIP(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(CLIP, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "CLIP"
+
+ def get_input_names(self):
+ return ["input_ids"]
+
+ def get_output_names(self):
+ return ["text_embeddings", "pooler_output"]
+
+ def get_dynamic_axes(self):
+ return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ self.check_dims(batch_size, image_height, image_width)
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
+ batch_size, image_height, image_width, static_batch, static_shape
+ )
+ return {
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return {
+ "input_ids": (batch_size, self.text_maxlen),
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ self.check_dims(batch_size, image_height, image_width)
+ return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
+
+ def optimize(self, onnx_graph):
+ opt = Optimizer(onnx_graph)
+ opt.select_outputs([0]) # delete graph output#1
+ opt.cleanup()
+ opt.fold_constants()
+ opt.infer_shapes()
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
+ opt_onnx_graph = opt.cleanup(return_onnx=True)
+ return opt_onnx_graph
+
+
+def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class UNet(BaseModel):
+ def __init__(
+ self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
+ ):
+ super(UNet, self).__init__(
+ model=model,
+ fp16=fp16,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ text_maxlen=text_maxlen,
+ )
+ self.unet_dim = unet_dim
+ self.name = "UNet"
+
+ def get_input_names(self):
+ return ["sample", "timestep", "encoder_hidden_states"]
+
+ def get_output_names(self):
+ return ["latent"]
+
+ def get_dynamic_axes(self):
+ return {
+ "sample": {0: "2B", 2: "H", 3: "W"},
+ "encoder_hidden_states": {0: "2B"},
+ "latent": {0: "2B", 2: "H", 3: "W"},
+ }
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "sample": [
+ (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
+ (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
+ ],
+ "encoder_hidden_states": [
+ (2 * min_batch, self.text_maxlen, self.embedding_dim),
+ (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ (2 * max_batch, self.text_maxlen, self.embedding_dim),
+ ],
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
+ "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
+ "latent": (2 * batch_size, 4, latent_height, latent_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ dtype = torch.float16 if self.fp16 else torch.float32
+ return (
+ torch.randn(
+ 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
+ ),
+ torch.tensor([1.0], dtype=torch.float32, device=self.device),
+ torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
+ )
+
+
+def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return UNet(
+ model,
+ fp16=True,
+ device=device,
+ max_batch_size=max_batch_size,
+ embedding_dim=embedding_dim,
+ unet_dim=(9 if inpaint else 4),
+ )
+
+
+class VAE(BaseModel):
+ def __init__(self, model, device, max_batch_size, embedding_dim):
+ super(VAE, self).__init__(
+ model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
+ )
+ self.name = "VAE decoder"
+
+ def get_input_names(self):
+ return ["latent"]
+
+ def get_output_names(self):
+ return ["images"]
+
+ def get_dynamic_axes(self):
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
+
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ (
+ min_batch,
+ max_batch,
+ _,
+ _,
+ _,
+ _,
+ min_latent_height,
+ max_latent_height,
+ min_latent_width,
+ max_latent_width,
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
+ return {
+ "latent": [
+ (min_batch, 4, min_latent_height, min_latent_width),
+ (batch_size, 4, latent_height, latent_width),
+ (max_batch, 4, max_latent_height, max_latent_width),
+ ]
+ }
+
+ def get_shape_dict(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return {
+ "latent": (batch_size, 4, latent_height, latent_width),
+ "images": (batch_size, 3, image_height, image_width),
+ }
+
+ def get_sample_input(self, batch_size, image_height, image_width):
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
+ return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
+
+
+def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
+ return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
+
+
+class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
+
+ This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: DDIMScheduler,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ requires_safety_checker: bool = True,
+ stages=["clip", "unet", "vae"],
+ image_height: int = 768,
+ image_width: int = 768,
+ max_batch_size: int = 16,
+ # ONNX export parameters
+ onnx_opset: int = 17,
+ onnx_dir: str = "onnx",
+ # TensorRT engine build parameters
+ engine_dir: str = "engine",
+ build_preview_features: bool = True,
+ force_engine_rebuild: bool = False,
+ timing_cache: str = "timing_cache",
+ ):
+ super().__init__(
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
+ )
+
+ self.vae.forward = self.vae.decode
+
+ self.stages = stages
+ self.image_height, self.image_width = image_height, image_width
+ self.inpaint = False
+ self.onnx_opset = onnx_opset
+ self.onnx_dir = onnx_dir
+ self.engine_dir = engine_dir
+ self.force_engine_rebuild = force_engine_rebuild
+ self.timing_cache = timing_cache
+ self.build_static_batch = False
+ self.build_dynamic_shape = False
+ self.build_preview_features = build_preview_features
+
+ self.max_batch_size = max_batch_size
+ # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
+ if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
+ self.max_batch_size = 4
+
+ self.stream = None # loaded in loadResources()
+ self.models = {} # loaded in __loadModels()
+ self.engine = {} # loaded in build_engines()
+
+ def __loadModels(self):
+ # Load pipeline models
+ self.embedding_dim = self.text_encoder.config.hidden_size
+ models_args = {
+ "device": self.torch_device,
+ "max_batch_size": self.max_batch_size,
+ "embedding_dim": self.embedding_dim,
+ "inpaint": self.inpaint,
+ }
+ if "clip" in self.stages:
+ self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
+ if "unet" in self.stages:
+ self.models["unet"] = make_UNet(self.unet, **models_args)
+ if "vae" in self.stages:
+ self.models["vae"] = make_VAE(self.vae, **models_args)
+
+ @classmethod
+ def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+
+ cls.cached_folder = (
+ pretrained_model_name_or_path
+ if os.path.isdir(pretrained_model_name_or_path)
+ else snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ )
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
+ super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
+
+ self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
+ self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
+ self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
+
+ # set device
+ self.torch_device = self._execution_device
+ logger.warning(f"Running inference on device: {self.torch_device}")
+
+ # load models
+ self.__loadModels()
+
+ # build engines
+ self.engine = build_engines(
+ self.models,
+ self.engine_dir,
+ self.onnx_dir,
+ self.onnx_opset,
+ opt_image_height=self.image_height,
+ opt_image_width=self.image_width,
+ force_engine_rebuild=self.force_engine_rebuild,
+ static_batch=self.build_static_batch,
+ static_shape=not self.build_dynamic_shape,
+ enable_preview=self.build_preview_features,
+ timing_cache=self.timing_cache,
+ )
+
+ return self
+
+ def __encode_prompt(self, prompt, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ """
+ # Tokenize prompt
+ text_input_ids = (
+ self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+
+ text_input_ids_inp = device_view(text_input_ids)
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
+ text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
+ "text_embeddings"
+ ].clone()
+
+ # Tokenize negative prompt
+ uncond_input_ids = (
+ self.tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ .input_ids.type(torch.int32)
+ .to(self.torch_device)
+ )
+ uncond_input_ids_inp = device_view(uncond_input_ids)
+ uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
+ "text_embeddings"
+ ]
+
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
+
+ return text_embeddings
+
+ def __denoise_latent(
+ self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
+ ):
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = self.scheduler.timesteps
+ for step_index, timestep in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+ if isinstance(mask, torch.Tensor):
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # Predict the noise residual
+ timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
+
+ sample_inp = device_view(latent_model_input)
+ timestep_inp = device_view(timestep_float)
+ embeddings_inp = device_view(text_embeddings)
+ noise_pred = runEngine(
+ self.engine["unet"],
+ {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
+ self.stream,
+ )["latent"]
+
+ # Perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
+
+ latents = 1.0 / 0.18215 * latents
+ return latents
+
+ def __decode_latent(self, latents):
+ images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
+ images = (images / 2 + 0.5).clamp(0, 1)
+ return images.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ def __loadResources(self, image_height, image_width, batch_size):
+ self.stream = cuda.Stream()
+
+ # Allocate buffers for TensorRT engine bindings
+ for model_name, obj in self.models.items():
+ self.engine[model_name].allocate_buffers(
+ shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+
+ """
+ self.generator = generator
+ self.denoising_steps = num_inference_steps
+ self.guidance_scale = guidance_scale
+
+ # Pre-compute latent input scales and linear multistep coefficients
+ self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
+
+ # Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
+
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+
+ if negative_prompt is not None and isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+
+ assert len(prompt) == len(negative_prompt)
+
+ if batch_size > self.max_batch_size:
+ raise ValueError(
+ f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
+ )
+
+ # load resources
+ self.__loadResources(self.image_height, self.image_width, batch_size)
+
+ with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
+ # CLIP text encoder
+ text_embeddings = self.__encode_prompt(prompt, negative_prompt)
+
+ # Pre-initialize latents
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ self.image_height,
+ self.image_width,
+ torch.float32,
+ self.torch_device,
+ generator,
+ )
+
+ # UNet denoiser
+ latents = self.__denoise_latent(latents, text_embeddings)
+
+ # VAE decode latent
+ images = self.__decode_latent(latents)
+
+ images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
+ images = self.numpy_to_pil(images)
+ return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/examples/community/stable_diffusion_xl_reference.py b/diffusers/examples/community/stable_diffusion_xl_reference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d2b1c7711284704c051cb9cdda8706e3192c2d1
--- /dev/null
+++ b/diffusers/examples/community/stable_diffusion_xl_reference.py
@@ -0,0 +1,807 @@
+# Based on stable_diffusion_reference.py
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers import StableDiffusionXLPipeline
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UpBlock2D,
+)
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.utils import PIL_INTERPOLATION, logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import UniPCMultistepScheduler
+ >>> from diffusers.utils import load_image
+
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
+
+ >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16").to('cuda:0')
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="1girl",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=True).images[0]
+
+ >>> result_img.show()
+ ```
+"""
+
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+
+ width = (width // 8) * 8
+
+ return height, width
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if not isinstance(image, torch.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = (image - 0.5) / 0.5
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.stack(image, dim=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage.to(device=device)
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ if refimage.dtype != self.vae.dtype:
+ refimage = refimage.to(dtype=self.vae.dtype)
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = torch.cat(ref_image_latents, dim=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
+
+ ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 0. Default height and width to unet
+ # height, width = self._default_height_width(height, width, ref_image)
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ # 4. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # 7. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attebtion and group norm
+ MODE = "write"
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .type_as(ref_image_latents)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 10. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 10.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # ref only part
+ noise = randn_tensor(
+ ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
+ )
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape(
+ 1,
+ ),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/diffusers/examples/community/stable_unclip.py b/diffusers/examples/community/stable_unclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..6acca20d6a78e5c76c80bc150ae48b3fcc7b0f71
--- /dev/null
+++ b/diffusers/examples/community/stable_unclip.py
@@ -0,0 +1,288 @@
+import types
+from typing import List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from diffusers.models import PriorTransformer
+from diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline
+from diffusers.schedulers import UnCLIPScheduler
+from diffusers.utils import logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
+ image = image.to(device=device)
+ image_embeddings = image # take image as image_embeddings
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ uncond_embeddings = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
+
+ return image_embeddings
+
+
+class StableUnCLIPPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModelWithProjection,
+ prior_scheduler: UnCLIPScheduler,
+ decoder_pipe_kwargs: Optional[dict] = None,
+ ):
+ super().__init__()
+
+ decoder_pipe_kwargs = {"image_encoder": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs
+
+ decoder_pipe_kwargs["torch_dtype"] = decoder_pipe_kwargs.get("torch_dtype", None) or prior.dtype
+
+ self.decoder_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
+ "lambdalabs/sd-image-variations-diffusers", **decoder_pipe_kwargs
+ )
+
+ # replace `_encode_image` method
+ self.decoder_pipe._encode_image = types.MethodType(_encode_image, self.decoder_pipe)
+
+ self.register_modules(
+ prior=prior,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ prior_scheduler=prior_scheduler,
+ )
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ text_embeddings = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return text_embeddings, text_encoder_hidden_states, text_mask
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.prior, "_hf_hook"):
+ return self.device
+ for module in self.prior.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ self.decoder_pipe.to(torch_device)
+ super().to(torch_device)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_images_per_prompt: int = 1,
+ prior_num_inference_steps: int = 25,
+ generator: Optional[torch.Generator] = None,
+ prior_latents: Optional[torch.FloatTensor] = None,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ prior_guidance_scale: float = 4.0,
+ decoder_guidance_scale: float = 8.0,
+ decoder_num_inference_steps: int = 50,
+ decoder_num_images_per_prompt: Optional[int] = 1,
+ decoder_eta: float = 0.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ if prompt is not None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ else:
+ batch_size = text_model_output[0].shape[0]
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
+
+ text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
+ )
+
+ # prior
+
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ text_embeddings.dtype,
+ device,
+ generator,
+ prior_latents,
+ self.prior_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=text_embeddings,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeddings = prior_latents
+
+ output = self.decoder_pipe(
+ image=image_embeddings,
+ height=height,
+ width=width,
+ num_inference_steps=decoder_num_inference_steps,
+ guidance_scale=decoder_guidance_scale,
+ generator=generator,
+ output_type=output_type,
+ return_dict=return_dict,
+ num_images_per_prompt=decoder_num_images_per_prompt,
+ eta=decoder_eta,
+ )
+ return output
diff --git a/diffusers/examples/community/text_inpainting.py b/diffusers/examples/community/text_inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd02049a4afb06af9eecedf866e1b387c3cb62be
--- /dev/null
+++ b/diffusers/examples/community/text_inpainting.py
@@ -0,0 +1,302 @@
+from typing import Callable, List, Optional, Union
+
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPSegForImageSegmentation,
+ CLIPSegProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+)
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, is_accelerate_available, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class TextInpainting(DiffusionPipeline):
+ r"""
+ Pipeline for text based inpainting using Stable Diffusion.
+ Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ segmentation_model ([`CLIPSegForImageSegmentation`]):
+ CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details.
+ segmentation_processor ([`CLIPSegProcessor`]):
+ CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the
+ [model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ segmentation_model: CLIPSegForImageSegmentation,
+ segmentation_processor: CLIPSegProcessor,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ segmentation_model=segmentation_model,
+ segmentation_processor=segmentation_processor,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[torch.FloatTensor, PIL.Image.Image],
+ text: str,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ text (`str``):
+ The text to use to generate the mask.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # We use the input text to generate the mask
+ inputs = self.segmentation_processor(
+ text=[text], images=[image], padding="max_length", return_tensors="pt"
+ ).to(self.device)
+ outputs = self.segmentation_model(**inputs)
+ mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()
+ mask_pil = self.numpy_to_pil(mask)[0].resize(image.size)
+
+ # Run inpainting pipeline with the generated mask
+ inpainting_pipeline = StableDiffusionInpaintPipeline(
+ vae=self.vae,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ unet=self.unet,
+ scheduler=self.scheduler,
+ safety_checker=self.safety_checker,
+ feature_extractor=self.feature_extractor,
+ )
+ return inpainting_pipeline(
+ prompt=prompt,
+ image=image,
+ mask_image=mask_pil,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ output_type=output_type,
+ return_dict=return_dict,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
diff --git a/diffusers/examples/community/tiled_upscaling.py b/diffusers/examples/community/tiled_upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..81f4f9e4c626c1ebf820686516d64ba2c2fc5a93
--- /dev/null
+++ b/diffusers/examples/community/tiled_upscaling.py
@@ -0,0 +1,298 @@
+# Copyright 2023 Peter Willemsen . All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+def make_transparency_mask(size, overlap_pixels, remove_borders=[]):
+ size_x = size[0] - overlap_pixels * 2
+ size_y = size[1] - overlap_pixels * 2
+ for letter in ["l", "r"]:
+ if letter in remove_borders:
+ size_x += overlap_pixels
+ for letter in ["t", "b"]:
+ if letter in remove_borders:
+ size_y += overlap_pixels
+ mask = np.ones((size_y, size_x), dtype=np.uint8) * 255
+ mask = np.pad(mask, mode="linear_ramp", pad_width=overlap_pixels, end_values=0)
+
+ if "l" in remove_borders:
+ mask = mask[:, overlap_pixels : mask.shape[1]]
+ if "r" in remove_borders:
+ mask = mask[:, 0 : mask.shape[1] - overlap_pixels]
+ if "t" in remove_borders:
+ mask = mask[overlap_pixels : mask.shape[0], :]
+ if "b" in remove_borders:
+ mask = mask[0 : mask.shape[0] - overlap_pixels, :]
+ return mask
+
+
+def clamp(n, smallest, largest):
+ return max(smallest, min(n, largest))
+
+
+def clamp_rect(rect: [int], min: [int], max: [int]):
+ return (
+ clamp(rect[0], min[0], max[0]),
+ clamp(rect[1], min[1], max[1]),
+ clamp(rect[2], min[0], max[0]),
+ clamp(rect[3], min[1], max[1]),
+ )
+
+
+def add_overlap_rect(rect: [int], overlap: int, image_size: [int]):
+ rect = list(rect)
+ rect[0] -= overlap
+ rect[1] -= overlap
+ rect[2] += overlap
+ rect[3] += overlap
+ rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])
+ return rect
+
+
+def squeeze_tile(tile, original_image, original_slice, slice_x):
+ result = Image.new("RGB", (tile.size[0] + original_slice, tile.size[1]))
+ result.paste(
+ original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(
+ (slice_x, 0, slice_x + original_slice, tile.size[1])
+ ),
+ (0, 0),
+ )
+ result.paste(tile, (original_slice, 0))
+ return result
+
+
+def unsqueeze_tile(tile, original_image_slice):
+ crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])
+ tile = tile.crop(crop_rect)
+ return tile
+
+
+def next_divisible(n, d):
+ divisor = n % d
+ return n - divisor
+
+
+class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
+ r"""
+ Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute
+ to create gigantic images.
+
+ This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ low_res_scheduler ([`SchedulerMixin`]):
+ A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
+ [`DDPMScheduler`].
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ low_res_scheduler: DDPMScheduler,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ max_noise_level: int = 350,
+ ):
+ super().__init__(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ max_noise_level=max_noise_level,
+ )
+
+ def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):
+ torch.manual_seed(0)
+ crop_rect = (
+ min(image.size[0] - (tile_size + original_image_slice), x * tile_size),
+ min(image.size[1] - (tile_size + original_image_slice), y * tile_size),
+ min(image.size[0], (x + 1) * tile_size),
+ min(image.size[1], (y + 1) * tile_size),
+ )
+ crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)
+ tile = image.crop(crop_rect_with_overlap)
+ translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]
+ translated_slice_x = translated_slice_x - (original_image_slice / 2)
+ translated_slice_x = max(0, translated_slice_x)
+ to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)
+ orig_input_size = to_input.size
+ to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)
+ upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]
+ upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)
+ upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)
+ upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)
+ remove_borders = []
+ if x == 0:
+ remove_borders.append("l")
+ elif crop_rect[2] == image.size[0]:
+ remove_borders.append("r")
+ if y == 0:
+ remove_borders.append("t")
+ elif crop_rect[3] == image.size[1]:
+ remove_borders.append("b")
+ transparency_mask = Image.fromarray(
+ make_transparency_mask(
+ (upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders
+ ),
+ mode="L",
+ )
+ final_image.paste(
+ upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[PIL.Image.Image, List[PIL.Image.Image]],
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 50,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ tile_size: int = 128,
+ tile_border: int = 32,
+ original_image_slice: int = 32,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
+ `Image`, or tensor representing an image batch which will be upscaled. *
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ tile_size (`int`, *optional*):
+ The size of the tiles. Too big can result in an OOM-error.
+ tile_border (`int`, *optional*):
+ The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).
+ original_image_slice (`int`, *optional*):
+ The amount of pixels of the original image to calculate with the current tile (bigger means more depth
+ is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).
+ callback (`Callable`, *optional*):
+ A function that take a callback function with a single argument, a dict,
+ that contains the (partially) processed image under "image",
+ as well as the progress (0 to 1, where 1 is completed) under "progress".
+
+ Returns: A PIL.Image that is 4 times larger than the original input image.
+
+ """
+
+ final_image = Image.new("RGB", (image.size[0] * 4, image.size[1] * 4))
+ tcx = math.ceil(image.size[0] / tile_size)
+ tcy = math.ceil(image.size[1] / tile_size)
+ total_tile_count = tcx * tcy
+ current_count = 0
+ for y in range(tcy):
+ for x in range(tcx):
+ self._process_tile(
+ original_image_slice,
+ x,
+ y,
+ tile_size,
+ tile_border,
+ image,
+ final_image,
+ prompt=prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ eta=eta,
+ generator=generator,
+ latents=latents,
+ )
+ current_count += 1
+ if callback is not None:
+ callback({"progress": current_count / total_tile_count, "image": final_image})
+ return final_image
+
+
+def main():
+ # Run a demo
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
+ pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
+ pipe = pipe.to("cuda")
+ image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
+
+ def callback(obj):
+ print(f"progress: {obj['progress']:.4f}")
+ obj["image"].save("diffusers_library_progress.jpg")
+
+ final_image = pipe(image=image, prompt="Black font, white background, vector", noise_level=40, callback=callback)
+ final_image.save("diffusers_library.jpg")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/community/unclip_image_interpolation.py b/diffusers/examples/community/unclip_image_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..95548b152c0702d7474f38df0989b992f49848de
--- /dev/null
+++ b/diffusers/examples/community/unclip_image_interpolation.py
@@ -0,0 +1,496 @@
+import inspect
+from typing import List, Optional, Union
+
+import PIL.Image
+import torch
+from torch.nn import functional as F
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import (
+ DiffusionPipeline,
+ ImagePipelineOutput,
+ UnCLIPScheduler,
+ UNet2DConditionModel,
+ UNet2DModel,
+)
+from diffusers.pipelines.unclip import UnCLIPTextProjModel
+from diffusers.utils import is_accelerate_available, logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def slerp(val, low, high):
+ """
+ Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
+ """
+ low_norm = low / torch.norm(low)
+ high_norm = high / torch.norm(high)
+ omega = torch.acos((low_norm * high_norm))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
+ return res
+
+
+class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
+ """
+ Pipeline to generate variations from an input image using unCLIP
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `image_encoder`.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+
+ """
+
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ feature_extractor: CLIPImageProcessor
+ image_encoder: CLIPVisionModelWithProjection
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__
+ def __init__(
+ self,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image
+ def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if image_embeddings is None:
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+
+ image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+
+ return image_embeddings
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
+ when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ models = [
+ self.decoder,
+ self.text_proj,
+ self.text_encoder,
+ self.super_res_first,
+ self.super_res_last,
+ ]
+ for cpu_offloaded_model in models:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
+ return self.device
+ for module in self.decoder.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
+ steps: int = 5,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ image_embeddings: Optional[torch.Tensor] = None,
+ decoder_latents: Optional[torch.FloatTensor] = None,
+ super_res_latents: Optional[torch.FloatTensor] = None,
+ decoder_guidance_scale: float = 8.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`List[PIL.Image.Image]` or `torch.FloatTensor`):
+ The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
+ configuration of
+ [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
+ `CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
+ steps (`int`, *optional*, defaults to 5):
+ The number of interpolation images to generate.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ image_embeddings (`torch.Tensor`, *optional*):
+ Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
+ can be passed for tasks like image interpolations. `image` can the be left to `None`.
+ decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+
+ batch_size = steps
+
+ device = self._execution_device
+
+ if isinstance(image, List):
+ if len(image) != 2:
+ raise AssertionError(
+ f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"
+ )
+ elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):
+ raise AssertionError(
+ f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
+ )
+ elif isinstance(image, torch.FloatTensor):
+ if image.shape[0] != 2:
+ raise AssertionError(
+ f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
+ )
+ elif isinstance(image_embeddings, torch.Tensor):
+ if image_embeddings.shape[0] != 2:
+ raise AssertionError(
+ f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
+ )
+ else:
+ raise AssertionError(
+ f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
+ )
+
+ original_image_embeddings = self._encode_image(
+ image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings
+ )
+
+ image_embeddings = []
+
+ for interp_step in torch.linspace(0, 1, steps):
+ temp_image_embeddings = slerp(
+ interp_step, original_image_embeddings[0], original_image_embeddings[1]
+ ).unsqueeze(0)
+ image_embeddings.append(temp_image_embeddings)
+
+ image_embeddings = torch.cat(image_embeddings).to(device)
+
+ do_classifier_free_guidance = decoder_guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt=["" for i in range(steps)],
+ device=device,
+ num_images_per_prompt=1,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ prompt_embeds=prompt_embeds,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ if device.type == "mps":
+ # HACK: MPS: There is a panic when padding bool tensors,
+ # so cast to int tensor for the pad and back to bool afterwards
+ text_mask = text_mask.type(torch.int)
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
+ else:
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.config.in_channels
+ height = self.decoder.config.sample_size
+ width = self.decoder.config.sample_size
+
+ # Get the decoder latents for 1 step and then repeat the same tensor for the entire batch to keep same noise across all interpolation steps.
+ decoder_latents = self.prepare_latents(
+ (1, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ decoder_latents,
+ self.decoder_scheduler,
+ )
+ decoder_latents = decoder_latents.repeat((batch_size, 1, 1, 1))
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.config.in_channels // 2
+ height = self.super_res_first.config.sample_size
+ width = self.super_res_first.config.sample_size
+
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ super_res_latents,
+ self.super_res_scheduler,
+ )
+
+ if device.type == "mps":
+ # MPS does not support many interpolations
+ image_upscaled = F.interpolate(image_small, size=[height, width])
+ else:
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/unclip_text_interpolation.py b/diffusers/examples/community/unclip_text_interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..764299433b4cb7e4e21c87051428fddc51253e44
--- /dev/null
+++ b/diffusers/examples/community/unclip_text_interpolation.py
@@ -0,0 +1,574 @@
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from diffusers import (
+ DiffusionPipeline,
+ ImagePipelineOutput,
+ PriorTransformer,
+ UnCLIPScheduler,
+ UNet2DConditionModel,
+ UNet2DModel,
+)
+from diffusers.pipelines.unclip import UnCLIPTextProjModel
+from diffusers.utils import is_accelerate_available, logging
+from diffusers.utils.torch_utils import randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def slerp(val, low, high):
+ """
+ Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
+ """
+ low_norm = low / torch.norm(low)
+ high_norm = high / torch.norm(high)
+ omega = torch.acos((low_norm * high_norm))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
+ return res
+
+
+class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
+
+ """
+ Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ prior_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+
+ """
+
+ prior: PriorTransformer
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ prior_scheduler: UnCLIPScheduler
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ prior_scheduler: UnCLIPScheduler,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ prior_scheduler=prior_scheduler,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return prompt_embeds, text_encoder_hidden_states, text_mask
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
+ when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
+ models = [
+ self.decoder,
+ self.text_proj,
+ self.text_encoder,
+ self.super_res_first,
+ self.super_res_last,
+ ]
+ for cpu_offloaded_model in models:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
+ return self.device
+ for module in self.decoder.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ start_prompt: str,
+ end_prompt: str,
+ steps: int = 5,
+ prior_num_inference_steps: int = 25,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ prior_guidance_scale: float = 4.0,
+ decoder_guidance_scale: float = 8.0,
+ enable_sequential_cpu_offload=True,
+ gpu_id=0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ start_prompt (`str`):
+ The prompt to start the image generation interpolation from.
+ end_prompt (`str`):
+ The prompt to end the image generation interpolation at.
+ steps (`int`, *optional*, defaults to 5):
+ The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns
+ the same number of images as this value.
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):
+ If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
+ when their specific submodule has its `forward` method called.
+ gpu_id (`int`, *optional*, defaults to `0`):
+ The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+
+ if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):
+ raise ValueError(
+ f"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and"
+ f" {type(end_prompt)} instead"
+ )
+
+ if enable_sequential_cpu_offload:
+ self.enable_sequential_cpu_offload(gpu_id=gpu_id)
+
+ device = self._execution_device
+
+ # Turn the prompts into embeddings.
+ inputs = self.tokenizer(
+ [start_prompt, end_prompt],
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ inputs.to(device)
+ text_model_output = self.text_encoder(**inputs)
+
+ text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])
+ text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)
+
+ # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline
+ batch_text_embeds = []
+ batch_last_hidden_state = []
+
+ for interp_val in torch.linspace(0, 1, steps):
+ text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])
+ last_hidden_state = slerp(
+ interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]
+ )
+ batch_text_embeds.append(text_embeds.unsqueeze(0))
+ batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))
+
+ batch_text_embeds = torch.cat(batch_text_embeds)
+ batch_last_hidden_state = torch.cat(batch_last_hidden_state)
+
+ text_model_output = CLIPTextModelOutput(
+ text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state
+ )
+
+ batch_size = text_model_output[0].shape[0]
+
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
+
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt=None,
+ device=device,
+ num_images_per_prompt=1,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ text_model_output=text_model_output,
+ text_attention_mask=text_attention_mask,
+ )
+
+ # prior
+
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ prompt_embeds.dtype,
+ device,
+ generator,
+ None,
+ self.prior_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeddings = prior_latents
+
+ # done prior
+
+ # decoder
+
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ prompt_embeds=prompt_embeds,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ if device.type == "mps":
+ # HACK: MPS: There is a panic when padding bool tensors,
+ # so cast to int tensor for the pad and back to bool afterwards
+ text_mask = text_mask.type(torch.int)
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
+ else:
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.config.in_channels
+ height = self.decoder.config.sample_size
+ width = self.decoder.config.sample_size
+
+ decoder_latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ None,
+ self.decoder_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.config.in_channels // 2
+ height = self.super_res_first.config.sample_size
+ width = self.super_res_first.config.sample_size
+
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ None,
+ self.super_res_scheduler,
+ )
+
+ if device.type == "mps":
+ # MPS does not support many interpolations
+ image_upscaled = F.interpolate(image_small, size=[height, width])
+ else:
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/community/wildcard_stable_diffusion.py b/diffusers/examples/community/wildcard_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5ea350b857e5069ca42aa278bfe65d23bbe15f
--- /dev/null
+++ b/diffusers/examples/community/wildcard_stable_diffusion.py
@@ -0,0 +1,419 @@
+import inspect
+import os
+import random
+import re
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from diffusers.utils import deprecate, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+global_re_wildcard = re.compile(r"__([^_]*)__")
+
+
+def get_filename(path: str):
+ # this doesn't work on Windows
+ return os.path.basename(path).split(".txt")[0]
+
+
+def read_wildcard_values(path: str):
+ with open(path, encoding="utf8") as f:
+ return f.read().splitlines()
+
+
+def grab_wildcard_values(wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []):
+ for wildcard_file in wildcard_files:
+ filename = get_filename(wildcard_file)
+ read_values = read_wildcard_values(wildcard_file)
+ if filename not in wildcard_option_dict:
+ wildcard_option_dict[filename] = []
+ wildcard_option_dict[filename].extend(read_values)
+ return wildcard_option_dict
+
+
+def replace_prompt_with_wildcards(
+ prompt: str, wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []
+):
+ new_prompt = prompt
+
+ # get wildcard options
+ wildcard_option_dict = grab_wildcard_values(wildcard_option_dict, wildcard_files)
+
+ for m in global_re_wildcard.finditer(new_prompt):
+ wildcard_value = m.group()
+ replace_value = random.choice(wildcard_option_dict[wildcard_value.strip("__")])
+ new_prompt = new_prompt.replace(wildcard_value, replace_value, 1)
+
+ return new_prompt
+
+
+@dataclass
+class WildcardStableDiffusionOutput(StableDiffusionPipelineOutput):
+ prompts: List[str]
+
+
+class WildcardStableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Example Usage:
+ pipe = WildcardStableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+
+ torch_dtype=torch.float16,
+ )
+ prompt = "__animal__ sitting on a __object__ wearing a __clothing__"
+ out = pipe(
+ prompt,
+ wildcard_option_dict={
+ "clothing":["hat", "shirt", "scarf", "beret"]
+ },
+ wildcard_files=["object.txt", "animal.txt"],
+ num_prompt_samples=1
+ )
+
+
+ Pipeline for text-to-image generation with wild cards using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPImageProcessor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ wildcard_option_dict: Dict[str, List[str]] = {},
+ wildcard_files: List[str] = [],
+ num_prompt_samples: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ wildcard_option_dict (Dict[str, List[str]]):
+ dict with key as `wildcard` and values as a list of possible replacements. For example if a prompt, "A __animal__ sitting on a chair". A wildcard_option_dict can provide possible values for "animal" like this: {"animal":["dog", "cat", "fox"]}
+ wildcard_files: (List[str])
+ List of filenames of txt files for wildcard replacements. For example if a prompt, "A __animal__ sitting on a chair". A file can be provided ["animal.txt"]
+ num_prompt_samples: int
+ Number of times to sample wildcards for each prompt provided
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if isinstance(prompt, str):
+ prompt = [
+ replace_prompt_with_wildcards(prompt, wildcard_option_dict, wildcard_files)
+ for i in range(num_prompt_samples)
+ ]
+ batch_size = len(prompt)
+ elif isinstance(prompt, list):
+ prompt_list = []
+ for p in prompt:
+ for i in range(num_prompt_samples):
+ prompt_list.append(replace_prompt_with_wildcards(p, wildcard_option_dict, wildcard_files))
+ prompt = prompt_list
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return WildcardStableDiffusionOutput(images=image, nsfw_content_detected=has_nsfw_concept, prompts=prompt)
diff --git a/diffusers/examples/conftest.py b/diffusers/examples/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a48d18d1cc739f3fbf52c84a9c77afbf5694803
--- /dev/null
+++ b/diffusers/examples/conftest.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# tests directory-specific settings - this file is run automatically
+# by pytest before any tests are run
+
+import sys
+import warnings
+from os.path import abspath, dirname, join
+
+
+# allow having multiple repository checkouts and not needing to remember to rerun
+# 'pip install -e .[dev]' when switching between checkouts and running tests.
+git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
+sys.path.insert(1, git_repo_path)
+
+
+# silence FutureWarning warnings in tests since often we can't act on them until
+# they become normal warnings - i.e. the tests still need to test the current functionality
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+def pytest_addoption(parser):
+ from diffusers.utils.testing_utils import pytest_addoption_shared
+
+ pytest_addoption_shared(parser)
+
+
+def pytest_terminal_summary(terminalreporter):
+ from diffusers.utils.testing_utils import pytest_terminal_summary_main
+
+ make_reports = terminalreporter.config.getoption("--make-reports")
+ if make_reports:
+ pytest_terminal_summary_main(terminalreporter, id=make_reports)
diff --git a/diffusers/examples/consistency_distillation/README.md b/diffusers/examples/consistency_distillation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c584736dfe820cef1a1c73b1b979909734333e92
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/README.md
@@ -0,0 +1,104 @@
+# Latent Consistency Distillation Example:
+
+[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
+
+## Full model distillation
+
+### Running locally with PyTorch
+
+#### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+
+#### Example with LAION-A6+ dataset
+
+```bash
+runwayml/stable-diffusion-v1-5
+PROGRAM="train_lcm_distill_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
+
+## LCM-LoRA
+
+Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
+
+### Example with LAION-A6+ dataset
+
+```bash
+runwayml/stable-diffusion-v1-5
+PROGRAM="train_lcm_distill_lora_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --lora_rank=64 \
+ --learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
\ No newline at end of file
diff --git a/diffusers/examples/consistency_distillation/README_sdxl.md b/diffusers/examples/consistency_distillation/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..00577f9fa2b8d910a9decbcb9e492c3094086653
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/README_sdxl.md
@@ -0,0 +1,106 @@
+# Latent Consistency Distillation Example:
+
+[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
+
+## Full model distillation
+
+### Running locally with PyTorch
+
+#### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+
+#### Example with LAION-A6+ dataset
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+PROGRAM="train_lcm_distill_sdxl_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=1024 \
+ --learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
+
+## LCM-LoRA
+
+Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
+
+### Example with LAION-A6+ dataset
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=1024 \
+ --lora_rank=64 \
+ --learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
\ No newline at end of file
diff --git a/diffusers/examples/consistency_distillation/requirements.txt b/diffusers/examples/consistency_distillation/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..09fb84270a8a897c2082748fff25220b94d4532e
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+webdataset
\ No newline at end of file
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fa8d2c578320859e206310d59763e73707d2a7d
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -0,0 +1,1321 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
+ kohya_ss_state_dict = {}
+ for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
+ kohya_key = peft_key.replace("base_model.model", prefix)
+ kohya_key = kohya_key.replace("lora_A", "lora_down")
+ kohya_key = kohya_key.replace("lora_B", "lora_up")
+ kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
+ kohya_ss_state_dict[kohya_key] = weight.to(dtype)
+
+ # Set alpha parameter
+ if "lora_down" in kohya_key:
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
+
+ return kohya_ss_state_dict
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 512,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ safety_checker=None,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
+ pipeline.load_lora_weights(lora_state_dict)
+ pipeline.fuse_lora()
+
+ pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda", dtype=weight_dtype):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=1.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=5.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-1.5 checkpoint.
+ # import correct text encoder classes
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_teacher_model,
+ subfolder="vae",
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoder, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online (`unet`) student U-Nets.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.train()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=[
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ],
+ )
+ unet = get_peft_model(unet, lora_config)
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
+ StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
+ # save weights in peft format to be able to load them back
+ unet_.save_pretrained(output_dir)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ unet_.load_adapter(input_dir, "default", is_trainable=True)
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ # target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
+ return {"prompt_embeds": prompt_embeds}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ uncond_input_ids = tokenizer(
+ [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
+ ).input_ids.to(accelerator.device)
+ uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, _, _ = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text)
+
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+
+ # encode pixel values with batch size of at most 32
+ latents = []
+ for i in range(0, pixel_values.shape[0], 32):
+ latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=weight_dtype):
+ target_noise_pred = unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
+ StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..25faedf714b925a772dc48c745ae9ae5514786a6
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -0,0 +1,1377 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
+ kohya_ss_state_dict = {}
+ for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
+ kohya_key = peft_key.replace("base_model.model", prefix)
+ kohya_key = kohya_key.replace("lora_A", "lora_down")
+ kohya_key = kohya_key.replace("lora_B", "lora_up")
+ kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
+ kohya_ss_state_dict[kohya_key] = weight.to(dtype)
+
+ # Set alpha parameter
+ if "lora_down" in kohya_key:
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
+
+ return kohya_ss_state_dict
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 1024,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ use_fix_crop_and_size: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def get_orig_size(json):
+ if use_fix_crop_and_size:
+ return (resolution, resolution)
+ else:
+ return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue
+ ),
+ wds.map(filter_keys({"image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(transform),
+ wds.to_tuple("image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=960)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
+ pipeline.load_lora_weights(lora_state_dict)
+ pipeline.fuse_lora()
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda", dtype=weight_dtype):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=0.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-XL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoders, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online (`unet`) student U-Nets.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.train()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=[
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ],
+ )
+ unet = get_peft_model(unet, lora_config)
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
+ StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
+ # save weights in peft format to be able to load them back
+ unet_.save_pretrained(output_dir)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ unet_.load_adapter(input_dir, "default", is_trainable=True)
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ # target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ use_fix_crop_and_size=args.use_fix_crop_and_size,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Create uncond embeds for classifier free guidance
+ uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)
+ uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, orig_size, crop_coords = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # encode pixel values with batch size of at most 8
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
+ target_noise_pred = unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
+ StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4bf432f03d8fd6aa65c8ba168b23dbc7474da8
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -0,0 +1,1302 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 512,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="target"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ unet=unet,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda"):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=5.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ # ----Exponential Moving Average (EMA)----
+ parser.add_argument(
+ "--ema_decay",
+ type=float,
+ default=0.95,
+ required=False,
+ help="The exponential moving average (EMA) rate or decay factor.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-1.5 checkpoint.
+ # import correct text encoder classes
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_teacher_model,
+ subfolder="vae",
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoder, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
+ # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
+ if teacher_unet.config.time_cond_proj_dim is None:
+ teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
+ unet = UNet2DConditionModel(**teacher_unet.config)
+ # load teacher_unet weights into unet
+ unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+ unet.train()
+
+ # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
+ # Initialize from unet
+ target_unet = UNet2DConditionModel(**teacher_unet.config)
+ target_unet.load_state_dict(unet.state_dict())
+ target_unet.train()
+ target_unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 10. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ target_unet.to(accelerator.device)
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 11. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
+ target_unet.load_state_dict(load_model.state_dict())
+ target_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 12. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
+ return {"prompt_embeds": prompt_embeds}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ uncond_input_ids = tokenizer(
+ [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
+ ).input_ids.to(accelerator.device)
+ uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, _, _ = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text)
+
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+
+ # encode pixel values with batch size of at most 32
+ latents = []
+ for i in range(0, pixel_values.shape[0], 32):
+ latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
+ w = w.reshape(bsz, 1, 1, 1)
+ # Move to U-Net device and dtype
+ w = w.to(device=latents.device, dtype=latents.dtype)
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=weight_dtype):
+ target_noise_pred = target_unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 20.4.15. Make EMA update to target student model parameters
+ update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(os.path.join(args.output_dir, "unet"))
+
+ target_unet = accelerator.unwrap_model(target_unet)
+ target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/diffusers/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d2b1e10320856bbf481ea003dec3ff45d32ffdf
--- /dev/null
+++ b/diffusers/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -0,0 +1,1399 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 1024,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ use_fix_crop_and_size: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def get_orig_size(json):
+ if use_fix_crop_and_size:
+ return (resolution, resolution)
+ else:
+ return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue
+ ),
+ wds.map(filter_keys({"image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(transform),
+ wds.to_tuple("image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=960)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="target"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ unet=unet,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda"):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ # ----Exponential Moving Average (EMA)----
+ parser.add_argument(
+ "--ema_decay",
+ type=float,
+ default=0.95,
+ required=False,
+ help="The exponential moving average (EMA) rate or decay factor.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-XL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoders, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
+ # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
+ if teacher_unet.config.time_cond_proj_dim is None:
+ teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
+ unet = UNet2DConditionModel(**teacher_unet.config)
+ # load teacher_unet weights into unet
+ unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+ unet.train()
+
+ # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
+ # Initialize from unet
+ target_unet = UNet2DConditionModel(**teacher_unet.config)
+ target_unet.load_state_dict(unet.state_dict())
+ target_unet.train()
+ target_unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ target_unet.to(accelerator.device)
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
+ target_unet.load_state_dict(load_model.state_dict())
+ target_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ use_fix_crop_and_size=args.use_fix_crop_and_size,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Create uncond embeds for classifier free guidance
+ uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)
+ uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, orig_size, crop_coords = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # encode pixel values with batch size of at most 8
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=weight_dtype):
+ target_noise_pred = target_unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 20.4.15. Make EMA update to target student model parameters
+ update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(os.path.join(args.output_dir, "unet"))
+
+ target_unet = accelerator.unwrap_model(target_unet)
+ target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/controlnet/README.md b/diffusers/examples/controlnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..15b0170d512034bc21786f12f5ab3ccd35143f94
--- /dev/null
+++ b/diffusers/examples/controlnet/README.md
@@ -0,0 +1,465 @@
+# ControlNet training example
+
+[Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) by Lvmin Zhang and Maneesh Agrawala.
+
+This example is based on the [training example in the original ControlNet repository](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md). It trains a ControlNet to fill circles using a [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k).
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+
+Our training examples use [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) as the original set of ControlNet models were trained from it. However, ControlNet can be trained to augment any Stable Diffusion compatible model (such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)) or [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
+
+## Training
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=4
+```
+
+This default configuration requires ~38GB VRAM.
+
+By default, the training script logs outputs to tensorboard. Pass `--report_to wandb` to use weights and
+biases.
+
+Gradient accumulation with a smaller batch size can be used to reduce training requirements to ~20 GB VRAM.
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4
+```
+
+## Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=4 \
+ --mixed_precision="fp16" \
+ --tracker_project_name="controlnet-demo" \
+ --report_to=wandb
+```
+
+## Example results
+
+#### After 300 steps with batch size 8
+
+| | |
+|-------------------|:-------------------------:|
+| | red circle with blue background |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_300_steps.png) |
+| | cyan circle with brown floral background |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_300_steps.png) |
+
+
+#### After 6000 steps with batch size 8:
+
+| | |
+|-------------------|:-------------------------:|
+| | red circle with blue background |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_6000_steps.png) |
+| | cyan circle with brown floral background |
+![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_6000_steps.png) |
+
+## Training on a 16 GB GPU
+
+Optimizations:
+- Gradient checkpointing
+- bitsandbyte's 8-bit optimizer
+
+[bitandbytes install instructions](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam
+```
+
+## Training on a 12 GB GPU
+
+Optimizations:
+- Gradient checkpointing
+- bitsandbyte's 8-bit optimizer
+- xformers
+- set grads to none
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none
+```
+
+When using `enable_xformers_memory_efficient_attention`, please make sure to install `xformers` by `pip install xformers`.
+
+## Training on an 8 GB GPU
+
+We have not exhaustively tested DeepSpeed support for ControlNet. While the configuration does
+save memory, we have not confirmed the configuration to train successfully. You will very likely
+have to make changes to the config to have a successful training run.
+
+Optimizations:
+- Gradient checkpointing
+- xformers
+- set grads to none
+- DeepSpeed stage 2 with parameter and optimizer offloading
+- fp16 mixed precision
+
+[DeepSpeed](https://www.deepspeed.ai/) can offload tensors from VRAM to either
+CPU or NVME. This requires significantly more RAM (about 25 GB).
+
+Use `accelerate config` to enable DeepSpeed stage 2.
+
+The relevant parts of the resulting accelerate config file are
+
+```yaml
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 4
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+```
+
+See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --mixed_precision fp16
+```
+
+## Performing inference with the trained ControlNet
+
+The trained model can be run the same as the original ControlNet pipeline with the newly trained ControlNet.
+Set `base_model_path` and `controlnet_path` to the values `--pretrained_model_name_or_path` and
+`--output_dir` were respectively set to in the training script.
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "path to model"
+controlnet_path = "path to controlnet"
+
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusionControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
+)
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+pipe.enable_xformers_memory_efficient_attention()
+# memory optimization.
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+### Running on Google Cloud TPU
+
+See below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).
+
+First create a single TPUv4-8 VM and connect to it:
+
+```
+ZONE=us-central2-b
+TPU_TYPE=v4-8
+VM_NAME=hg_flax
+
+gcloud alpha compute tpus tpu-vm create $VM_NAME \
+ --zone $ZONE \
+ --accelerator-type $TPU_TYPE \
+ --version tpu-vm-v4-base
+
+gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
+```
+
+When connected install JAX `0.4.5`:
+
+```
+pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+To verify that JAX was correctly installed, you can run the following command:
+
+```
+import jax
+jax.device_count()
+```
+
+This should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.
+
+Then install Diffusers and the library's training dependencies:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+If you want to use Weights and Biases logging, you should also install `wandb` now
+
+```bash
+pip install wandb
+```
+
+
+Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
+
+```
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
+
+```
+huggingface-cli login
+```
+
+Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="runs/fill-circle-{timestamp}"
+export HUB_MODEL_ID="controlnet-fill-circle"
+```
+
+And finally start the training
+
+```bash
+python3 train_controlnet_flax.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=1000 \
+ --train_batch_size=2 \
+ --revision="non-ema" \
+ --from_pt \
+ --report_to="wandb" \
+ --tracker_project_name=$HUB_MODEL_ID \
+ --num_train_epochs=11 \
+ --push_to_hub \
+ --hub_model_id=$HUB_MODEL_ID
+ ```
+
+Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
+
+Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
+
+```bash
+export MODEL_DIR="runwayml/stable-diffusion-v1-5"
+export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
+export HUB_MODEL_ID="controlnet-uncanny-faces"
+
+python3 train_controlnet_flax.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=multimodalart/facesyntheticsspigacaptioned \
+ --streaming \
+ --conditioning_image_column=spiga_seg \
+ --image_column=image \
+ --caption_column=image_caption \
+ --resolution=512 \
+ --max_train_samples 100000 \
+ --learning_rate=1e-5 \
+ --train_batch_size=1 \
+ --revision="flax" \
+ --report_to="wandb" \
+ --tracker_project_name=$HUB_MODEL_ID
+```
+
+Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
+
+* [Webdataset](https://webdataset.github.io/webdataset/)
+* [TorchData](https://github.com/pytorch/data)
+* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
+
+When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
+
+```bash
+ --checkpointing_steps=500
+```
+This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
+
+You can then start your training from this saved checkpoint with
+
+```bash
+ --controlnet_model_name_or_path="./control_out/500"
+```
+
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
+
+We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
+
+You can **profile your code** with:
+
+```bash
+ --profile_steps==5
+```
+
+Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
+
+```bash
+pip install tensorflow tensorboard-plugin-profile
+tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
+```
+
+The profile can then be inspected at http://localhost:6006/#profile
+
+Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
+
+Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
+
+## Support for Stable Diffusion XL
+
+We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.
diff --git a/diffusers/examples/controlnet/README_sdxl.md b/diffusers/examples/controlnet/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a7797b9572c7319a5fc123b787d3c0b20ceb5aa
--- /dev/null
+++ b/diffusers/examples/controlnet/README_sdxl.md
@@ -0,0 +1,131 @@
+# ControlNet training example for Stable Diffusion XL (SDXL)
+
+The `train_controlnet_sdxl.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/controlnet` folder and run
+```bash
+pip install -r requirements_sdxl.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+
+## Training
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_controlnet_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --mixed_precision="fp16" \
+ --resolution=1024 \
+ --learning_rate=1e-5 \
+ --max_train_steps=15000 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=100 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --report_to="wandb" \
+ --seed=42 \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+controlnet_path = "path to controlnet"
+
+controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
+ base_model_path, controlnet=controlnet, torch_dtype=torch.float16
+)
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+pipe.enable_xformers_memory_efficient_attention()
+# memory optimization.
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Notes
+
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
diff --git a/diffusers/examples/controlnet/requirements.txt b/diffusers/examples/controlnet/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d19c62296702868c768596bdd866dd5b504e4180
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+datasets
diff --git a/diffusers/examples/controlnet/requirements_flax.txt b/diffusers/examples/controlnet/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb64e254625ee8eff2ef126d67adfd5b6994dc
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements_flax.txt
@@ -0,0 +1,9 @@
+transformers>=4.25.1
+datasets
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/controlnet/requirements_sdxl.txt b/diffusers/examples/controlnet/requirements_sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5ab6e9932e10a1e5337f3bc3faa8a192f4f60a52
--- /dev/null
+++ b/diffusers/examples/controlnet/requirements_sdxl.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+datasets
+wandb
diff --git a/diffusers/examples/controlnet/train_controlnet.py b/diffusers/examples/controlnet/train_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b6767a6f8f47765daa8cd19201797557ccd5f6
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet.py
@@ -0,0 +1,1128 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ ControlNetModel,
+ DDPMScheduler,
+ StableDiffusionControlNetPipeline,
+ UNet2DConditionModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ controlnet = accelerator.unwrap_model(controlnet)
+
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ controlnet=controlnet,
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"![images_{i})](./images_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- controlnet
+inference: true
+---
+ """
+ model_card = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def make_train_dataset(args, tokenizer, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if random.random() < args.proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ return train_dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "input_ids": input_ids,
+ }
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ controlnet = ControlNetModel.from_unet(unet)
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ controlnet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ controlnet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ controlnet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = accelerator.unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/controlnet/train_controlnet_flax.py b/diffusers/examples/controlnet/train_controlnet_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..b658f689358d24a845d41e83df60302ab4cf16bb
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet_flax.py
@@ -0,0 +1,1137 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import time
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import torch
+import torch.utils.checkpoint
+import transformers
+from datasets import load_dataset, load_from_disk
+from flax import jax_utils
+from flax.core.frozen_dict import unfreeze
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+from PIL import Image, PngImagePlugin
+from torch.utils.data import IterableDataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxControlNetModel,
+ FlaxDDPMScheduler,
+ FlaxStableDiffusionControlNetPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+
+
+# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
+# see more https://github.com/python-pillow/Pillow/issues/5610
+LARGE_ENOUGH_NUMBER = 100
+PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
+ logger.info("Running validation...")
+
+ pipeline_params = pipeline_params.copy()
+ pipeline_params["controlnet"] = controlnet_params
+
+ num_samples = jax.device_count()
+ prng_seed = jax.random.split(rng, jax.device_count())
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ prompts = num_samples * [validation_prompt]
+ prompt_ids = pipeline.prepare_text_inputs(prompts)
+ prompt_ids = shard(prompt_ids)
+
+ validation_image = Image.open(validation_image).convert("RGB")
+ processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
+ processed_image = shard(processed_image)
+ images = pipeline(
+ prompt_ids=prompt_ids,
+ image=processed_image,
+ params=pipeline_params,
+ prng_seed=prng_seed,
+ num_inference_steps=50,
+ jit=True,
+ ).images
+
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ images = pipeline.numpy_to_pil(images)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ if args.report_to == "wandb":
+ formatted_images = []
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ wandb.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {args.report_to}")
+
+ return image_logs
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"![images_{i})](./images_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- controlnet
+- jax-diffusers-event
+inference: true
+---
+ """
+ model_card = f"""
+# controlnet- {repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--from_pt",
+ action="store_true",
+ help="Load the pretrained model from a PyTorch checkpoint.",
+ )
+ parser.add_argument(
+ "--controlnet_revision",
+ type=str,
+ default=None,
+ help="Revision of controlnet model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--profile_steps",
+ type=int,
+ default=0,
+ help="How many training steps to profile in the beginning.",
+ )
+ parser.add_argument(
+ "--profile_validation",
+ action="store_true",
+ help="Whether to profile the (last) validation.",
+ )
+ parser.add_argument(
+ "--profile_memory",
+ action="store_true",
+ help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
+ )
+ parser.add_argument(
+ "--ccache",
+ type=str,
+ default=None,
+ help="Enables compilation cache.",
+ )
+ parser.add_argument(
+ "--controlnet_from_pt",
+ action="store_true",
+ help="Load the controlnet model from a PyTorch checkpoint.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="runs/{timestamp}",
+ help="The output directory where the model predictions and checkpoints will be written. "
+ "Can contain placeholders: {timestamp}.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=5000,
+ help=("Save a checkpoint of the training state every X updates."),
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_steps",
+ type=int,
+ default=100,
+ help=("log training metric every X steps to `--report_t`"),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="wandb",
+ help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
+ "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
+ "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--load_from_disk",
+ action="store_true",
+ help=(
+ "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
+ "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set. Needed if `streaming` is set to True."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` and logging the images."
+ ),
+ )
+ parser.add_argument("--wandb_entity", type=str, default=None, help=("The wandb entity to use (for teams)."))
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="train_controlnet_flax",
+ help=("The `project` argument passed to wandb"),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients over"
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ # This idea comes from
+ # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
+ if args.streaming and args.max_train_samples is None:
+ raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
+
+ return args
+
+
+def make_train_dataset(args, tokenizer, batch_size=None):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ streaming=args.streaming,
+ )
+ else:
+ if args.train_data_dir is not None:
+ if args.load_from_disk:
+ dataset = load_from_disk(
+ args.train_data_dir,
+ )
+ else:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ if isinstance(dataset["train"], IterableDataset):
+ column_names = next(iter(dataset["train"])).keys()
+ else:
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {caption_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if random.random() < args.proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ if jax.process_index() == 0:
+ if args.max_train_samples is not None:
+ if args.streaming:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
+ else:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ if args.streaming:
+ train_dataset = dataset["train"].map(
+ preprocess_train,
+ batched=True,
+ batch_size=batch_size,
+ remove_columns=list(dataset["train"].features.keys()),
+ )
+ else:
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ return train_dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "input_ids": input_ids,
+ }
+ batch = {k: v.numpy() for k, v in batch.items()}
+ return batch
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # wandb init
+ if jax.process_index() == 0 and args.report_to == "wandb":
+ wandb.init(
+ entity=args.wandb_entity,
+ project=args.tracker_project_name,
+ job_type="train",
+ config=args,
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ rng = jax.random.PRNGKey(0)
+
+ # Handle the repository creation
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ else:
+ raise NotImplementedError("No tokenizer specified!")
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
+ train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=not args.streaming,
+ collate_fn=collate_fn,
+ batch_size=total_train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ drop_last=True,
+ )
+
+ weight_dtype = jnp.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = jnp.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = jnp.bfloat16
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ dtype=weight_dtype,
+ revision=args.revision,
+ from_pt=args.from_pt,
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ subfolder="vae",
+ dtype=weight_dtype,
+ from_pt=args.from_pt,
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ dtype=weight_dtype,
+ revision=args.revision,
+ from_pt=args.from_pt,
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
+ args.controlnet_model_name_or_path,
+ revision=args.controlnet_revision,
+ from_pt=args.controlnet_from_pt,
+ dtype=jnp.float32,
+ )
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ rng, rng_params = jax.random.split(rng)
+
+ controlnet = FlaxControlNetModel(
+ in_channels=unet.config.in_channels,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ attention_head_dim=unet.config.attention_head_dim,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ use_linear_projection=unet.config.use_linear_projection,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ )
+ controlnet_params = controlnet.init_weights(rng=rng_params)
+ controlnet_params = unfreeze(controlnet_params)
+ for key in [
+ "conv_in",
+ "time_embedding",
+ "down_blocks_0",
+ "down_blocks_1",
+ "down_blocks_2",
+ "down_blocks_3",
+ "mid_block",
+ ]:
+ controlnet_params[key] = unet_params[key]
+
+ pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ controlnet=controlnet,
+ safety_checker=None,
+ dtype=weight_dtype,
+ revision=args.revision,
+ from_pt=args.from_pt,
+ )
+ pipeline_params = jax_utils.replicate(pipeline_params)
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ adamw = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(args.max_grad_norm),
+ adamw,
+ )
+
+ state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer)
+
+ noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+
+ # Initialize our training
+ validation_rng, train_rngs = jax.random.split(rng)
+ train_rngs = jax.random.split(train_rngs, jax.local_device_count())
+
+ def compute_snr(timesteps):
+ """
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
+ """
+ alphas_cumprod = noise_scheduler_state.common.alphas_cumprod
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
+
+ alpha = sqrt_alphas_cumprod[timesteps]
+ sigma = sqrt_one_minus_alphas_cumprod[timesteps]
+ # Compute SNR.
+ snr = (alpha / sigma) ** 2
+ return snr
+
+ def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):
+ # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1
+ if args.gradient_accumulation_steps > 1:
+ grad_steps = args.gradient_accumulation_steps
+ batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch)
+
+ def compute_loss(params, minibatch, sample_rng):
+ # Convert images to latent space
+ vae_outputs = vae.apply(
+ {"params": vae_params}, minibatch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ # Sample a random timestep for each image
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(
+ minibatch["input_ids"],
+ params=text_encoder_params,
+ train=False,
+ )[0]
+
+ controlnet_cond = minibatch["conditioning_pixel_values"]
+
+ # Predict the noise residual and compute loss
+ down_block_res_samples, mid_block_res_sample = controlnet.apply(
+ {"params": params},
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states,
+ controlnet_cond,
+ train=True,
+ return_dict=False,
+ )
+
+ model_pred = unet.apply(
+ {"params": unet_params},
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+
+ if args.snr_gamma is not None:
+ snr = jnp.array(compute_snr(timesteps))
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
+ loss = loss * snr_loss_weights
+
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+
+ # get a minibatch (one gradient accumulation slice)
+ def get_minibatch(batch, grad_idx):
+ return jax.tree_util.tree_map(
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
+ batch,
+ )
+
+ def loss_and_grad(grad_idx, train_rng):
+ # create minibatch for the grad step
+ minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch
+ sample_rng, train_rng = jax.random.split(train_rng, 2)
+ loss, grad = grad_fn(state.params, minibatch, sample_rng)
+ return loss, grad, train_rng
+
+ if args.gradient_accumulation_steps == 1:
+ loss, grad, new_train_rng = loss_and_grad(None, train_rng)
+ else:
+ init_loss_grad_rng = (
+ 0.0, # initial value for cumul_loss
+ jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad
+ train_rng, # initial value for train_rng
+ )
+
+ def cumul_grad_step(grad_idx, loss_grad_rng):
+ cumul_loss, cumul_grad, train_rng = loss_grad_rng
+ loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)
+ cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))
+ return cumul_loss, cumul_grad, new_train_rng
+
+ loss, grad, new_train_rng = jax.lax.fori_loop(
+ 0,
+ args.gradient_accumulation_steps,
+ cumul_grad_step,
+ init_loss_grad_rng,
+ )
+ loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad))
+
+ grad = jax.lax.pmean(grad, "batch")
+
+ new_state = state.apply_gradients(grads=grad)
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+
+ def l2(xs):
+ return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
+
+ metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
+
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ unet_params = jax_utils.replicate(unet_params)
+ text_encoder_params = jax_utils.replicate(text_encoder.params)
+ vae_params = jax_utils.replicate(vae_params)
+
+ # Train!
+ if args.streaming:
+ dataset_length = args.max_train_samples
+ else:
+ dataset_length = len(train_dataloader)
+ num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
+
+ if jax.process_index() == 0 and args.report_to == "wandb":
+ wandb.define_metric("*", step_metric="train/step")
+ wandb.define_metric("train/step", step_metric="walltime")
+ wandb.config.update(
+ {
+ "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
+ "total_train_batch_size": total_train_batch_size,
+ "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
+ "num_devices": jax.device_count(),
+ "controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
+ }
+ )
+
+ global_step = step0 = 0
+ epochs = tqdm(
+ range(args.num_train_epochs),
+ desc="Epoch ... ",
+ position=0,
+ disable=jax.process_index() > 0,
+ )
+ if args.profile_memory:
+ jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
+ t00 = t0 = time.monotonic()
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+ train_metric = None
+
+ steps_per_epoch = (
+ args.max_train_samples // total_train_batch_size
+ if args.streaming or args.max_train_samples
+ else len(train_dataset) // total_train_batch_size
+ )
+ train_step_progress_bar = tqdm(
+ total=steps_per_epoch,
+ desc="Training...",
+ position=1,
+ leave=False,
+ disable=jax.process_index() > 0,
+ )
+ # train
+ for batch in train_dataloader:
+ if args.profile_steps and global_step == 1:
+ train_metric["loss"].block_until_ready()
+ jax.profiler.start_trace(args.output_dir)
+ if args.profile_steps and global_step == 1 + args.profile_steps:
+ train_metric["loss"].block_until_ready()
+ jax.profiler.stop_trace()
+
+ batch = shard(batch)
+ with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
+ state, train_metric, train_rngs = p_train_step(
+ state, unet_params, text_encoder_params, vae_params, batch, train_rngs
+ )
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+
+ global_step += 1
+ if global_step >= args.max_train_steps:
+ break
+
+ if (
+ args.validation_prompt is not None
+ and global_step % args.validation_steps == 0
+ and jax.process_index() == 0
+ ):
+ _ = log_validation(
+ pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
+ )
+
+ if global_step % args.logging_steps == 0 and jax.process_index() == 0:
+ if args.report_to == "wandb":
+ train_metrics = jax_utils.unreplicate(train_metrics)
+ train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
+ wandb.log(
+ {
+ "walltime": time.monotonic() - t00,
+ "train/step": global_step,
+ "train/epoch": global_step / dataset_length,
+ "train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
+ **{f"train/{k}": v for k, v in train_metrics.items()},
+ }
+ )
+ t0, step0 = time.monotonic(), global_step
+ train_metrics = []
+ if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
+ controlnet.save_pretrained(
+ f"{args.output_dir}/{global_step}",
+ params=get_params_to_save(state.params),
+ )
+
+ train_metric = jax_utils.unreplicate(train_metric)
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Final validation & store model.
+ if jax.process_index() == 0:
+ if args.validation_prompt is not None:
+ if args.profile_validation:
+ jax.profiler.start_trace(args.output_dir)
+ image_logs = log_validation(
+ pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
+ )
+ if args.profile_validation:
+ jax.profiler.stop_trace()
+ else:
+ image_logs = None
+
+ controlnet.save_pretrained(
+ args.output_dir,
+ params=get_params_to_save(state.params),
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if args.profile_memory:
+ jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
+ logger.info("Finished training.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/controlnet/train_controlnet_sdxl.py b/diffusers/examples/controlnet/train_controlnet_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4fa96dae8ffcf8b1282422ec93d9532608b1a4b
--- /dev/null
+++ b/diffusers/examples/controlnet/train_controlnet_sdxl.py
@@ -0,0 +1,1237 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ ControlNetModel,
+ DDPMScheduler,
+ StableDiffusionXLControlNetPipeline,
+ UNet2DConditionModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ controlnet = accelerator.unwrap_model(controlnet)
+
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ controlnet=controlnet,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"![images_{i})](./images_{i}.png)\n"
+
+ yaml = f"""
+---
+license: openrail++
+base_model: {base_model}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- controlnet
+inference: true
+---
+ """
+ model_card = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="sd_xl_train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[args.image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+
+ add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples])
+ add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "prompt_ids": prompt_ids,
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
+ }
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ controlnet = ControlNetModel.from_unet(unet)
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ controlnet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+ unet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=torch.float32)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
+ original_size = (args.resolution, args.resolution)
+ target_size = (args.resolution, args.resolution)
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ prompt_batch = batch[args.caption_column]
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ train_dataset = get_train_dataset(args, accelerator)
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ )
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+
+ del text_encoders, tokenizers
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Then get the training dataset ready to be passed to the dataloader.
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ controlnet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ # Convert images to latent space
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+ latents = vae.encode(pixel_values).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # ControlNet conditioning.
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae, unet, controlnet, args, accelerator, weight_dtype, global_step
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = accelerator.unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/custom_diffusion/README.md b/diffusers/examples/custom_diffusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e686933feb51d09fd73c9b705620132ee1713f3c
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/README.md
@@ -0,0 +1,280 @@
+# Custom Diffusion training example
+
+[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
+The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+
+```bash
+pip install -r requirements.txt
+pip install clip-retrieval
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+### Cat example 😺
+
+Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
+
+We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
+The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
+```
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="./data/cat"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="cat" --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr --hflip \
+ --modifier_token ""
+```
+
+**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
+
+To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps:
+
+* Install `wandb`: `pip install wandb`.
+* Authorize: `wandb login`.
+* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
+ * `num_validation_images`
+ * `validation_steps`
+
+Here is an example command:
+
+```bash
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_cat/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="cat" --num_class_images=200 \
+ --instance_prompt="photo of a cat" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=250 \
+ --scale_lr --hflip \
+ --modifier_token "" \
+ --validation_prompt=" cat sitting in a bucket" \
+ --report_to="wandb"
+```
+
+Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
+
+If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
+
+### Training on multiple concepts 🐱🪵
+
+Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
+
+To collect the real images run this command for each concept in the json file.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
+```
+
+And then we're ready to start training!
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --output_dir=$OUTPUT_DIR \
+ --concepts_list=./concept_list.json \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=1e-5 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --num_class_images=200 \
+ --scale_lr --hflip \
+ --modifier_token "+"
+```
+
+Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
+
+### Training on human faces
+
+For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
+
+To collect the real images use this command first before training.
+
+```bash
+pip install clip-retrieval
+python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
+```
+
+Then start training!
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+export INSTANCE_DIR="path-to-images"
+
+accelerate launch train_custom_diffusion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --class_data_dir=./real_reg/samples_person/ \
+ --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
+ --class_prompt="person" --num_class_images=200 \
+ --instance_prompt="photo of a person" \
+ --resolution=512 \
+ --train_batch_size=2 \
+ --learning_rate=5e-6 \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1000 \
+ --scale_lr --hflip --noaug \
+ --freeze_model crossattn \
+ --modifier_token "" \
+ --enable_xformers_memory_efficient_attention
+```
+
+## Inference
+
+Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \ in above example) in your prompt.
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
+).to("cuda")
+pipe.unet.load_attn_procs(
+ "path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin"
+)
+pipe.load_textual_inversion("path-to-save-model", weight_name=".bin")
+
+image = pipe(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+It's possible to directly load these parameters from a Hub repository:
+
+```python
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+model_id = "sayakpaul/custom-diffusion-cat"
+card = RepoCard.load(model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
+"cuda")
+pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipe(
+ " cat sitting in a bucket",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("cat.png")
+```
+
+Here is an example of performing inference with multiple concepts:
+
+```python
+import torch
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+
+model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
+card = RepoCard.load(model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
+"cuda")
+pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+pipe.load_textual_inversion(model_id, weight_name=".bin")
+
+image = pipe(
+ "the cat sculpture in the style of a wooden pot",
+ num_inference_steps=100,
+ guidance_scale=6.0,
+ eta=1.0,
+).images[0]
+image.save("multi-subject.png")
+```
+
+Here, `cat` and `wooden pot` refer to the multiple concepts.
+
+### Inference from a training checkpoint
+
+You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
+
+TODO.
+
+## Set grads to none
+To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
+
+More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+## Experimental results
+You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html).
\ No newline at end of file
diff --git a/diffusers/examples/custom_diffusion/requirements.txt b/diffusers/examples/custom_diffusion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7d93f3d03bd8eba09b8cab5e570d15380456b66a
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/custom_diffusion/retrieve.py b/diffusers/examples/custom_diffusion/retrieve.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f050c15227b2e1157a38a0b7155f6c515df575d
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/retrieve.py
@@ -0,0 +1,87 @@
+# Copyright 2023 Custom Diffusion authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import os
+from io import BytesIO
+from pathlib import Path
+
+import requests
+from clip_retrieval.clip_client import ClipClient
+from PIL import Image
+from tqdm import tqdm
+
+
+def retrieve(class_prompt, class_data_dir, num_class_images):
+ factor = 1.5
+ num_images = int(factor * num_class_images)
+ client = ClipClient(
+ url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1
+ )
+
+ os.makedirs(f"{class_data_dir}/images", exist_ok=True)
+ if len(list(Path(f"{class_data_dir}/images").iterdir())) >= num_class_images:
+ return
+
+ while True:
+ class_images = client.query(text=class_prompt)
+ if len(class_images) >= factor * num_class_images or num_images > 1e4:
+ break
+ else:
+ num_images = int(factor * num_images)
+ client = ClipClient(
+ url="https://knn.laion.ai/knn-service",
+ indice_name="laion_400m",
+ num_images=num_images,
+ aesthetic_weight=0.1,
+ )
+
+ count = 0
+ total = 0
+ pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
+
+ with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
+ f"{class_data_dir}/images.txt", "w"
+ ) as f3:
+ while total < num_class_images:
+ images = class_images[count]
+ count += 1
+ try:
+ img = requests.get(images["url"], timeout=30)
+ if img.status_code == 200:
+ _ = Image.open(BytesIO(img.content))
+ with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
+ f.write(img.content)
+ f1.write(images["caption"] + "\n")
+ f2.write(images["url"] + "\n")
+ f3.write(f"{class_data_dir}/images/{total}.jpg" + "\n")
+ total += 1
+ pbar.update(1)
+ else:
+ continue
+ except Exception:
+ continue
+ return
+
+
+def parse_args():
+ parser = argparse.ArgumentParser("", add_help=False)
+ parser.add_argument("--class_prompt", help="text prompt to retrieve images", required=True, type=str)
+ parser.add_argument("--class_data_dir", help="path to save images", required=True, type=str)
+ parser.add_argument("--num_class_images", help="number of images to download", default=200, type=int)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ retrieve(args.class_prompt, args.class_data_dir, args.num_class_images)
diff --git a/diffusers/examples/custom_diffusion/train_custom_diffusion.py b/diffusers/examples/custom_diffusion/train_custom_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7f78841a81a08af515727150877482255fde826
--- /dev/null
+++ b/diffusers/examples/custom_diffusion/train_custom_diffusion.py
@@ -0,0 +1,1340 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 Custom Diffusion authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import HfApi, create_repo
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import (
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionXFormersAttnProcessor,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+instance_prompt: {prompt}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- custom-diffusion
+inference: true
+---
+ """
+ model_card = f"""
+# Custom Diffusion - {repo_id}
+
+These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
+{img_str}
+
+\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def collate_fn(examples, with_prior_preservation):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+ mask = [example["mask"] for example in examples]
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ mask += [example["class_mask"] for example in examples]
+
+ input_ids = torch.cat(input_ids, dim=0)
+ pixel_values = torch.stack(pixel_values)
+ mask = torch.stack(mask)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ mask = mask.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "mask": mask.unsqueeze(1)}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+class CustomDiffusionDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ concepts_list,
+ tokenizer,
+ size=512,
+ mask_size=64,
+ center_crop=False,
+ with_prior_preservation=False,
+ num_class_images=200,
+ hflip=False,
+ aug=True,
+ ):
+ self.size = size
+ self.mask_size = mask_size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.interpolation = Image.BILINEAR
+ self.aug = aug
+
+ self.instance_images_path = []
+ self.class_images_path = []
+ self.with_prior_preservation = with_prior_preservation
+ for concept in concepts_list:
+ inst_img_path = [
+ (x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()
+ ]
+ self.instance_images_path.extend(inst_img_path)
+
+ if with_prior_preservation:
+ class_data_root = Path(concept["class_data_dir"])
+ if os.path.isdir(class_data_root):
+ class_images_path = list(class_data_root.iterdir())
+ class_prompt = [concept["class_prompt"] for _ in range(len(class_images_path))]
+ else:
+ with open(class_data_root, "r") as f:
+ class_images_path = f.read().splitlines()
+ with open(concept["class_prompt"], "r") as f:
+ class_prompt = f.read().splitlines()
+
+ class_img_path = list(zip(class_images_path, class_prompt))
+ self.class_images_path.extend(class_img_path[:num_class_images])
+
+ random.shuffle(self.instance_images_path)
+ self.num_instance_images = len(self.instance_images_path)
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
+
+ self.image_transforms = transforms.Compose(
+ [
+ self.flip,
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def preprocess(self, image, scale, resample):
+ outer, inner = self.size, scale
+ factor = self.size // self.mask_size
+ if scale > self.size:
+ outer, inner = scale, self.size
+ top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1)
+ image = image.resize((scale, scale), resample=resample)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+ instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
+ mask = np.zeros((self.size // factor, self.size // factor))
+ if scale > self.size:
+ instance_image = image[top : top + inner, left : left + inner, :]
+ mask = np.ones((self.size // factor, self.size // factor))
+ else:
+ instance_image[top : top + inner, left : left + inner, :] = image
+ mask[
+ top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1
+ ] = 1.0
+ return instance_image, mask
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images]
+ instance_image = Image.open(instance_image)
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ instance_image = self.flip(instance_image)
+
+ # apply resize augmentation and create a valid image region mask
+ random_scale = self.size
+ if self.aug:
+ random_scale = (
+ np.random.randint(self.size // 3, self.size + 1)
+ if np.random.uniform() < 0.66
+ else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
+ )
+ instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation)
+
+ if random_scale < 0.6 * self.size:
+ instance_prompt = np.random.choice(["a far away ", "very small "]) + instance_prompt
+ elif random_scale > self.size:
+ instance_prompt = np.random.choice(["zoomed in ", "close up "]) + instance_prompt
+
+ example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1)
+ example["mask"] = torch.from_numpy(mask)
+ example["instance_prompt_ids"] = self.tokenizer(
+ instance_prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ if self.with_prior_preservation:
+ class_image, class_prompt = self.class_images_path[index % self.num_class_images]
+ class_image = Image.open(class_image)
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_mask"] = torch.ones_like(example["mask"])
+ example["class_prompt_ids"] = self.tokenizer(
+ class_prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return example
+
+
+def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
+ """Saves the new token embeddings from the text encoder."""
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
+ for x, y in zip(modifier_token_id, args.modifier_token):
+ learned_embeds_dict = {}
+ learned_embeds_dict[y] = learned_embeds[x]
+ filename = f"{output_dir}/{y}.bin"
+
+ if safe_serialization:
+ safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
+ else:
+ torch.save(learned_embeds_dict, filename)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Custom Diffusion training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=2,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument(
+ "--real_prior",
+ default=False,
+ action="store_true",
+ help="real images as prior.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=200,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="custom-diffusion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=250,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-5,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=2,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--freeze_model",
+ type=str,
+ default="crossattn_kv",
+ choices=["crossattn_kv", "crossattn"],
+ help="crossattn to enable fine-tuning of all params in the cross attention",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument(
+ "--concepts_list",
+ type=str,
+ default=None,
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--modifier_token",
+ type=str,
+ default=None,
+ help="A token to use as a modifier for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default="ktn+pll+ucd", help="A token to use as initializer word."
+ )
+ parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
+ parser.add_argument(
+ "--noaug",
+ action="store_true",
+ help="Dont apply augmentation during data augmentation when this flag is enabled.",
+ )
+ parser.add_argument(
+ "--no_safe_serialization",
+ action="store_true",
+ help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.concepts_list is None:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("custom-diffusion", config=vars(args))
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+ if args.concepts_list is None:
+ args.concepts_list = [
+ {
+ "instance_prompt": args.instance_prompt,
+ "class_prompt": args.class_prompt,
+ "instance_data_dir": args.instance_data_dir,
+ "class_data_dir": args.class_data_dir,
+ }
+ ]
+ else:
+ with open(args.concepts_list, "r") as f:
+ args.concepts_list = json.load(f)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ for i, concept in enumerate(args.concepts_list):
+ class_images_dir = Path(concept["class_data_dir"])
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True, exist_ok=True)
+ if args.real_prior:
+ assert (
+ class_images_dir / "images"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ len(list((class_images_dir / "images").iterdir())) == args.num_class_images
+ ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ class_images_dir / "caption.txt"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ assert (
+ class_images_dir / "images.txt"
+ ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
+ concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
+ concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
+ args.concepts_list[i] = concept
+ accelerator.wait_for_everyone()
+ else:
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader,
+ desc="Generating class images",
+ disable=not accelerator.is_local_main_process,
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = (
+ class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ )
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_name,
+ revision=args.revision,
+ use_fast=False,
+ )
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Adding a modifier token which is optimized ####
+ # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
+ modifier_token_id = []
+ initializer_token_id = []
+ if args.modifier_token is not None:
+ args.modifier_token = args.modifier_token.split("+")
+ args.initializer_token = args.initializer_token.split("+")
+ if len(args.modifier_token) > len(args.initializer_token):
+ raise ValueError("You must specify + separated initializer token for each modifier token.")
+ for modifier_token, initializer_token in zip(
+ args.modifier_token, args.initializer_token[: len(args.modifier_token)]
+ ):
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(modifier_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {modifier_token}. Please pass a different"
+ " `modifier_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode([initializer_token], add_special_tokens=False)
+ print(token_ids)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id.append(token_ids[0])
+ modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token))
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ for x, y in zip(modifier_token_id, initializer_token_id):
+ token_embeds[x] = token_embeds[y]
+
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+ ########################################################
+ ########################################################
+
+ vae.requires_grad_(False)
+ if args.modifier_token is None:
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ if accelerator.mixed_precision != "fp16" and args.modifier_token is not None:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ attention_class = (
+ CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor
+ )
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ attention_class = CustomDiffusionXFormersAttnProcessor
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new Custom Diffusion weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer
+ train_kv = True
+ train_q_out = False if args.freeze_model == "crossattn_kv" else True
+ custom_diffusion_attn_procs = {}
+
+ st = unet.state_dict()
+ for name, _ in unet.attn_processors.items():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
+ "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
+ }
+ if train_q_out:
+ weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
+ weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
+ weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
+ if cross_attention_dim is not None:
+ custom_diffusion_attn_procs[name] = attention_class(
+ train_kv=train_kv,
+ train_q_out=train_q_out,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ ).to(unet.device)
+ custom_diffusion_attn_procs[name].load_state_dict(weights)
+ else:
+ custom_diffusion_attn_procs[name] = attention_class(
+ train_kv=False,
+ train_q_out=False,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ )
+ del st
+ unet.set_attn_processor(custom_diffusion_attn_procs)
+ custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
+
+ accelerator.register_for_checkpointing(custom_diffusion_layers)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.modifier_token is not None:
+ text_encoder.gradient_checkpointing_enable()
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+ if args.with_prior_preservation:
+ args.learning_rate = args.learning_rate * 2.0
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ optimizer = optimizer_class(
+ itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters())
+ if args.modifier_token is not None
+ else custom_diffusion_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = CustomDiffusionDataset(
+ concepts_list=args.concepts_list,
+ tokenizer=tokenizer,
+ with_prior_preservation=args.with_prior_preservation,
+ size=args.resolution,
+ mask_size=vae.encode(
+ torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)
+ )
+ .latent_dist.sample()
+ .size()[-1],
+ center_crop=args.center_crop,
+ num_class_images=args.num_class_images,
+ hflip=args.hflip,
+ aug=not args.noaug,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.modifier_token is not None:
+ custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.modifier_token is not None:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+ mask = torch.chunk(batch["mask"], 2, dim=0)[0]
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ mask = batch["mask"]
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
+ accelerator.backward(loss)
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if args.modifier_token is not None:
+ if accelerator.num_processes > 1:
+ grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad
+ else:
+ grads_text_encoder = text_encoder.get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]
+ for i in range(len(modifier_token_id[1:])):
+ index_grads_to_zero = index_grads_to_zero & (
+ torch.arange(len(tokenizer)) != modifier_token_id[i]
+ )
+ grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[
+ index_grads_to_zero, :
+ ].fill_(0)
+
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(text_encoder.parameters(), custom_diffusion_layers.parameters())
+ if args.modifier_token is not None
+ else custom_diffusion_layers.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ images = []
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[
+ 0
+ ]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the custom diffusion layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
+ save_new_embed(
+ text_encoder,
+ modifier_token_id,
+ accelerator,
+ args,
+ args.output_dir,
+ safe_serialization=not args.no_safe_serialization,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ weight_name = (
+ "pytorch_custom_diffusion_weights.safetensors"
+ if not args.no_safe_serialization
+ else "pytorch_custom_diffusion_weights.bin"
+ )
+ pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
+ for token in args.modifier_token:
+ token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
+ pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)
+
+ # run inference
+ if args.validation_prompt and args.num_validation_images > 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ )
+ api = HfApi(token=args.hub_token)
+ api.upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/README.md b/diffusers/examples/dreambooth/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0579e337939d36d2cceb637f7f3eeec6ffd8fefe
--- /dev/null
+++ b/diffusers/examples/dreambooth/README.md
@@ -0,0 +1,747 @@
+# DreamBooth training example
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+And launch the training using:
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400 \
+ --push_to_hub
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+
+### Training on a 16GB GPU:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+To install `bitsandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+
+### Training on a 12GB GPU:
+
+It is possible to run dreambooth on a 12GB GPU by using the following optimizations:
+- [gradient checkpointing and the 8-bit optimizer](#training-on-a-16gb-gpu)
+- [xformers](#training-with-xformers)
+- [setting grads to none](#set-grads-to-none)
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+
+### Training on a 8 GB GPU:
+
+By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
+tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
+
+DeepSpeed needs to be enabled with `accelerate config`. During configuration
+answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
+mixed precision and offloading both parameters and optimizer state to cpu it's
+possible to train on under 8 GB VRAM with a drawback of requiring significantly
+more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's special version of Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch --mixed_precision="fp16" train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --sample_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+### Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800 \
+ --push_to_hub
+```
+
+### Using DreamBooth for pipelines other than Stable Diffusion
+
+The [AltDiffusion pipeline](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion) also supports dreambooth fine-tuning. The process is the same as above, all you need to do is replace the `MODEL_NAME` like this:
+
+```
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
+or
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
+```
+
+### Inference
+
+Once you have trained a model using the above command, you can run inference simply using the `StableDiffusionPipeline`. Make sure to include the `identifier` (e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
+```
+
+### Inference from a training checkpoint
+
+You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
+
+## Training with Low-Rank Adaptation of Large Language Models (LoRA)
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
+
+In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
+- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
+the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+### Training
+
+Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
+
+First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
+Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
+
+Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
+
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="path-to-save-model"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch train_dreambooth_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --checkpointing_steps=100 \
+ --learning_rate=1e-4 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=50 \
+ --seed="0" \
+ --push_to_hub
+```
+
+**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
+use *1e-4* instead of the usual *2e-6*.___**
+
+The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
+
+The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
+You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
+
+Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we
+enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
+
+With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
+
+
+### Inference
+
+After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
+load the original pipeline:
+
+```python
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+import torch
+
+pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe.to("cuda")
+```
+
+Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
+
+```python
+pipe.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")
+```
+
+Finally, we can run the model in inference.
+
+```python
+image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
+```
+
+If you are loading the LoRA parameters from the Hub and if the Hub repository has
+a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then
+you can do:
+
+```py
+from huggingface_hub.repocard import RepoCard
+
+lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example"
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+...
+```
+
+If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA
+weights. For example:
+
+```python
+from huggingface_hub.repocard import RepoCard
+from diffusers import StableDiffusionPipeline
+import torch
+
+lora_model_id = "sayakpaul/dreambooth-text-encoder-test"
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.load_lora_weights(lora_model_id)
+image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
+```
+
+Note that the use of [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights) is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) for loading LoRA parameters. This is because
+`LoraLoaderMixin.load_lora_weights` can handle the following situations:
+
+* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:
+
+ ```py
+ pipe.load_lora_weights(lora_model_path)
+ ```
+
+* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___
+
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+
+### Training without prior preservation loss
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --max_train_steps=400
+```
+
+
+### Training with prior preservation loss
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+### Fine-tune text encoder with the UNet.
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="dog"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=2e-6 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
+
+You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
+
+### Set grads to none
+
+To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
+
+More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+### Experimental results
+You can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects.
+
+## IF
+
+You can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler
+[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0).
+
+Note that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed
+variance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you
+must also update the pipeline's scheduler config.
+
+```py
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")
+
+pipe.load_lora_weights("")
+
+# Update scheduler config to fixed variance schedule
+pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")
+```
+
+Additionally, a few alternative cli flags are needed for IF.
+
+`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution.
+
+`--pre_compute_text_embeddings`: IF uses [T5](https://huggingface.co/docs/transformers/model_doc/t5) for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate
+T5.
+
+`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number.
+
+`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder.
+
+### Tips and Tricks
+We find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless.
+
+For common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the
+upscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is "a sks dog", use "a dog" for your stage II prompt.
+
+For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than
+LoRA finetuning stage II.
+
+For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.
+
+For stage II, we find that lower learning rates are also needed.
+
+We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler
+used in the training scripts.
+
+### Stage II additional validation images
+
+The stage II validation requires images to upscale, we can download a downsized version of the training set:
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog_downsized"
+snapshot_download(
+ "diffusers/dog-example-downsized",
+ local_dir=local_dir,
+ repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+### IF stage I LoRA Dreambooth
+This training configuration requires ~28 GB VRAM.
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_lora"
+
+accelerate launch train_dreambooth_lora.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=64 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --scale_lr \
+ --max_train_steps=1200 \
+ --validation_prompt="a sks dog" \
+ --validation_epochs=25 \
+ --checkpointing_steps=100 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask
+```
+
+### IF stage II LoRA Dreambooth
+
+`--validation_images`: These images are upscaled during validation steps.
+
+`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II.
+
+`--learning_rate=1e-6`: Lower learning rate than stage I.
+
+`--resolution=256`: The upscaler expects higher resolution inputs
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_upscale"
+export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
+
+python train_dreambooth_lora.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=256 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-6 \
+ --max_train_steps=2000 \
+ --validation_prompt="a sks dog" \
+ --validation_epochs=100 \
+ --checkpointing_steps=500 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask \
+ --validation_images $VALIDATION_IMAGES \
+ --class_labels_conditioning=timesteps
+```
+
+### IF Stage I Full Dreambooth
+`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline
+with a T5 loaded from the original model.
+
+`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.
+
+`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is
+likely the learning rate can be increased with larger batch sizes.
+
+Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.
+
+`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
+
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_if"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=64 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-7 \
+ --max_train_steps=150 \
+ --validation_prompt "a photo of sks dog" \
+ --validation_steps 25 \
+ --text_encoder_use_attention_mask \
+ --tokenizer_max_length 77 \
+ --pre_compute_text_embeddings \
+ --use_8bit_adam \
+ --set_grads_to_none \
+ --skip_save_text_encoder \
+ --validation_scheduler DDPMScheduler \
+ --push_to_hub
+```
+
+### IF Stage II Full Dreambooth
+
+`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as
+1e-8.
+
+`--resolution=256`: The upscaler expects higher resolution inputs
+
+`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with
+faces required large effective batch sizes.
+
+```sh
+export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_upscale"
+export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
+
+accelerate launch train_dreambooth.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=256 \
+ --train_batch_size=2 \
+ --gradient_accumulation_steps=6 \
+ --learning_rate=5e-6 \
+ --max_train_steps=2000 \
+ --validation_prompt="a sks dog" \
+ --validation_steps=150 \
+ --checkpointing_steps=500 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask \
+ --validation_images $VALIDATION_IMAGES \
+ --class_labels_conditioning timesteps \
+ --validation_scheduler DDPMScheduler\
+ --push_to_hub
+```
+
+## Stable Diffusion XL
+
+We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
diff --git a/diffusers/examples/dreambooth/README_sdxl.md b/diffusers/examples/dreambooth/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..d78d1ef5d2dd27b30317bdda6ba50120ab4e934c
--- /dev/null
+++ b/diffusers/examples/dreambooth/README_sdxl.md
@@ -0,0 +1,207 @@
+# DreamBooth training example for Stable Diffusion XL (SDXL)
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
+
+The `train_dreambooth_lora_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
+
+> 💡 **Note**: For now, we only allow DreamBooth fine-tuning of the SDXL UNet via LoRA. LoRA is a parameter-efficient fine-tuning technique introduced in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/dreambooth` folder and run
+```bash
+pip install -r requirements_sdxl.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+### Dog toy example
+
+Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
+
+Now, we can launch training using:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="lora-trained-xl"
+export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
+
+accelerate launch train_dreambooth_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --pretrained_vae_model_name_or_path=$VAE_PATH \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision="fp16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=1e-5 \
+ --report_to="wandb" \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --validation_prompt="A photo of sks dog in a bucket" \
+ --validation_epochs=25 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Dog toy example with < 16GB VRAM
+
+By making use of [`gradient_checkpointing`](https://pytorch.org/docs/stable/checkpoint.html) (which is natively supported in Diffusers), [`xformers`](https://github.com/facebookresearch/xformers), and [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) libraries, you can train SDXL LoRAs with less than 16GB of VRAM by adding the following flags to your accelerate launch command:
+
+```diff
++ --enable_xformers_memory_efficient_attention \
++ --gradient_checkpointing \
++ --use_8bit_adam \
++ --mixed_precision="fp16" \
+```
+
+and making sure that you have the following libraries installed:
+
+```
+bitsandbytes>=0.40.0
+xformers>=0.0.20
+```
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline
+import torch
+
+lora_model_id = <"lora-sdxl-dreambooth-id">
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.load_lora_weights(lora_model_id)
+image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
+image.save("sks_dog.png")
+```
+
+We can further refine the outputs with the [Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0):
+
+```python
+from huggingface_hub.repocard import RepoCard
+from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
+import torch
+
+lora_model_id = <"lora-sdxl-dreambooth-id">
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+# Load the base pipeline and load the LoRA parameters into it.
+pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+pipe = pipe.to("cuda")
+pipe.load_lora_weights(lora_model_id)
+
+# Load the refiner.
+refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
+)
+refiner.to("cuda")
+
+prompt = "A picture of a sks dog in a bucket"
+generator = torch.Generator("cuda").manual_seed(0)
+
+# Run inference.
+image = pipe(prompt=prompt, output_type="latent", generator=generator).images[0]
+image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]
+image.save("refined_sks_dog.png")
+```
+
+Here's a side-by-side comparison of the with and without Refiner pipeline outputs:
+
+| Without Refiner | With Refiner |
+|---|---|
+| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |
+
+### Training with text encoder(s)
+
+Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
+
+* SDXL has two text encoders. So, we fine-tune both using LoRA.
+* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
+
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+## Notes
+
+In our experiments, we found that SDXL yields good initial results without extensive hyperparameter tuning. For example, without fine-tuning the text encoders and without using prior-preservation, we observed decent results. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
+
+## Results
+
+You can explore the results from a couple of our internal experiments by checking out this link: [https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl](https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl). Specifically, we used the same script with the exact same hyperparameters on the following datasets:
+
+* [Dogs](https://huggingface.co/datasets/diffusers/dog-example)
+* [Starbucks logo](https://huggingface.co/datasets/diffusers/starbucks-example)
+* [Mr. Potato Head](https://huggingface.co/datasets/diffusers/potato-head-example)
+* [Keramer face](https://huggingface.co/datasets/diffusers/keramer-face-example)
+
+## Running on a free-tier Colab Notebook
+
+Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).
diff --git a/diffusers/examples/dreambooth/requirements.txt b/diffusers/examples/dreambooth/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a612982f4abbaa64f83db52e411a1235a372259
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/dreambooth/requirements_flax.txt b/diffusers/examples/dreambooth/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f85ad523a3b46b65abf0138c05ecdd656e6845c
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/dreambooth/requirements_sdxl.txt b/diffusers/examples/dreambooth/requirements_sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a612982f4abbaa64f83db52e411a1235a372259
--- /dev/null
+++ b/diffusers/examples/dreambooth/requirements_sdxl.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/dreambooth/train_dreambooth.py b/diffusers/examples/dreambooth/train_dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..92b57b7286736b4318a4a6261b7c4ad5a306a738
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth.py
@@ -0,0 +1,1422 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import importlib
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, model_info, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ prompt=str,
+ repo_folder=None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+instance_prompt: {prompt}
+tags:
+- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
+- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
+- text-to-image
+- diffusers
+- dreambooth
+inference: true
+---
+ """
+ model_card = f"""
+# DreamBooth - {repo_id}
+
+This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
+You can find some example images in the following. \n
+{img_str}
+
+DreamBooth for the text encoder was enabled: {train_text_encoder}.
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ prompt_embeds,
+ negative_prompt_embeds,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ pipeline_args = {}
+
+ if vae is not None:
+ pipeline_args["vae"] = vae
+
+ if text_encoder is not None:
+ text_encoder = accelerator.unwrap_model(text_encoder)
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ module = importlib.import_module("diffusers")
+ scheduler_class = getattr(module, args.validation_scheduler)
+ pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ if args.validation_images is None:
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+ else:
+ for image in args.validation_images:
+ image = Image.open(image)
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more details"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+
+ parser.add_argument(
+ "--offset_noise",
+ action="store_true",
+ default=False,
+ help=(
+ "Fine-tuning against a modified noise"
+ " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--validation_scheduler",
+ type=str,
+ default="DPMSolverMultistepScheduler",
+ choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
+ help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ attention_mask = torch.cat(attention_mask, dim=0)
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def model_has_vae(args):
+ config_file_name = os.path.join("vae", AutoencoderKL.config_name)
+ if os.path.isdir(args.pretrained_model_name_or_path):
+ config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
+ return os.path.isfile(config_file_name)
+ else:
+ files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
+ return any(file.rfilename == config_file_name for file in files_in_repo)
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+
+ if model_has_vae(args):
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+ else:
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
+ # load transformers style into model
+ load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
+ model.config = load_model.config
+ else:
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if vae is not None:
+ vae.requires_grad_(False)
+
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training. copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
+ raise ValueError(
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
+ f" {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ if not args.train_text_encoder and text_encoder is not None:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the model input
+ if args.offset_noise:
+ noise = torch.randn_like(model_input) + 0.1 * torch.randn(
+ model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
+ )
+ else:
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
+ ).sample
+
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Compute instance loss
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ images = []
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ validation_prompt_encoder_hidden_states,
+ validation_prompt_negative_prompt_embeds,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline_args = {}
+
+ if text_encoder is not None:
+ pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
+
+ if args.skip_save_text_encoder:
+ pipeline_args["text_encoder"] = None
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_flax.py b/diffusers/examples/dreambooth/train_dreambooth_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8c385133e22d99dd04764069ed302d0c15c470
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_flax.py
@@ -0,0 +1,695 @@
+import argparse
+import logging
+import math
+import os
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from jax.experimental.compilation_cache import compilation_cache as cc
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+# Cache compiled models across invocations of this script.
+cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ rng = jax.random.PRNGKey(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
+ ):
+ prompt_ids = pipeline.prepare_inputs(example["prompt"])
+ prompt_ids = shard(prompt_ids)
+ p_params = jax_utils.replicate(params)
+ rng = jax.random.split(rng)[0]
+ sample_rng = jax.random.split(rng, jax.device_count())
+ images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ images = pipeline.numpy_to_pil(np.array(images))
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+
+ # Handle the repository creation
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ else:
+ raise NotImplementedError("No tokenizer specified!")
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad(
+ {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
+ ).input_ids
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ batch = {k: v.numpy() for k, v in batch.items()}
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ if len(train_dataset) < total_train_batch_size:
+ raise ValueError(
+ f"Training batch size is {total_train_batch_size}, but your dataset only contains"
+ f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
+ f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
+ )
+
+ weight_dtype = jnp.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = jnp.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = jnp.bfloat16
+
+ if args.pretrained_vae_name_or_path:
+ # TODO(patil-suraj): Upload flax weights for the VAE
+ vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
+ else:
+ vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ vae_arg,
+ dtype=weight_dtype,
+ **vae_kwargs,
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ adamw = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(args.max_grad_norm),
+ adamw,
+ )
+
+ unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
+ text_encoder_state = train_state.TrainState.create(
+ apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
+ )
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ if args.train_text_encoder:
+ params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params}
+ else:
+ params = {"unet": unet_state.params}
+
+ def compute_loss(params):
+ # Convert images to latent space
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ # Sample a random timestep for each image
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.train_text_encoder:
+ encoder_hidden_states = text_encoder_state.apply_fn(
+ batch["input_ids"], params=params["text_encoder"], dropout_rng=dropout_rng, train=True
+ )[0]
+ else:
+ encoder_hidden_states = text_encoder(
+ batch["input_ids"], params=text_encoder_state.params, train=False
+ )[0]
+
+ # Predict the noise residual
+ model_pred = unet.apply(
+ {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
+ target, target_prior = jnp.split(target, 2, axis=0)
+
+ # Compute instance loss
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ # Compute prior loss
+ prior_loss = (target_prior - model_pred_prior) ** 2
+ prior_loss = prior_loss.mean()
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(params)
+ grad = jax.lax.pmean(grad, "batch")
+
+ new_unet_state = unet_state.apply_gradients(grads=grad["unet"])
+ if args.train_text_encoder:
+ new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad["text_encoder"])
+ else:
+ new_text_encoder_state = text_encoder_state
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+
+ return new_unet_state, new_text_encoder_state, metrics, new_train_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
+
+ # Replicate the train state on each device
+ unet_state = jax_utils.replicate(unet_state)
+ text_encoder_state = jax_utils.replicate(text_encoder_state)
+ vae_params = jax_utils.replicate(vae_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ def checkpoint(step=None):
+ # Create the pipeline using the trained modules and save it.
+ scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
+ pipeline.save_pretrained(
+ outdir,
+ params={
+ "text_encoder": get_params_to_save(text_encoder_state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_state.params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ if args.push_to_hub:
+ message = f"checkpoint-{step}" if step is not None else "End of training"
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message=message,
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(
+ unet_state, text_encoder_state, vae_params, batch, train_rngs
+ )
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(jax.local_device_count())
+
+ global_step += 1
+ if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
+ checkpoint(global_step)
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ if jax.process_index() == 0:
+ checkpoint()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/dreambooth/train_dreambooth_lora.py b/diffusers/examples/dreambooth/train_dreambooth_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..b82dfa38c1727bbfafad62a305af7c0e92ebdf74
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_lora.py
@@ -0,0 +1,1464 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.models.attention_processor import (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+)
+from diffusers.models.lora import LoRALinearLayer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import unet_lora_state_dict
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+# TODO: This function should be removed once training scripts are rewritten in PEFT
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ def text_encoder_attn_modules(text_encoder):
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
+
+ attn_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+
+ return attn_modules
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ prompt=str,
+ repo_folder=None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+instance_prompt: {prompt}
+tags:
+- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
+- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA DreamBooth - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ try:
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+ except OSError:
+ # IF does not have a VAE so let's just set it to None
+ # We don't have to error out here
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # We only train the additional adapter LoRA layers
+ if vae is not None:
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
+ # => 32 layers
+
+ # Set correct lora layers
+ unet_lora_parameters = []
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
+ # Parse the attention module.
+ attn_module = unet
+ for n in attn_processor_name.split(".")[:-1]:
+ attn_module = getattr(attn_module, n)
+
+ # Set the `lora_layer` attribute of the attention-related matrices.
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=args.rank,
+ )
+ )
+
+ # Accumulate the LoRA params to optimize.
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+
+ if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
+ attn_module.add_k_proj.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.add_k_proj.in_features,
+ out_features=attn_module.add_k_proj.out_features,
+ rank=args.rank,
+ )
+ )
+ attn_module.add_v_proj.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.add_v_proj.in_features,
+ out_features=attn_module.add_v_proj.out_features,
+ rank=args.rank,
+ )
+ )
+ unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = unet_lora_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
+ text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ LoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
+ text_encoder_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth-lora", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
+ ).sample
+
+ # if model predicts variance, throw away the prediction. we will only train on the
+ # simplified training objective. This means that all schedulers using the fine tuned
+ # model must be configured to use one of the fixed variance variance types.
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, **scheduler_args
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": validation_prompt_encoder_hidden_states,
+ "negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ if args.validation_images is None:
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, generator=generator).images[0]
+ images.append(image)
+ else:
+ images = []
+ for image in args.validation_images:
+ image = Image.open(image)
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = unet_lora_state_dict(unet)
+
+ if text_encoder is not None and args.train_text_encoder:
+ text_encoder = accelerator.unwrap_model(text_encoder)
+ text_encoder = text_encoder.to(torch.float32)
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
+ else:
+ text_encoder_lora_layers = None
+
+ LoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py b/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e7887c1c13bb1ac15b49000f78d3965e7dd365
--- /dev/null
+++ b/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -0,0 +1,1732 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.models.lora import LoRALinearLayer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr, unet_lora_state_dict
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+# TODO: This function should be removed once training scripts are rewritten in PEFT
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ def text_encoder_attn_modules(text_encoder):
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
+
+ attn_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+
+ return attn_modules
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ instance_prompt=str,
+ validation_prompt=str,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = "widget:\n" if images else ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"""
+ - text: '{validation_prompt if validation_prompt else ' ' }'
+ output:
+ url:
+ "image_{i}.png"
+ """
+
+ yaml = f"""
+---
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- lora
+- template:sd-lora
+{img_str}
+base_model: {base_model}
+instance_prompt: {instance_prompt}
+license: openrail++
+---
+ """
+
+ model_card = f"""
+# SDXL LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA adaption weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+Weights for this model are available in Safetensors format.
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.instance_images[index % self.num_instance_images]
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # costum prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_parameters = []
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
+ # Parse the attention module.
+ attn_module = unet
+ for n in attn_processor_name.split(".")[:-1]:
+ attn_module = getattr(attn_module, n)
+
+ # Set the `lora_layer` attribute of the attention-related matrices.
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=args.rank,
+ )
+ )
+
+ # Accumulate the LoRA params to optimize.
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_one, dtype=torch.float32, rank=args.rank
+ )
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_two, dtype=torch.float32, rank=args.rank
+ )
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = unet_lora_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
+
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warn(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warn(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warn(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warn(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Computes additional embeddings/ids required by the SDXL UNet.
+ # regular text embeddings (when `train_text_encoder` is not True)
+ # pooled text embeddings
+ # time ids
+
+ def compute_time_ids():
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ original_size = (args.resolution, args.resolution)
+ target_size = (args.resolution, args.resolution)
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Handle instance prompt.
+ instance_time_ids = compute_time_ids()
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ class_time_ids = compute_time_ids()
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+ add_time_ids = instance_time_ids
+ if args.with_prior_preservation:
+ add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+
+ # set top parameter requires_grad = True for gradient checkpointing works
+ text_encoder_one.text_model.embeddings.requires_grad_(True)
+ text_encoder_two.text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+ elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
+ else:
+ elems_to_repeat_text_embeds = 1
+ elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
+
+ # Predict the noise residual
+ if not args.train_text_encoder:
+ unet_added_conditions = {
+ "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ }
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ ).sample
+ else:
+ unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two],
+ )
+ unet_added_conditions.update(
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, **scheduler_args
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = unet_lora_state_dict(unet)
+
+ if args.train_text_encoder:
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/inference/README.md b/diffusers/examples/inference/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..52d66be8e228d312f1d079e6c8123448b6fa86fd
--- /dev/null
+++ b/diffusers/examples/inference/README.md
@@ -0,0 +1,8 @@
+# Inference Examples
+
+**The inference examples folder is deprecated and will be removed in a future version**.
+**Officially supported inference examples can be found in the [Pipelines folder](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines)**.
+
+- For `Image-to-Image text-guided generation with Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
+- For `In-painting using Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
+- For `Tweak prompts reusing seeds and latents`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)
diff --git a/diffusers/examples/inference/image_to_image.py b/diffusers/examples/inference/image_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..86b46c4e606e039cb2ad80b341b2685694f883b4
--- /dev/null
+++ b/diffusers/examples/inference/image_to_image.py
@@ -0,0 +1,9 @@
+import warnings
+
+from diffusers import StableDiffusionImg2ImgPipeline # noqa F401
+
+
+warnings.warn(
+ "The `image_to_image.py` script is outdated. Please use directly `from diffusers import"
+ " StableDiffusionImg2ImgPipeline` instead."
+)
diff --git a/diffusers/examples/inference/inpainting.py b/diffusers/examples/inference/inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aad208ff34eb4d4ba1c6acfdfe0f97ac9afc4bc
--- /dev/null
+++ b/diffusers/examples/inference/inpainting.py
@@ -0,0 +1,9 @@
+import warnings
+
+from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline # noqa F401
+
+
+warnings.warn(
+ "The `inpainting.py` script is outdated. Please use directly `from diffusers import"
+ " StableDiffusionInpaintPipeline` instead."
+)
diff --git a/diffusers/examples/instruct_pix2pix/README.md b/diffusers/examples/instruct_pix2pix/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e615c282cad7a50a5b2d305412ecff712e1ed34
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/README.md
@@ -0,0 +1,196 @@
+# InstructPix2Pix training example
+
+[InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:
+
+
+
+
+
+The output is an "edited" image that reflects the edit instruction applied on the input image:
+
+
+
+
+
+The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.
+
+***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix
+training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+### Toy example
+
+As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
+is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.
+
+Configure environment variables such as the dataset identifier and the Stable Diffusion
+checkpoint:
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+```
+
+Now, we can launch training:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+Additionally, we support performing validation inference to monitor training progress
+with Weights and Biases. You can enable this feature with `report_to="wandb"`:
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \
+ --validation_prompt="make the mountains snowy" \
+ --seed=42 \
+ --report_to=wandb \
+ --push_to_hub
+ ```
+
+ We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
+
+ [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.
+
+ ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
+
+ ## Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
+ --dataset_name=sayakpaul/instructpix2pix-1000-samples \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+ ## Inference
+
+ Once training is complete, we can perform inference:
+
+ ```python
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionInstructPix2PixPipeline
+
+model_id = "your_model_id" # <- replace this
+pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png"
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+image = download_image(url)
+prompt = "wipe out the lake"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipe(prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+An example model repo obtained using this training script can be found
+here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix).
+
+We encourage you to play with the following three parameters to control
+speed and quality during performance:
+
+* `num_inference_steps`
+* `image_guidance_scale`
+* `guidance_scale`
+
+Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
+on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
+
+If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd).
+
+## Stable Diffusion XL
+
+There's an equivalent `train_instruct_pix2pix_sdxl.py` script for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to the docs [here](./README_sdxl.md) to learn more.
diff --git a/diffusers/examples/instruct_pix2pix/README_sdxl.md b/diffusers/examples/instruct_pix2pix/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..b8c2ffdc817526ca88a05f21117fff82ba31a9c0
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/README_sdxl.md
@@ -0,0 +1,197 @@
+# InstructPix2Pix SDXL training example
+
+***This is based on the original InstructPix2Pix training example.***
+
+[Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (or SDXL) is the latest image generation model that is tailored towards more photorealistic outputs with more detailed imagery and composition compared to previous SD models. It leverages a three times larger UNet backbone. The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder.
+
+The `train_instruct_pix2pix_sdxl.py` script shows how to implement the training procedure and adapt it for Stable Diffusion XL.
+
+***Disclaimer: Even though `train_instruct_pix2pix_sdxl.py` implements the InstructPix2Pix
+training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Refer to the original InstructPix2Pix training example for installing the dependencies.
+
+You will also need to get access of SDXL by filling the [form](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
+
+### Toy example
+
+As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
+is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.
+
+Configure environment variables such as the dataset identifier and the Stable Diffusion
+checkpoint:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+```
+
+Now, we can launch training:
+
+```bash
+accelerate launch train_instruct_pix2pix_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --push_to_hub
+```
+
+Additionally, we support performing validation inference to monitor training progress
+with Weights and Biases. You can enable this feature with `report_to="wandb"`:
+
+```bash
+accelerate launch train_instruct_pix2pix_sdxl.py \
+ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
+ --dataset_name=$DATASET_ID \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --val_image_url_or_path="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
+ --validation_prompt="make it in japan" \
+ --report_to=wandb \
+ --push_to_hub
+ ```
+
+ We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
+
+ [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.
+
+ ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
+
+ ## Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix_sdxl.py \
+ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
+ --dataset_name=$DATASET_ID \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --val_image_url_or_path="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
+ --validation_prompt="make it in japan" \
+ --report_to=wandb \
+ --push_to_hub
+```
+
+ ## Inference
+
+ Once training is complete, we can perform inference:
+
+ ```python
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionXLInstructPix2PixPipeline
+
+model_id = "your_model_id" # <- replace this
+pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+url = "https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg"
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+image = download_image(url)
+prompt = "make it Japan"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipe(prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+We encourage you to play with the following three parameters to control
+speed and quality during performance:
+
+* `num_inference_steps`
+* `image_guidance_scale`
+* `guidance_scale`
+
+Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact
+on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).
+
+If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd).
+
+## Compare between SD and SDXL
+
+We aim to understand the differences resulting from the use of SD-1.5 and SDXL-0.9 as pretrained models. To achieve this, we trained on the [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) using both of these pretrained models. The training script is as follows:
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5" or "stabilityai/stable-diffusion-xl-base-0.9"
+export DATASET_ID="fusing/instructpix2pix-1000-samples"
+
+accelerate launch train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --use_ema \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --random_flip \
+ --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --seed=42 \
+ --val_image_url="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
+ --validation_prompt="make it in Japan" \
+ --report_to=wandb \
+ --push_to_hub
+```
+
+We discovered that compared to training with SD-1.5 as the pretrained model, SDXL-0.9 results in a lower training loss value (SD-1.5 yields 0.0599, SDXL scores 0.0254). Moreover, from a visual perspective, the results obtained using SDXL demonstrated fewer artifacts and a richer detail. Notably, SDXL starts to preserve the structure of the original image earlier on.
+
+The following two GIFs provide intuitive visual results. We observed, for each step, what kind of results could be achieved using the image
+
+
+
+with "make it in Japan” as the prompt. It can be seen that SDXL starts preserving the details of the original image earlier, resulting in higher fidelity outcomes sooner.
+
+* SD-1.5: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_ip2p_training_val_img_progress.gif
+
+
diff --git a/diffusers/examples/instruct_pix2pix/requirements.txt b/diffusers/examples/instruct_pix2pix/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e18cc9e4215eaa760c8d29c946396dba9ff2c9ac
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
\ No newline at end of file
diff --git a/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9b1c9cc5b3b42cfc37717be1398f2e83731d23c
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -0,0 +1,1009 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import PIL
+import requests
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"),
+}
+WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--original_image_column",
+ type=str,
+ default="input_image",
+ help="The column of the dataset containing the original image on which edits where made.",
+ )
+ parser.add_argument(
+ "--edited_image_column",
+ type=str,
+ default="edited_image",
+ help="The column of the dataset containing the edited image.",
+ )
+ parser.add_argument(
+ "--edit_prompt_column",
+ type=str,
+ default="edit_prompt",
+ help="The column of the dataset containing the edit instruction.",
+ )
+ parser.add_argument(
+ "--val_image_url",
+ type=str,
+ default=None,
+ help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="instruct-pix2pix-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=256,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--conditioning_dropout_prob",
+ type=float,
+ default=None,
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def convert_to_np(image, resolution):
+ image = image.convert("RGB").resize((resolution, resolution))
+ return np.array(image).transpose(2, 0, 1)
+
+
+def download_image(url):
+ image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def main():
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
+ in_channels = 8
+ out_channels = unet.conv_in.out_channels
+ unet.register_to_config(in_channels=in_channels)
+
+ with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.original_image_column is None:
+ original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ original_image_column = args.original_image_column
+ if original_image_column not in column_names:
+ raise ValueError(
+ f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edit_prompt_column is None:
+ edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ edit_prompt_column = args.edit_prompt_column
+ if edit_prompt_column not in column_names:
+ raise ValueError(
+ f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edited_image_column is None:
+ edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
+ else:
+ edited_image_column = args.edited_image_column
+ if edited_image_column not in column_names:
+ raise ValueError(
+ f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(captions):
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+ )
+
+ def preprocess_images(examples):
+ original_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
+ )
+ edited_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
+ )
+ # We need to ensure that the original and the edited images undergo the same
+ # augmentation transforms.
+ images = np.concatenate([original_images, edited_images])
+ images = torch.tensor(images)
+ images = 2 * (images / 255) - 1
+ return train_transforms(images)
+
+ def preprocess_train(examples):
+ # Preprocess images.
+ preprocessed_images = preprocess_images(examples)
+ # Since the original and edited images were concatenated before
+ # applying the transformations, we need to separate them and reshape
+ # them accordingly.
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ # Collate the preprocessed images into the `examples`.
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ # Preprocess the captions.
+ captions = list(examples[edit_prompt_column])
+ examples["input_ids"] = tokenize_captions(captions)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
+ original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
+ edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
+ edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {
+ "original_pixel_values": original_pixel_values,
+ "edited_pixel_values": edited_pixel_values,
+ "input_ids": input_ids,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("instruct-pix2pix", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # We want to learn the denoising process w.r.t the edited images which
+ # are conditioned on the original image (which was edited) and the edit instruction.
+ # So, first, convert images to latent space.
+ latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning.
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the additional image embedding for conditioning.
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
+ original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
+
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ original_image_embeds = image_mask * original_image_embeds
+
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if (
+ (args.val_image_url is not None)
+ and (args.validation_prompt is not None)
+ and (epoch % args.validation_epochs == 0)
+ ):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ # The models need unwrapping because for compatibility in distributed training mode.
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=accelerator.unwrap_model(vae),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ original_image = download_image(args.val_image_url)
+ edited_images = []
+ with torch.autocast(
+ str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
+ ):
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(
+ wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
+ )
+ tracker.log({"validation": wandb_table})
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=accelerator.unwrap_model(vae),
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if args.validation_prompt is not None:
+ edited_images = []
+ pipeline = pipeline.to(accelerator.device)
+ with torch.autocast(str(accelerator.device).replace(":0", "")):
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(
+ wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
+ )
+ tracker.log({"test": wandb_table})
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b503cb29275d0d4c28ced48f88f08a37f616f2c
--- /dev/null
+++ b/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -0,0 +1,1219 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 Harutatsu Akiyama and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+from urllib.parse import urlparse
+
+import accelerate
+import datasets
+import numpy as np
+import PIL
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import (
+ StableDiffusionXLInstructPix2PixPipeline,
+)
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
+}
+WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
+TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Script to train Stable Diffusion XL for InstructPix2Pix.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--vae_precision",
+ type=str,
+ choices=["fp32", "fp16", "bf16"],
+ default="fp32",
+ help=(
+ "The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution"
+ " to this problem, and this flag allows you to use mixed precision to stabilize training."
+ ),
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--original_image_column",
+ type=str,
+ default="input_image",
+ help="The column of the dataset containing the original image on which edits where made.",
+ )
+ parser.add_argument(
+ "--edited_image_column",
+ type=str,
+ default="edited_image",
+ help="The column of the dataset containing the edited image.",
+ )
+ parser.add_argument(
+ "--edit_prompt_column",
+ type=str,
+ default="edit_prompt",
+ help="The column of the dataset containing the edit instruction.",
+ )
+ parser.add_argument(
+ "--val_image_url_or_path",
+ type=str,
+ default=None,
+ help="URL to the original image that you would like to edit (used during inference for debugging purposes).",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run fine-tuning validation every X steps. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="instruct-pix2pix-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=256,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this resolution."
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--conditioning_dropout_prob",
+ type=float,
+ default=None,
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def convert_to_np(image, resolution):
+ if isinstance(image, str):
+ image = PIL.Image.open(image)
+ image = image.convert("RGB").resize((resolution, resolution))
+ return np.array(image).transpose(2, 0, 1)
+
+
+def main():
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ logger.info("Initializing the XL InstructPix2Pix UNet from the pretrained UNet.")
+ in_channels = 8
+ out_channels = unet.conv_in.out_channels
+ unet.register_to_config(in_channels=in_channels)
+
+ with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.original_image_column is None:
+ original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ original_image_column = args.original_image_column
+ if original_image_column not in column_names:
+ raise ValueError(
+ f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edit_prompt_column is None:
+ edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ edit_prompt_column = args.edit_prompt_column
+ if edit_prompt_column not in column_names:
+ raise ValueError(
+ f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.edited_image_column is None:
+ edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]
+ else:
+ edited_image_column = args.edited_image_column
+ if edited_image_column not in column_names:
+ raise ValueError(
+ f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
+
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(captions, tokenizer):
+ inputs = tokenizer(
+ captions,
+ max_length=tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+ )
+
+ def preprocess_images(examples):
+ original_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[original_image_column]]
+ )
+ edited_images = np.concatenate(
+ [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
+ )
+ # We need to ensure that the original and the edited images undergo the same
+ # augmentation transforms.
+ images = np.concatenate([original_images, edited_images])
+ images = torch.tensor(images)
+ images = 2 * (images / 255) - 1
+ return train_transforms(images)
+
+ # Load scheduler, tokenizer and models.
+ tokenizer_1 = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_2 = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+ text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+ text_encoder_cls_2 = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_1 = text_encoder_cls_1.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_2 = text_encoder_cls_2.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+
+ # We ALWAYS pre-compute the additional condition embeddings needed for SDXL
+ # UNet as the model is already big and it uses two text encoders.
+ text_encoder_1.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+ tokenizers = [tokenizer_1, tokenizer_2]
+ text_encoders = [text_encoder_1, text_encoder_2]
+
+ # Freeze vae and text_encoders
+ vae.requires_grad_(False)
+ text_encoder_1.requires_grad_(False)
+ text_encoder_2.requires_grad_(False)
+
+ # Set UNet to trainable.
+ unet.train()
+
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(text_encoders, tokenizers, prompt):
+ prompt_embeds_list = []
+
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompts(text_encoders, tokenizers, prompts):
+ prompt_embeds_all = []
+ pooled_prompt_embeds_all = []
+
+ for prompt in prompts:
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds_all.append(prompt_embeds)
+ pooled_prompt_embeds_all.append(pooled_prompt_embeds)
+
+ return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)
+
+ # Adapted from examples.dreambooth.train_dreambooth_lora_sdxl
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts)
+ add_text_embeds_all = pooled_prompt_embeds_all
+
+ prompt_embeds_all = prompt_embeds_all.to(accelerator.device)
+ add_text_embeds_all = add_text_embeds_all.to(accelerator.device)
+ return prompt_embeds_all, add_text_embeds_all
+
+ # Get null conditioning
+ def compute_null_conditioning():
+ null_conditioning_list = []
+ for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders):
+ null_conditioning_list.append(
+ a_text_encoder(
+ tokenize_captions([""], tokenizer=a_tokenizer).to(accelerator.device),
+ output_hidden_states=True,
+ ).hidden_states[-2]
+ )
+ return torch.concat(null_conditioning_list, dim=-1)
+
+ null_conditioning = compute_null_conditioning()
+
+ def compute_time_ids():
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ original_size = target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)
+ return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1)
+
+ add_time_ids = compute_time_ids()
+
+ def preprocess_train(examples):
+ # Preprocess images.
+ preprocessed_images = preprocess_images(examples)
+ # Since the original and edited images were concatenated before
+ # applying the transformations, we need to separate them and reshape
+ # them accordingly.
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ # Collate the preprocessed images into the `examples`.
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ # Preprocess the captions.
+ captions = list(examples[edit_prompt_column])
+ prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers)
+ examples["prompt_embeds"] = prompt_embeds_all
+ examples["add_text_embeds"] = add_text_embeds_all
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
+ original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
+ edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
+ edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompt_embeds = torch.concat([example["prompt_embeds"] for example in examples], dim=0)
+ add_text_embeds = torch.concat([example["add_text_embeds"] for example in examples], dim=0)
+ return {
+ "original_pixel_values": original_pixel_values,
+ "edited_pixel_values": edited_pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "add_text_embeds": add_text_embeds,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # We want to learn the denoising process w.r.t the edited images which
+ # are conditioned on the original image (which was edited) and the edit instruction.
+ # So, first, convert images to latent space.
+ if args.pretrained_vae_model_name_or_path is not None:
+ edited_pixel_values = batch["edited_pixel_values"].to(dtype=weight_dtype)
+ else:
+ edited_pixel_values = batch["edited_pixel_values"]
+ latents = vae.encode(edited_pixel_values).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # SDXL additional inputs
+ encoder_hidden_states = batch["prompt_embeds"]
+ add_text_embeds = batch["add_text_embeds"]
+
+ # Get the additional image embedding for conditioning.
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
+ if args.pretrained_vae_model_name_or_path is not None:
+ original_pixel_values = batch["original_pixel_values"].to(dtype=weight_dtype)
+ else:
+ original_pixel_values = batch["original_pixel_values"]
+ original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample()
+ if args.pretrained_vae_model_name_or_path is None:
+ original_image_embeds = original_image_embeds.to(weight_dtype)
+
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ original_image_embeds = image_mask * original_image_embeds
+
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ model_pred = unet(
+ concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ ).sample
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ ### BEGIN: Perform validation every `validation_epochs` steps
+ if global_step % args.validation_steps == 0 or global_step == 1:
+ if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ # create pipeline
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # The models need unwrapping because for compatibility in distributed training mode.
+ pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=text_encoder_1,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ vae=vae,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ # Save validation images
+ val_save_dir = os.path.join(args.output_dir, "validation_images")
+ if not os.path.exists(val_save_dir):
+ os.makedirs(val_save_dir)
+
+ original_image = (
+ lambda image_url_or_path: load_image(image_url_or_path)
+ if urlparse(image_url_or_path).scheme
+ else Image.open(image_url_or_path).convert("RGB")
+ )(args.val_image_url_or_path)
+ with torch.autocast(
+ str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
+ ):
+ edited_images = []
+ for val_img_idx in range(args.num_validation_images):
+ a_val_img = pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ edited_images.append(a_val_img)
+ a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(
+ wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
+ )
+ tracker.log({"validation": wandb_table})
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ del pipeline
+ torch.cuda.empty_cache()
+ ### END: Perform validation every `validation_epochs` steps
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder_1,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer_1,
+ tokenizer_2=tokenizer_2,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ if args.validation_prompt is not None:
+ edited_images = []
+ pipeline = pipeline.to(accelerator.device)
+ with torch.autocast(str(accelerator.device).replace(":0", "")):
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(
+ wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
+ )
+ tracker.log({"test": wandb_table})
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/README.md b/diffusers/examples/kandinsky2_2/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e5a1835593fa7e3c9bec8bfdf2ee4e9ace7af71
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/README.md
@@ -0,0 +1,317 @@
+# Kandinsky2.2 text-to-image fine-tuning
+
+Kandinsky 2.2 includes a prior pipeline that generates image embeddings from text prompts, and a decoder pipeline that generates the output image based on the image embeddings. We provide `train_text_to_image_prior.py` and `train_text_to_image_decoder.py` scripts to show you how to fine-tune the Kandinsky prior and decoder models separately based on your own dataset. To achieve the best results, you should fine-tune **_both_** your prior and decoder models.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
+
+
+## Running locally with PyTorch
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag.
+
+___
+
+### Pokemon example
+
+For all our examples, we will directly store the trained weights on the Hub, so we need to be logged in and add the `--push_to_hub` flag. In order to do that, you have to be a registered user on the 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to the [User Access Tokens](https://huggingface.co/docs/hub/security-tokens) guide.
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+We also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run
+
+```bash
+pip install wandb
+```
+
+To disable wandb logging, remove the `--report_to=="wandb"` and `--validation_prompts="A robot pokemon, 4k photo"` flags from below examples
+
+#### Fine-tune decoder
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-decoder-pokemon-model"
+```
+
+
+
+To train on your own training files, prepare the dataset according to the format required by `datasets`. You can find the instructions for how to do that in the [ImageFolder with metadata](https://huggingface.co/docs/datasets/en/image_load#imagefolder-with-metadata) guide.
+If you wish to use custom loading logic, you should modify the script and we have left pointers for that in the training script.
+
+```bash
+export TRAIN_DIR="path_to_your_dataset"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
+ --train_data_dir=$TRAIN_DIR \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi22-decoder-pokemon-model"
+```
+
+
+Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `kandi22-decoder-pokemon-model`. To load the fine-tuned model for inference just pass that path to `AutoPipelineForText2Image`
+
+```python
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained(output_dir, torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+prompt='A robot pokemon, 4k photo'
+images = pipe(prompt=prompt).images
+images[0].save("robot-pokemon.png")
+```
+
+Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
+```python
+from diffusers import AutoPipelineForText2Image, UNet2DConditionModel
+
+model_path = "path_to_saved_model"
+
+unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet")
+
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload()
+
+image = pipe(prompt="A robot pokemon, 4k photo").images[0]
+image.save("robot-pokemon.png")
+```
+
+#### Fine-tune prior
+
+You can fine-tune the Kandinsky prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning.
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-prior-pokemon-model"
+```
+
+
+
+To perform inference with the fine-tuned prior model, you will need to first create a prior pipeline by passing the `output_dir` to `DiffusionPipeline`. Then create a `KandinskyV22CombinedPipeline` from a pretrained or fine-tuned decoder checkpoint along with all the modules of the prior pipeline you just created.
+
+```python
+from diffusers import AutoPipelineForText2Image, DiffusionPipeline
+import torch
+
+pipe_prior = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
+prior_components = {"prior_" + k: v for k,v in pipe_prior.components.items()}
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)
+
+pipe.enable_model_cpu_offload()
+prompt='A robot pokemon, 4k photo'
+images = pipe(prompt=prompt, negative_prompt=negative_prompt).images
+images[0]
+```
+
+If you want to use a fine-tuned decoder checkpoint along with your fine-tuned prior checkpoint, you can simply replace the "kandinsky-community/kandinsky-2-2-decoder" in above code with your custom model repo name. Note that in order to be able to create a `KandinskyV22CombinedPipeline`, your model repository need to have a prior tag. If you have created your model repo using our training script, the prior tag is automatically included.
+
+#### Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image_decoder.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-decoder-pokemon-model"
+```
+
+
+#### Training with Min-SNR weighting
+
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps achieve faster convergence
+by rebalancing the loss. Enable the `--snr_gamma` argument and set it to the recommended
+value of 5.0.
+
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Kandinsky 2.2 on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
+
+
+#### Train decoder
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder_lora.py \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --rank=4 \
+ --gradient_checkpointing \
+ --output_dir="kandi22-decoder-pokemon-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub \
+```
+
+#### Train prior
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --rank=4 \
+ --output_dir="kandi22-prior-pokemon-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub \
+```
+
+**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run above scripts in consumer GPUs like T4 or V100.___**
+
+
+### Inference
+
+#### Inference using fine-tuned LoRA checkpoint for decoder
+
+Once you have trained a Kandinsky decoder model using the above command, inference can be done with the `AutoPipelineForText2Image` after loading the trained LoRA weights. You need to pass the `output_dir` for loading the LoRA weights, which in this case is `kandi22-decoder-pokemon-lora`.
+
+
+```python
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
+pipe.unet.load_attn_procs(output_dir)
+pipe.enable_model_cpu_offload()
+
+prompt='A robot pokemon, 4k photo'
+image = pipe(prompt=prompt).images[0]
+image.save("robot_pokemon.png")
+```
+
+#### Inference using fine-tuned LoRA checkpoint for prior
+
+```python
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
+pipe.prior_prior.load_attn_procs(output_dir)
+pipe.enable_model_cpu_offload()
+
+prompt='A robot pokemon, 4k photo'
+image = pipe(prompt=prompt).images[0]
+image.save("robot_pokemon.png")
+image
+```
+
+### Training with xFormers:
+
+You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.
+
+xFormers training is not available for fine-tuning the prior model.
+
+**Note**:
+
+According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
\ No newline at end of file
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/requirements.txt b/diffusers/examples/kandinsky2_2/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31b9026efdc2799b1d02e2e3f4d8dfc463737fdc
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0a64b42e4b73ca7ee713e11770770c8aea5b0a
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -0,0 +1,917 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_decoder_model_name_or_path}
+datasets:
+- {args.dataset_name}
+prior:
+- {args.pretrained_prior_model_name_or_path}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_decoder_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(vae, image_encoder, image_processor, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ prior_image_encoder=accelerator.unwrap_model(image_encoder),
+ prior_image_processor=image_processor,
+ unet=accelerator.unwrap_model(unet),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ vae = VQModel.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
+ ).eval()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+
+ # Freeze vae and image_encoder
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # Set unet to trainable.
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+ ema_unet.to(accelerator.device)
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
+
+ def center_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+ def train_transforms(img):
+ img = center_crop(img)
+ img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
+ img = np.array(img).astype(np.float32) / 127.5 - 1
+ img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
+ return img
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+ # Move image_encode and vae to gpu and cast to weight_dtype
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ images = batch["pixel_values"].to(weight_dtype)
+ clip_images = batch["clip_pixel_values"].to(weight_dtype)
+ latents = vae.encode(images).latents
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ target = noise
+
+ # Predict the noise residual and compute loss
+ added_cond_kwargs = {"image_embeds": image_embeds}
+
+ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ image_encoder,
+ image_processor,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ )
+ pipeline.decoder_pipe.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+ pipeline.enable_model_cpu_offload()
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f968aa8b8b36c3a0da204bb3f5be27a206bb303
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -0,0 +1,798 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Kandinsky with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from PIL import Image
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2 with LoRA.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=0.0, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="scheduler")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
+ )
+
+ vae = VQModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="movq")
+
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ image_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnAddedKVProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=args.rank,
+ )
+
+ unet.set_attn_processor(lora_attn_procs)
+
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
+
+ def center_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+ def train_transforms(img):
+ img = center_crop(img)
+ img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)
+ img = np.array(img).astype(np.float32) / 127.5 - 1
+ img = torch.from_numpy(np.transpose(img, [2, 0, 1]))
+ return img
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ return {"pixel_values": pixel_values, "clip_pixel_values": clip_pixel_values}
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+ # Prepare everything with our `accelerator`.
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ images = batch["pixel_values"].to(weight_dtype)
+ clip_images = batch["clip_pixel_values"].to(weight_dtype)
+ latents = vae.encode(images).latents
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ target = noise
+
+ # Predict the noise residual and compute loss
+ added_cond_kwargs = {"image_embeds": image_embeds}
+
+ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_decoder_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.unet.load_attn_procs(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..317e4178c04c8be99dbb7d8c71b0f5abad83d523
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -0,0 +1,830 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, image_processor, tokenizer and models.
+ noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder"
+ )
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder"
+ )
+ prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ # freeze parameters of models to save more memory
+ image_encoder.requires_grad_(False)
+ prior.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move image_encoder, text_encoder and prior to device and cast to weight_dtype
+ prior.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ lora_attn_procs = {}
+ for name in prior.attn_processors.keys():
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)
+
+ prior.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(prior.attn_processors)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+ clip_mean = prior.clip_mean.clone()
+ clip_std = prior.clip_std.clone()
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
+ clip_std = clip_std.to(weight_dtype).to(accelerator.device)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ prior.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, clip_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["clip_pixel_values"].to(weight_dtype),
+ )
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids)
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
+ )
+ timesteps = timesteps.long()
+ image_embeds = (image_embeds - clip_mean) / clip_std
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ target = image_embeds
+
+ # Predict the noise residual and compute loss
+ model_pred = prior(
+ noisy_latents,
+ timestep=timesteps,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_prior=accelerator.unwrap_model(prior),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = prior.to(torch.float32)
+ prior.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_prior_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.prior_prior.load_attn_procs(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e6d06074012956c2010914fc8bcfc60e78ddbc0
--- /dev/null
+++ b/diffusers/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -0,0 +1,945 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_prior_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- kandinsky
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
+image = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ image_encoder, image_processor, text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch
+):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_image_encoder=accelerator.unwrap_model(image_encoder),
+ prior_image_processor=image_processor,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ prior_prior=accelerator.unwrap_model(prior),
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-decoder",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="kandinsky-community/kandinsky-2-2-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="kandi_2_2-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, image_processor, tokenizer and models.
+ noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+ )
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+
+ prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+
+ # Freeze text_encoder and image_encoder
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # Set prior to trainable.
+ prior.train()
+
+ # Create EMA for the prior.
+ if args.use_ema:
+ ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
+ ema_prior.to(accelerator.device)
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if args.use_ema:
+ ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "prior"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), PriorTransformer)
+ ema_prior.load_state_dict(load_model.state_dict())
+ ema_prior.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = PriorTransformer.from_pretrained(input_dir, subfolder="prior")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+ optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ clip_pixel_values = torch.stack([example["clip_pixel_values"] for example in examples])
+ clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"clip_pixel_values": clip_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ clip_mean = prior.clip_mean.clone()
+ clip_std = prior.clip_std.clone()
+
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ prior, optimizer, train_dataloader, lr_scheduler
+ )
+
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
+ clip_std = clip_std.to(weight_dtype).to(accelerator.device)
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, clip_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["clip_pixel_values"].to(weight_dtype),
+ )
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids)
+ prompt_embeds = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ image_embeds = image_encoder(clip_images).image_embeds
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device
+ )
+ timesteps = timesteps.long()
+ image_embeds = (image_embeds - clip_mean) / clip_std
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ target = image_embeds
+
+ # Predict the noise residual and compute loss
+ model_pred = prior(
+ noisy_latents,
+ timestep=timesteps,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_prior.step(prior.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_prior.store(prior.parameters())
+ ema_prior.copy_to(prior.parameters())
+ log_validation(
+ image_encoder,
+ image_processor,
+ text_encoder,
+ tokenizer,
+ prior,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_prior.restore(prior.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = accelerator.unwrap_model(prior)
+ if args.use_ema:
+ ema_prior.copy_to(prior.parameters())
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_image_encoder=image_encoder,
+ prior_text_encoder=text_encoder,
+ prior_prior=prior,
+ )
+ pipeline.prior_pipe.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/reinforcement_learning/README.md b/diffusers/examples/reinforcement_learning/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..17881d584a4043156b784a152253b0f83598ced9
--- /dev/null
+++ b/diffusers/examples/reinforcement_learning/README.md
@@ -0,0 +1,22 @@
+# Overview
+
+These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers.
+There are two ways to use the script, `run_diffuser_locomotion.py`.
+
+The key option is a change of the variable `n_guide_steps`.
+When `n_guide_steps=0`, the trajectories are sampled from the diffusion model, but not fine-tuned to maximize reward in the environment.
+By default, `n_guide_steps=2` to match the original implementation.
+
+
+You will need some RL specific requirements to run the examples:
+
+```
+pip install -f https://download.pytorch.org/whl/torch_stable.html \
+ free-mujoco-py \
+ einops \
+ gym==0.24.1 \
+ protobuf==3.20.1 \
+ git+https://github.com/rail-berkeley/d4rl.git \
+ mediapy \
+ Pillow==9.0.0
+```
diff --git a/diffusers/examples/reinforcement_learning/run_diffuser_locomotion.py b/diffusers/examples/reinforcement_learning/run_diffuser_locomotion.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf6d1443d1c2e7caca7bdc1a26da1f2f186b8f9
--- /dev/null
+++ b/diffusers/examples/reinforcement_learning/run_diffuser_locomotion.py
@@ -0,0 +1,59 @@
+import d4rl # noqa
+import gym
+import tqdm
+from diffusers.experimental import ValueGuidedRLPipeline
+
+
+config = {
+ "n_samples": 64,
+ "horizon": 32,
+ "num_inference_steps": 20,
+ "n_guide_steps": 2, # can set to 0 for faster sampling, does not use value network
+ "scale_grad_by_std": True,
+ "scale": 0.1,
+ "eta": 0.0,
+ "t_grad_cutoff": 2,
+ "device": "cpu",
+}
+
+
+if __name__ == "__main__":
+ env_name = "hopper-medium-v2"
+ env = gym.make(env_name)
+
+ pipeline = ValueGuidedRLPipeline.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32",
+ env=env,
+ )
+
+ env.seed(0)
+ obs = env.reset()
+ total_reward = 0
+ total_score = 0
+ T = 1000
+ rollout = [obs.copy()]
+ try:
+ for t in tqdm.tqdm(range(T)):
+ # call the policy
+ denorm_actions = pipeline(obs, planning_horizon=32)
+
+ # execute action in environment
+ next_observation, reward, terminal, _ = env.step(denorm_actions)
+ score = env.get_normalized_score(total_reward)
+
+ # update return
+ total_reward += reward
+ total_score += score
+ print(
+ f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
+ f" {total_score}"
+ )
+
+ # save observations for rendering
+ rollout.append(next_observation.copy())
+
+ obs = next_observation
+ except KeyboardInterrupt:
+ pass
+
+ print(f"Total reward: {total_reward}")
diff --git a/diffusers/examples/research_projects/README.md b/diffusers/examples/research_projects/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef50d423e68ff5c641e4419bd30f84787aebf839
--- /dev/null
+++ b/diffusers/examples/research_projects/README.md
@@ -0,0 +1,14 @@
+# Research projects
+
+This folder contains various research projects using 🧨 Diffusers.
+They are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.
+Updating them to the most recent version of the library will require some work.
+
+To use any of them, just run the command
+
+```
+pip install -r requirements.txt
+```
+inside the folder of your choice.
+
+If you need help with any of those, please open an issue where you directly ping the author(s), as indicated at the top of the README of each folder.
diff --git a/diffusers/examples/research_projects/colossalai/README.md b/diffusers/examples/research_projects/colossalai/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..be94950b772eaaac994104f0787d9ffbfc769f63
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/README.md
@@ -0,0 +1,111 @@
+# [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git)
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+The `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel.
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -r requirements.txt
+```
+
+## Install [ColossalAI](https://github.com/hpcaitech/ColossalAI.git)
+
+**From PyPI**
+```bash
+pip install colossalai
+```
+
+**From source**
+
+```bash
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI
+
+# install colossalai
+pip install .
+```
+
+## Dataset for Teyvat BLIP captions
+Dataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).
+
+BLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).
+
+For each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.
+
+The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).
+
+## Training
+
+The argument `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400 \
+ --placement="cuda"
+```
+
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=800 \
+ --placement="cuda"
+```
+
+## Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-save-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
+```
diff --git a/diffusers/examples/research_projects/colossalai/inference.py b/diffusers/examples/research_projects/colossalai/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b115c2d2b8f5bcdb3a0c053a6c71b91a965c573
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/inference.py
@@ -0,0 +1,12 @@
+import torch
+
+from diffusers import StableDiffusionPipeline
+
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of sks dog in a bucket"
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("dog-bucket.png")
diff --git a/diffusers/examples/research_projects/colossalai/requirement.txt b/diffusers/examples/research_projects/colossalai/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f80467dcff521bfed1fa72109e1e01e92ab05646
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/requirement.txt
@@ -0,0 +1,7 @@
+diffusers
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
+transformers
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/diffusers/examples/research_projects/colossalai/train_dreambooth_colossalai.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cebd2b811759f71dbdf00ee6f131a459bc9a1d5
--- /dev/null
+++ b/diffusers/examples/research_projects/colossalai/train_dreambooth_colossalai.py
@@ -0,0 +1,671 @@
+import argparse
+import math
+import os
+from pathlib import Path
+
+import colossalai
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
+from colossalai.nn.parallel.utils import get_static_torch_model
+from colossalai.utils import get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+
+
+disable_existing_loggers()
+logger = get_dist_logger()
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default="a photo of sks dog",
+ required=False,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--placement",
+ type=str,
+ default="cpu",
+ help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ if args.class_data_dir is not None:
+ logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ logger.warning("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+# Gemini + ZeRO DDP
+def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
+ from colossalai.nn.parallel import GeminiDDP
+
+ model = GeminiDDP(
+ model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64
+ )
+ return model
+
+
+def main(args):
+ if args.seed is None:
+ colossalai.launch_from_torch(config={})
+ else:
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+
+ local_rank = gpc.get_local_rank(ParallelMode.DATA)
+ world_size = gpc.get_world_size(ParallelMode.DATA)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ pipeline.to(get_current_device())
+
+ for example in tqdm(
+ sample_dataloader,
+ desc="Generating class images",
+ disable=not local_rank == 0,
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+
+ # Handle the repository creation
+ if local_rank == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_name,
+ revision=args.revision,
+ use_fast=False,
+ )
+ elif args.pretrained_model_name_or_path:
+ logger.info("Loading tokenizer from pretrained model", ranks=[0])
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
+
+ # Load models and create wrapper for stable diffusion
+
+ logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
+
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+
+ logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+
+ logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
+ with ColoInitContext(device=get_current_device()):
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
+ )
+
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * args.train_batch_size * world_size
+
+ unet = gemini_zero_dpp(unet, args.placement)
+
+ # config optimizer for colossalai zero
+ optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
+
+ # load noise_scheduler
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # prepare dataset
+ logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad(
+ {"input_ids": input_ids},
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ vae.to(get_current_device(), dtype=weight_dtype)
+ text_encoder.to(get_current_device(), dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * world_size
+
+ logger.info("***** Running training *****", ranks=[0])
+ logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
+ logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
+ logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ torch.cuda.synchronize()
+ for epoch in range(args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ torch.cuda.reset_peak_memory_stats()
+ # Move batch to gpu
+ for key, value in batch.items():
+ batch[key] = value.to(get_current_device(), non_blocking=True)
+
+ # Convert images to latent space
+ optimizer.zero_grad()
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ optimizer.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ progress_bar.update(1)
+ global_step += 1
+ logs = {
+ "loss": loss.detach().item(),
+ "lr": optimizer.param_groups[0]["lr"],
+ } # lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step % args.save_steps == 0:
+ torch.cuda.synchronize()
+ torch_unet = get_static_torch_model(unet)
+ if local_rank == 0:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=torch_unet,
+ revision=args.revision,
+ )
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ pipeline.save_pretrained(save_path)
+ logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
+ if global_step >= args.max_train_steps:
+ break
+
+ torch.cuda.synchronize()
+ unet = get_static_torch_model(unet)
+
+ if local_rank == 0:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ revision=args.revision,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+ logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/controlnet/train_controlnet_webdataset.py b/diffusers/examples/research_projects/controlnet/train_controlnet_webdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3122a3952b33e2a1a06108c340a9bc6bc7523f05
--- /dev/null
+++ b/diffusers/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -0,0 +1,1460 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Optional, Union
+
+import accelerate
+import cv2
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ ControlNetModel,
+ EulerDiscreteScheduler,
+ StableDiffusionXLControlNetPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext)
+ :param lcase: convert suffixes to lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+def control_transform(image):
+ image = np.array(image)
+
+ low_threshold = 100
+ high_threshold = 200
+
+ image = cv2.Canny(image, low_threshold, high_threshold)
+ image = image[:, :, None]
+ image = np.concatenate([image, image, image], axis=2)
+ control_image = Image.fromarray(image)
+ return control_image
+
+
+def canny_image_transform(example, resolution=1024):
+ image = example["image"]
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
+ # get crop coordinates
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)
+ control_image = control_transform(image)
+
+ image = transforms.ToTensor()(image)
+ image = transforms.Normalize([0.5], [0.5])(image)
+ control_image = transforms.ToTensor()(control_image)
+
+ example["image"] = image
+ example["control_image"] = control_image
+ example["crop_coords"] = (c_top, c_left)
+
+ return example
+
+
+def depth_image_transform(example, feature_extractor, resolution=1024):
+ image = example["image"]
+ image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)
+ # get crop coordinates
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)
+
+ control_image = feature_extractor(images=image, return_tensors="pt").pixel_values.squeeze(0)
+
+ image = transforms.ToTensor()(image)
+ image = transforms.Normalize([0.5], [0.5])(image)
+
+ example["image"] = image
+ example["control_image"] = control_image
+ example["crop_coords"] = (c_top, c_left)
+
+ return example
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ eval_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 256,
+ center_crop: bool = True,
+ random_flip: bool = False,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ control_type: str = "canny",
+ feature_extractor: Optional[DPTFeatureExtractor] = None,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ if not isinstance(eval_shards_path_or_url, str):
+ eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]
+ # flatten list using itertools
+ eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))
+
+ def get_orig_size(json):
+ return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
+
+ if control_type == "canny":
+ image_transform = functools.partial(canny_image_transform, resolution=resolution)
+ elif control_type == "depth":
+ image_transform = functools.partial(
+ depth_image_transform, feature_extractor=feature_extractor, resolution=resolution
+ )
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp",
+ control_image="jpg;png;jpeg;webp",
+ text="text;txt;caption",
+ orig_size="json",
+ handler=wds.warn_and_continue,
+ ),
+ wds.map(filter_keys({"image", "control_image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(image_transform),
+ wds.to_tuple("image", "control_image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=512)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ # Create eval dataset and loader
+ pipeline = [
+ wds.SimpleShardList(eval_shards_path_or_url),
+ wds.split_by_worker,
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+ self._eval_dataset = wds.DataPipeline(*pipeline)
+ self._eval_dataloader = wds.WebLoader(
+ self._eval_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+ @property
+ def eval_dataset(self):
+ return self._eval_dataset
+
+ @property
+ def eval_dataloader(self):
+ return self._eval_dataloader
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ controlnet = accelerator.unwrap_model(controlnet)
+
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ controlnet=controlnet,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ # pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"![images_{i})](./images_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- controlnet
+inference: true
+---
+ """
+ model_card = f"""
+# controlnet-{repo_id}
+
+These are controlnet weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
+ " If not specified controlnet weights are initialized from unet.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="controlnet-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=3,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=1,
+ help=("Number of subprocesses to use for data loading."),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--eval_shards_path_or_url",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="sd_xl_train_controlnet",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument(
+ "--control_type",
+ type=str,
+ default="canny",
+ help=("The type of controlnet conditioning image to use. One of `canny`, `depth`" " Defaults to `canny`."),
+ )
+ parser.add_argument(
+ "--transformer_layers_per_block",
+ type=str,
+ default=None,
+ help=("The number of layers per block in the transformer. If None, defaults to" " `args.transformer_layers`."),
+ )
+ parser.add_argument(
+ "--old_style_controlnet",
+ action="store_true",
+ default=False,
+ help=(
+ "Use the old style controlnet, which is a single transformer layer with"
+ " a single head. Defaults to False."
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ # noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
+ )
+
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ pre_controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+ else:
+ logger.info("Initializing controlnet weights from unet")
+ pre_controlnet = ControlNetModel.from_unet(unet)
+
+ if args.transformer_layers_per_block is not None:
+ transformer_layers_per_block = [int(x) for x in args.transformer_layers_per_block.split(",")]
+ down_block_types = ["DownBlock2D" if l == 0 else "CrossAttnDownBlock2D" for l in transformer_layers_per_block]
+ controlnet = ControlNetModel.from_config(
+ pre_controlnet.config,
+ down_block_types=down_block_types,
+ transformer_layers_per_block=transformer_layers_per_block,
+ )
+ controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False)
+ del pre_controlnet
+ else:
+ controlnet = pre_controlnet
+
+ if args.control_type == "depth":
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
+ depth_model.requires_grad_(False)
+ else:
+ feature_extractor = None
+ depth_model = None
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "controlnet"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ controlnet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ controlnet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ controlnet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = controlnet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=torch.float32)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ if args.control_type == "depth":
+ depth_model.to(accelerator.device, dtype=weight_dtype)
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ # crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ # add_time_ids = list(crops_coords_top_left + target_size)
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ # add_time_ids = torch.cat([torch.tensor(original_sizes, dtype=torch.long), add_time_ids], dim=-1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ eval_shards_path_or_url=args.eval_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ center_crop=False,
+ random_flip=False,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ control_type=args.control_type,
+ feature_extractor=feature_extractor,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ controlnet, optimizer, lr_scheduler = accelerator.prepare(controlnet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(controlnet):
+ image, control_image, text, orig_size, crop_coords = batch
+
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+ image = image.to(accelerator.device, non_blocking=True)
+ control_image = control_image.to(accelerator.device, non_blocking=True)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # latents = vae.encode(pixel_values).latent_dist.sample()
+ # encode pixel values with batch size of at most 8
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ if args.control_type == "depth":
+ control_image = control_image.to(weight_dtype)
+ with torch.autocast("cuda"):
+ depth_map = depth_model(control_image).predicted_depth
+ depth_map = torch.nn.functional.interpolate(
+ depth_map.unsqueeze(1),
+ size=image.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+ control_image = (depth_map * 255.0).to(torch.uint8).float() / 255.0 # hack to match inference
+ control_image = torch.cat([control_image] * 3, dim=1)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+ sigmas = get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
+ inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)
+
+ # ControlNet conditioning.
+ controlnet_image = control_image.to(dtype=weight_dtype)
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=encoded_text,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+ )
+
+ # Predict the noise residual
+ model_pred = unet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=encoded_text,
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ ).sample
+
+ model_pred = model_pred * (-sigmas) + noisy_latents
+ weighing = sigmas**-2.0
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = latents # compute loss against the denoised latents
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+ loss = torch.mean(
+ (weighing.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = controlnet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae, unet, controlnet, args, accelerator, weight_dtype, global_step
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ controlnet = accelerator.unwrap_model(controlnet)
+ controlnet.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/README.md b/diffusers/examples/research_projects/dreambooth_inpaint/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dec919587935ec6e08a08e9299d62b0edc17449c
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/README.md
@@ -0,0 +1,118 @@
+# Dreambooth for the inpainting model
+
+This script was added by @thedarkzeno .
+
+Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+### Training with gradient checkpointing and 8-bit optimizer:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth_inpaint.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt b/diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aad6387026f181053d1872fd1961c7b56e86f1df
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/requirements.txt
@@ -0,0 +1,7 @@
+diffusers==0.9.0
+accelerate>=0.16.0
+torchvision
+transformers>=4.21.0
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e82a45c024f5c79059b73bb5272547659c1b8a7
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
@@ -0,0 +1,812 @@
+import argparse
+import itertools
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image, ImageDraw
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+# generate random masks
+def random_mask(im_shape, ratio=1, mask_full_image=False):
+ mask = Image.new("L", im_shape, 0)
+ draw = ImageDraw.Draw(mask)
+ size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
+ # use this to always mask the whole image
+ if mask_full_image:
+ size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
+ limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
+ center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
+ draw_type = random.randint(0, 1)
+ if draw_type == 0 or mask_full_image:
+ draw.rectangle(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+ else:
+ draw.ellipse(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
+ " sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
+ " using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms_resize_and_crop = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ ]
+ )
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ instance_image = self.image_transforms_resize_and_crop(instance_image)
+
+ example["PIL_images"] = instance_image
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ class_image = self.image_transforms_resize_and_crop(class_image)
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_PIL_images"] = class_image
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ project_config=project_config,
+ )
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(
+ sample_dataset, batch_size=args.sample_batch_size, num_workers=1
+ )
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+ transform_to_pil = transforms.ToPILImage()
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ bsz = len(example["prompt"])
+ fake_images = torch.rand((3, args.resolution, args.resolution))
+ transform_to_pil = transforms.ToPILImage()
+ fake_pil_images = transform_to_pil(fake_images)
+
+ fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
+
+ images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ vae.requires_grad_(False)
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ pior_pil = [example["class_PIL_images"] for example in examples]
+
+ masks = []
+ masked_images = []
+ for example in examples:
+ pil_image = example["PIL_images"]
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ if args.with_prior_preservation:
+ for pil_image in pior_pil:
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+ masks = torch.stack(masks)
+ masked_images = torch.stack(masked_images)
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+ accelerator.register_for_checkpointing(lr_scheduler)
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Convert masked images to latent space
+ masked_latents = vae.encode(
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ masked_latents = masked_latents * vae.config.scaling_factor
+
+ masks = batch["masks"]
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.stack(
+ [
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
+ for mask in masks
+ ]
+ )
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # concatenate the noised latents with the mask and the masked latents
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Create the pipeline using using the trained modules and save it.
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d79b2ceadaf01aa5cfdaf92dcf2a997dd61a629
--- /dev/null
+++ b/diffusers/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
@@ -0,0 +1,831 @@
+import argparse
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image, ImageDraw
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def prepare_mask_and_masked_image(image, mask):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ return mask, masked_image
+
+
+# generate random masks
+def random_mask(im_shape, ratio=1, mask_full_image=False):
+ mask = Image.new("L", im_shape, 0)
+ draw = ImageDraw.Draw(mask)
+ size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))
+ # use this to always mask the whole image
+ if mask_full_image:
+ size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))
+ limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)
+ center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))
+ draw_type = random.randint(0, 1)
+ if draw_type == 0 or mask_full_image:
+ draw.rectangle(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+ else:
+ draw.ellipse(
+ (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),
+ fill=255,
+ )
+
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
+ " sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="dreambooth-inpaint-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
+ " using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.instance_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms_resize_and_crop = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ ]
+ )
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ instance_image = self.image_transforms_resize_and_crop(instance_image)
+
+ example["PIL_images"] = instance_image
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ class_image = self.image_transforms_resize_and_crop(class_image)
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_PIL_images"] = class_image
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ project_config=accelerator_project_config,
+ )
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(
+ sample_dataset, batch_size=args.sample_batch_size, num_workers=1
+ )
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+ transform_to_pil = transforms.ToPILImage()
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ bsz = len(example["prompt"])
+ fake_images = torch.rand((3, args.resolution, args.resolution))
+ transform_to_pil = transforms.ToPILImage()
+ fake_pil_images = transform_to_pil(fake_images)
+
+ fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)
+
+ images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new LoRA weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Set correct lora layers
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
+
+ unet.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ accelerator.register_for_checkpointing(lora_layers)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ pior_pil = [example["class_PIL_images"] for example in examples]
+
+ masks = []
+ masked_images = []
+ for example in examples:
+ pil_image = example["PIL_images"]
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ if args.with_prior_preservation:
+ for pil_image in pior_pil:
+ # generate a random mask
+ mask = random_mask(pil_image.size, 1, False)
+ # prepare mask and masked image
+ mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)
+
+ masks.append(mask)
+ masked_images.append(masked_image)
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+ masks = torch.stack(masks)
+ masked_images = torch.stack(masked_images)
+ batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+ # accelerator.register_for_checkpointing(lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth-inpaint-lora", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Convert masked images to latent space
+ masked_latents = vae.encode(
+ batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
+ ).latent_dist.sample()
+ masked_latents = masked_latents * vae.config.scaling_factor
+
+ masks = batch["masks"]
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ mask = torch.stack(
+ [
+ torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
+ for mask in masks
+ ]
+ ).to(dtype=weight_dtype)
+ mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # concatenate the noised latents with the mask and the masked latents
+ latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Save the lora layers
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/intel_opts/README.md b/diffusers/examples/research_projects/intel_opts/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6b25679efbe90d556244e7aa6bee3e863c28b069
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/README.md
@@ -0,0 +1,37 @@
+## Diffusers examples with Intel optimizations
+
+**This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**
+
+This aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.
+
+## Accelerating the fine-tuning for textual inversion
+
+We accelereate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
+
+## Accelerating the inference for Stable Diffusion using Bfloat16
+
+We start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.
+```bash
+pip install diffusers transformers accelerate scipy safetensors
+
+export KMP_BLOCKTIME=1
+export KMP_SETTINGS=1
+export KMP_AFFINITY=granularity=fine,compact,1,0
+
+# Intel OpenMP
+export OMP_NUM_THREADS=< Cores to use >
+export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so
+# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.
+export LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so
+export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000"
+
+# Launch with default DDIM
+numactl --membind -C python python inference_bf16.py
+# Launch with DPMSolverMultistepScheduler
+numactl --membind -C python python inference_bf16.py --dpm
+
+```
+
+## Accelerating the inference for Stable Diffusion using INT8
+
+Coming soon ...
diff --git a/diffusers/examples/research_projects/intel_opts/inference_bf16.py b/diffusers/examples/research_projects/intel_opts/inference_bf16.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ec709f433cd13dad0b93d5368d61e169b9df28
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/inference_bf16.py
@@ -0,0 +1,56 @@
+import argparse
+
+import intel_extension_for_pytorch as ipex
+import torch
+
+from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
+
+
+parser = argparse.ArgumentParser("Stable Diffusion script with intel optimization", add_help=False)
+parser.add_argument("--dpm", action="store_true", help="Enable DPMSolver or not")
+parser.add_argument("--steps", default=None, type=int, help="Num inference steps")
+args = parser.parse_args()
+
+
+device = "cpu"
+prompt = "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings"
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id)
+if args.dpm:
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+pipe = pipe.to(device)
+
+# to channels last
+pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
+pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
+pipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)
+if pipe.requires_safety_checker:
+ pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)
+
+# optimize with ipex
+sample = torch.randn(2, 4, 64, 64)
+timestep = torch.rand(1) * 999
+encoder_hidden_status = torch.randn(2, 77, 768)
+input_example = (sample, timestep, encoder_hidden_status)
+try:
+ pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)
+except Exception:
+ pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)
+pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)
+pipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
+if pipe.requires_safety_checker:
+ pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)
+
+# compute
+seed = 666
+generator = torch.Generator(device).manual_seed(seed)
+generate_kwargs = {"generator": generator}
+if args.steps is not None:
+ generate_kwargs["num_inference_steps"] = args.steps
+
+with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ image = pipe(prompt, **generate_kwargs).images[0]
+
+# save image
+image.save("generated.png")
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion/README.md b/diffusers/examples/research_projects/intel_opts/textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..14e8b160fb1fb2de72cd37ddb4e4abcab83356fa
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion/README.md
@@ -0,0 +1,68 @@
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Training with Intel Extension for PyTorch
+
+Intel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example "textual_inversion.py". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.
+
+The example supports both single node and multi-node distributed training:
+
+### Single node training
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATA_DIR="path-to-dir-containing-dicoo-images"
+
+python textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --seed=7 \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --max_train_steps=3000 \
+ --learning_rate=2.5e-03 --scale_lr \
+ --output_dir="textual_inversion_dicoo"
+```
+
+Note: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.
+
+### Multi-node distributed training
+
+Before running the scripts, make sure to install the library's training dependencies successfully:
+
+```bash
+python -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu
+```
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATA_DIR="path-to-dir-containing-dicoo-images"
+
+oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
+source $oneccl_bindings_for_pytorch_path/env/setvars.sh
+
+python -m intel_extension_for_pytorch.cpu.launch --distributed \
+ --hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --seed=7 \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --max_train_steps=750 \
+ --learning_rate=2.5e-03 --scale_lr \
+ --output_dir="textual_inversion_dicoo"
+```
+The above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the "hostfile" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).
+
+
+### Reference
+
+We publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt b/diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..af7ed6b21f6fb4518930a37786199643b1c60ece
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.21.0
+ftfy
+tensorboard
+Jinja2
+intel_extension_for_pytorch>=1.13
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff24130c9b61e932e14687250a0ad0e95a5c7089
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
@@ -0,0 +1,635 @@
+import argparse
+import itertools
+import math
+import os
+import random
+from pathlib import Path
+
+import intel_extension_for_pytorch as ipex
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+
+logger = get_logger(__name__)
+
+
+def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--only_save_embeds",
+ action="store_true",
+ default=False,
+ help="Save only the embeddings for the new concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ )
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
+
+ # Freeze vae and unet
+ freeze_params(vae.parameters())
+ freeze_params(unet.parameters())
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # Move vae and unet to device
+ vae.to(accelerator.device)
+ unet.to(accelerator.device)
+
+ # Keep vae and unet in eval model as we don't train these
+ vae.eval()
+ unet.eval()
+
+ unet = ipex.optimize(unet, dtype=torch.bfloat16, inplace=True)
+ vae = ipex.optimize(vae, dtype=torch.bfloat16, inplace=True)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ text_encoder.train()
+ text_encoder, optimizer = ipex.optimize(text_encoder, optimizer=optimizer, dtype=torch.bfloat16)
+
+ for epoch in range(args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(latents.shape).to(latents.device)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ ).long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if accelerator.num_processes > 1:
+ grads = text_encoder.module.get_input_embeddings().weight.grad
+ else:
+ grads = text_encoder.get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ accelerator.wait_for_everyone()
+
+ # Create the pipeline using using the trained modules and save it.
+ if accelerator.is_main_process:
+ if args.push_to_hub and args.only_save_embeds:
+ logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = not args.only_save_embeds
+ if save_full_model:
+ pipeline = StableDiffusionPipeline(
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4a227cdb4d63585cc0f0ab76424be8a0b2c5b604
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
@@ -0,0 +1,93 @@
+# Distillation for quantization on Textual Inversion models to personalize text2image
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
+
+## Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -r requirements.txt
+```
+
+## Prepare Datasets
+
+One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
+
+
+
+
+
+## Get a FP32 Textual Inversion model
+
+Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATA_DIR="./dicoo"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="dicoo_model"
+```
+
+## Do distillation for quantization
+
+Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
+
+Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
+
+```bash
+export FP32_MODEL_NAME="./dicoo_model"
+export DATA_DIR="./dicoo"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$FP32_MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --use_ema --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=300 \
+ --learning_rate=5.0e-04 --max_grad_norm=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="int8_model" \
+ --do_quantization --do_distillation --verify_loading
+```
+
+After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
+
+## Inference
+
+Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
+
+```bash
+export INT8_MODEL_NAME="./int8_model"
+
+python text2images.py \
+ --pretrained_model_name_or_path=$INT8_MODEL_NAME \
+ --caption "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings." \
+ --images_num 4
+```
+
+Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
+
+
+
+
+
+
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cbd4c957be441a1aaf9a52e7ff02d772cb9d302b
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt
@@ -0,0 +1,7 @@
+accelerate
+torchvision
+transformers>=4.25.0
+ftfy
+tensorboard
+modelcards
+neural-compressor
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py
new file mode 100644
index 0000000000000000000000000000000000000000..a99d727712eb44b875576443837c81a442c72a6f
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py
@@ -0,0 +1,112 @@
+import argparse
+import math
+import os
+
+import torch
+from neural_compressor.utils.pytorch import load
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-m",
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "-c",
+ "--caption",
+ type=str,
+ default="robotic cat with wings",
+ help="Text used to generate images.",
+ )
+ parser.add_argument(
+ "-n",
+ "--images_num",
+ type=int,
+ default=4,
+ help="How much images to generate.",
+ )
+ parser.add_argument(
+ "-s",
+ "--seed",
+ type=int,
+ default=42,
+ help="Seed for random process.",
+ )
+ parser.add_argument(
+ "-ci",
+ "--cuda_id",
+ type=int,
+ default=0,
+ help="cuda_id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def image_grid(imgs, rows, cols):
+ if not len(imgs) == rows * cols:
+ raise ValueError("The specified number of rows and columns are not correct.")
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+ grid_w, grid_h = grid.size
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def generate_images(
+ pipeline,
+ prompt="robotic cat with wings",
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ num_images_per_prompt=1,
+ seed=42,
+):
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
+ images = pipeline(
+ prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+ _rows = int(math.sqrt(num_images_per_prompt))
+ grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
+ return grid, images
+
+
+args = parse_args()
+# Load models and create wrapper for stable diffusion
+tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
+)
+pipeline.safety_checker = lambda images, clip_input: (images, False)
+if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
+ unet = load(args.pretrained_model_name_or_path, model=unet)
+ unet.eval()
+ setattr(pipeline, "unet", unet)
+else:
+ unet = unet.to(torch.device("cuda", args.cuda_id))
+pipeline = pipeline.to(unet.device)
+grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
+grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
+dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
+os.makedirs(dirname, exist_ok=True)
+for idx, image in enumerate(images):
+ image.save(os.path.join(dirname, "{}.png".format(idx + 1)))
diff --git a/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..43667187596ef9b038d2ad03cbbdb02c0a6e40cf
--- /dev/null
+++ b/diffusers/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
@@ -0,0 +1,996 @@
+import argparse
+import itertools
+import math
+import os
+import random
+from pathlib import Path
+from typing import Iterable
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from accelerate import Accelerator
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from neural_compressor.utils import logger
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.utils import make_image_grid
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
+ logger.info("Saving embeddings")
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Example of distillation for quantization on Textual Inversion.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--do_quantization", action="store_true", help="Whether or not to do quantization.")
+ parser.add_argument("--do_distillation", action="store_true", help="Whether or not to do distillation.")
+ parser.add_argument(
+ "--verify_loading", action="store_true", help="Whether or not to verify the loading of the quantized model."
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+
+ self.decay = decay
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ value = (1 + optimization_step) / (10 + optimization_step)
+ return 1 - min(self.decay, value)
+
+ @torch.no_grad()
+ def step(self, parameters):
+ parameters = list(parameters)
+
+ self.optimization_step += 1
+ self.decay = self.get_decay(self.optimization_step)
+
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ tmp = self.decay * (s_param - param)
+ s_param.sub_(tmp)
+ else:
+ s_param.copy_(param)
+
+ torch.cuda.empty_cache()
+
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ """
+ Copy current averaged parameters into given collection of parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = list(parameters)
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.data)
+
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
+ for p in self.shadow_params
+ ]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
+ images = pipeline(
+ prompt,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ generator=generator,
+ num_images_per_prompt=num_images_per_prompt,
+ ).images
+ _rows = int(math.sqrt(num_images_per_prompt))
+ grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
+ return grid
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with="tensorboard",
+ project_config=accelerator_project_config,
+ )
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load models and create wrapper for stable diffusion
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ )
+
+ train_unet = False
+ # Freeze vae and unet
+ freeze_params(vae.parameters())
+ if not args.do_quantization and not args.do_distillation:
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
+
+ freeze_params(unet.parameters())
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ text_encoder.text_model.encoder.parameters(),
+ text_encoder.text_model.final_layer_norm.parameters(),
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+ else:
+ train_unet = True
+ freeze_params(text_encoder.parameters())
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ # only optimize the unet or embeddings of text_encoder
+ unet.parameters() if train_unet else text_encoder.get_input_embeddings().parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ if not train_unet:
+ text_encoder = accelerator.prepare(text_encoder)
+ unet.to(accelerator.device)
+ unet.eval()
+ else:
+ unet = accelerator.prepare(unet)
+ text_encoder.to(accelerator.device)
+ text_encoder.eval()
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
+
+ # Move vae to device
+ vae.to(accelerator.device)
+
+ # Keep vae in eval model as we don't train these
+ vae.eval()
+
+ compression_manager = None
+
+ def train_func(model):
+ if train_unet:
+ unet_ = model
+ text_encoder_ = text_encoder
+ else:
+ unet_ = unet
+ text_encoder_ = model
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ if train_unet and args.use_ema:
+ ema_unet = EMAModel(unet_.parameters())
+
+ for epoch in range(args.num_train_epochs):
+ model.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(model):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(latents.shape).to(latents.device)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ ).long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder_(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet_(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ loss = F.mse_loss(model_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ if train_unet and compression_manager:
+ unet_inputs = {
+ "sample": noisy_latents,
+ "timestep": timesteps,
+ "encoder_hidden_states": encoder_hidden_states,
+ }
+ loss = compression_manager.callbacks.on_after_compute_loss(unet_inputs, model_pred, loss)
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+
+ if train_unet:
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet_.parameters(), args.max_grad_norm)
+ else:
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if accelerator.num_processes > 1:
+ grads = text_encoder_.module.get_input_embeddings().weight.grad
+ else:
+ grads = text_encoder_.get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if train_unet and args.use_ema:
+ ema_unet.step(unet_.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+ if not train_unet and global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(text_encoder_, placeholder_token_id, accelerator, args, save_path)
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ accelerator.wait_for_everyone()
+
+ if train_unet and args.use_ema:
+ ema_unet.copy_to(unet_.parameters())
+
+ if not train_unet:
+ return text_encoder_
+
+ if not train_unet:
+ text_encoder = train_func(text_encoder)
+ else:
+ import copy
+
+ model = copy.deepcopy(unet)
+ confs = []
+ if args.do_quantization:
+ from neural_compressor import QuantizationAwareTrainingConfig
+
+ q_conf = QuantizationAwareTrainingConfig()
+ confs.append(q_conf)
+
+ if args.do_distillation:
+ teacher_model = copy.deepcopy(model)
+
+ def attention_fetcher(x):
+ return x.sample
+
+ layer_mappings = [
+ [
+ [
+ "conv_in",
+ ]
+ ],
+ [
+ [
+ "time_embedding",
+ ]
+ ],
+ [["down_blocks.0.attentions.0", attention_fetcher]],
+ [["down_blocks.0.attentions.1", attention_fetcher]],
+ [
+ [
+ "down_blocks.0.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.0.resnets.1",
+ ]
+ ],
+ [
+ [
+ "down_blocks.0.downsamplers.0",
+ ]
+ ],
+ [["down_blocks.1.attentions.0", attention_fetcher]],
+ [["down_blocks.1.attentions.1", attention_fetcher]],
+ [
+ [
+ "down_blocks.1.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.1.resnets.1",
+ ]
+ ],
+ [
+ [
+ "down_blocks.1.downsamplers.0",
+ ]
+ ],
+ [["down_blocks.2.attentions.0", attention_fetcher]],
+ [["down_blocks.2.attentions.1", attention_fetcher]],
+ [
+ [
+ "down_blocks.2.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.2.resnets.1",
+ ]
+ ],
+ [
+ [
+ "down_blocks.2.downsamplers.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.3.resnets.0",
+ ]
+ ],
+ [
+ [
+ "down_blocks.3.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.resnets.2",
+ ]
+ ],
+ [
+ [
+ "up_blocks.0.upsamplers.0",
+ ]
+ ],
+ [["up_blocks.1.attentions.0", attention_fetcher]],
+ [["up_blocks.1.attentions.1", attention_fetcher]],
+ [["up_blocks.1.attentions.2", attention_fetcher]],
+ [
+ [
+ "up_blocks.1.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.1.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.1.resnets.2",
+ ]
+ ],
+ [
+ [
+ "up_blocks.1.upsamplers.0",
+ ]
+ ],
+ [["up_blocks.2.attentions.0", attention_fetcher]],
+ [["up_blocks.2.attentions.1", attention_fetcher]],
+ [["up_blocks.2.attentions.2", attention_fetcher]],
+ [
+ [
+ "up_blocks.2.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.2.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.2.resnets.2",
+ ]
+ ],
+ [
+ [
+ "up_blocks.2.upsamplers.0",
+ ]
+ ],
+ [["up_blocks.3.attentions.0", attention_fetcher]],
+ [["up_blocks.3.attentions.1", attention_fetcher]],
+ [["up_blocks.3.attentions.2", attention_fetcher]],
+ [
+ [
+ "up_blocks.3.resnets.0",
+ ]
+ ],
+ [
+ [
+ "up_blocks.3.resnets.1",
+ ]
+ ],
+ [
+ [
+ "up_blocks.3.resnets.2",
+ ]
+ ],
+ [["mid_block.attentions.0", attention_fetcher]],
+ [
+ [
+ "mid_block.resnets.0",
+ ]
+ ],
+ [
+ [
+ "mid_block.resnets.1",
+ ]
+ ],
+ [
+ [
+ "conv_out",
+ ]
+ ],
+ ]
+ layer_names = [layer_mapping[0][0] for layer_mapping in layer_mappings]
+ if not set(layer_names).issubset([n[0] for n in model.named_modules()]):
+ raise ValueError(
+ "Provided model is not compatible with the default layer_mappings, "
+ 'please use the model fine-tuned from "CompVis/stable-diffusion-v1-4", '
+ "or modify the layer_mappings variable to fit your model."
+ f"\nDefault layer_mappings are as such:\n{layer_mappings}"
+ )
+ from neural_compressor.config import DistillationConfig, IntermediateLayersKnowledgeDistillationLossConfig
+
+ distillation_criterion = IntermediateLayersKnowledgeDistillationLossConfig(
+ layer_mappings=layer_mappings,
+ loss_types=["MSE"] * len(layer_mappings),
+ loss_weights=[1.0 / len(layer_mappings)] * len(layer_mappings),
+ add_origin_loss=True,
+ )
+ d_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)
+ confs.append(d_conf)
+
+ from neural_compressor.training import prepare_compression
+
+ compression_manager = prepare_compression(model, confs)
+ compression_manager.callbacks.on_train_begin()
+ model = compression_manager.model
+ train_func(model)
+ compression_manager.callbacks.on_train_end()
+
+ # Save the resulting model and its corresponding configuration in the given directory
+ model.save(args.output_dir)
+
+ logger.info(f"Optimized model saved to: {args.output_dir}.")
+
+ # change to framework model for further use
+ model = model.model
+
+ # Create the pipeline using using the trained modules and save it.
+ templates = imagenet_style_templates_small if args.learnable_property == "style" else imagenet_templates_small
+ prompt = templates[0].format(args.placeholder_token)
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ pipeline = pipeline.to(unet.device)
+ baseline_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
+ baseline_model_images.save(
+ os.path.join(args.output_dir, "{}_baseline_model.png".format("_".join(prompt.split())))
+ )
+
+ if not train_unet:
+ # Also save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
+ else:
+ setattr(pipeline, "unet", accelerator.unwrap_model(model))
+ if args.do_quantization:
+ pipeline = pipeline.to(torch.device("cpu"))
+
+ optimized_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
+ optimized_model_images.save(
+ os.path.join(args.output_dir, "{}_optimized_model.png".format("_".join(prompt.split())))
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+ if args.do_quantization and args.verify_loading:
+ # Load the model obtained after Intel Neural Compressor quantization
+ from neural_compressor.utils.pytorch import load
+
+ loaded_model = load(args.output_dir, model=unet)
+ loaded_model.eval()
+
+ setattr(pipeline, "unet", loaded_model)
+ if args.do_quantization:
+ pipeline = pipeline.to(torch.device("cpu"))
+
+ loaded_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)
+ if loaded_model_images != optimized_model_images:
+ logger.info("The quantized model was not successfully loaded.")
+ else:
+ logger.info("The quantized model was successfully loaded.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/lora/README.md b/diffusers/examples/research_projects/lora/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b5d72403166f9b4017751c3d47f79a9eb3f535d8
--- /dev/null
+++ b/diffusers/examples/research_projects/lora/README.md
@@ -0,0 +1,83 @@
+# Stable Diffusion text-to-image fine-tuning
+This extended LoRA training script was authored by [haofanwang](https://github.com/haofanwang).
+This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py). We further support add LoRA layers for text encoder.
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=512 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --output_dir="sd-pokemon-model-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb"
+ --use_peft \
+ --lora_r=4 --lora_alpha=32 \
+ --lora_text_encoder_r=4 --lora_text_encoder_alpha=32
+```
+
+The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
+
+**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___**
+
+The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___**
+
+You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw).
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You
+need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora`.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_path = "sayakpaul/sd-model-finetuned-lora-t4"
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
+pipe.unet.load_attn_procs(model_path)
+pipe.to("cuda")
+
+prompt = "A pokemon with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("pokemon.png")
+```
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/lora/requirements.txt b/diffusers/examples/research_projects/lora/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..89a1b73e70728b2395ca5f121f22def70f2076f9
--- /dev/null
+++ b/diffusers/examples/research_projects/lora/requirements.txt
@@ -0,0 +1,8 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+Jinja2
+git+https://github.com/huggingface/peft.git
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/lora/train_text_to_image_lora.py b/diffusers/examples/research_projects/lora/train_text_to_image_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69284042af4dd1552378d7293221afe2ec05788
--- /dev/null
+++ b/diffusers/examples/research_projects/lora/train_text_to_image_lora.py
@@ -0,0 +1,1014 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import itertools
+import json
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+
+ # lora args
+ parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora")
+ parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True")
+ parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True")
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_r",
+ type=int,
+ default=4,
+ help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_alpha",
+ type=int,
+ default=32,
+ help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_dropout",
+ type=float,
+ default=0.0,
+ help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
+ )
+ parser.add_argument(
+ "--lora_text_encoder_bias",
+ type=str,
+ default="none",
+ help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if args.use_peft:
+ from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
+
+ UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
+ TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
+
+ config = LoraConfig(
+ r=args.lora_r,
+ lora_alpha=args.lora_alpha,
+ target_modules=UNET_TARGET_MODULES,
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ )
+ unet = LoraModel(config, unet)
+
+ vae.requires_grad_(False)
+ if args.train_text_encoder:
+ config = LoraConfig(
+ r=args.lora_text_encoder_r,
+ lora_alpha=args.lora_text_encoder_alpha,
+ target_modules=TEXT_ENCODER_TARGET_MODULES,
+ lora_dropout=args.lora_text_encoder_dropout,
+ bias=args.lora_text_encoder_bias,
+ )
+ text_encoder = LoraModel(config, text_encoder)
+ else:
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ text_encoder.requires_grad_(False)
+
+ # now we will add new LoRA weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Set correct lora layers
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
+
+ unet.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ if args.use_peft:
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ optimizer = optimizer_cls(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+ else:
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.use_peft:
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ if args.use_peft:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ else:
+ params_to_clip = lora_layers.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.use_peft:
+ lora_config = {}
+ unwarpped_unet = accelerator.unwrap_model(unet)
+ state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
+ lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
+ if args.train_text_encoder:
+ unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
+ text_encoder_state_dict = get_peft_model_state_dict(
+ unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
+ )
+ text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
+ state_dict.update(text_encoder_state_dict)
+ lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
+ inference=True
+ )
+
+ accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
+ with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
+ json.dump(lora_config, f)
+ else:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
+ )
+
+ if args.use_peft:
+
+ def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
+ with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
+ lora_config = json.load(f)
+ print(lora_config)
+
+ checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
+ lora_checkpoint_sd = torch.load(checkpoint)
+ unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
+ text_encoder_lora_ds = {
+ k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
+ }
+
+ unet_config = LoraConfig(**lora_config["peft_config"])
+ pipe.unet = LoraModel(unet_config, pipe.unet)
+ set_peft_model_state_dict(pipe.unet, unet_lora_ds)
+
+ if "text_encoder_peft_config" in lora_config:
+ text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
+ pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
+ set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
+
+ if dtype in (torch.float16, torch.bfloat16):
+ pipe.unet.half()
+ pipe.text_encoder.half()
+
+ pipe.to(device)
+ return pipe
+
+ pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
+
+ else:
+ pipeline = pipeline.to(accelerator.device)
+ # load attention processors
+ pipeline.unet.load_attn_procs(args.output_dir)
+
+ # run inference
+ if args.seed is not None:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ else:
+ generator = None
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/mulit_token_textual_inversion/README.md b/diffusers/examples/research_projects/mulit_token_textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1303f73c175636466061110775cf1c905b4aba9a
--- /dev/null
+++ b/diffusers/examples/research_projects/mulit_token_textual_inversion/README.md
@@ -0,0 +1,143 @@
+## [Deprecated] Multi Token Textual Inversion
+
+**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the officail textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
+
+The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
+
+We add multi token support to textual inversion. I added
+1. num_vec_per_token for the number of used to reference that token
+2. progressive_tokens for progressively training the token from 1 token to 2 token etc
+3. progressive_tokens_max_steps for the max number of steps until we start full training
+4. vector_shuffle to shuffle vectors
+
+Feel free to add these options to your training! In practice num_vec_per_token around 10+vector shuffle works great!
+
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+
+### Cat toy example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
+
+And launch the training using
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATA_DIR="path-to-dir-containing-images"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat"
+```
+
+A full training run takes ~1 hour on one V100 GPU.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
+
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("cat-backpack.png")
+```
+
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="path-to-dir-containing-images"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --output_dir="textual_inversion_cat"
+```
+It should be at least 70% faster than the PyTorch script with the same configuration.
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
diff --git a/diffusers/examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py b/diffusers/examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..4388771b840df36ffa3a986dc9a2ad81ac7ee425
--- /dev/null
+++ b/diffusers/examples/research_projects/mulit_token_textual_inversion/multi_token_clip.py
@@ -0,0 +1,103 @@
+"""
+The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing
+a photo of _0 _1 ... and so on
+and instead just do
+a photo of
+which gets translated to the above. This needs to work for both inference and training.
+For inference,
+the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with
+it's underlying vectors
+For training,
+we would want to abstract away some logic like
+1. Adding tokens
+2. Updating gradient mask
+3. Saving embeddings
+to our Util class here.
+so
+TODO:
+1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x
+2. have mechanism for adding tokens x
+3. have mech for saving emebeddings x
+4. get mask to update x
+5. Loading tokens from embedding x
+6. Integrate to training x
+7. Test
+"""
+import copy
+import random
+
+from transformers import CLIPTokenizer
+
+
+class MultiTokenCLIPTokenizer(CLIPTokenizer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.token_map = {}
+
+ def try_adding_tokens(self, placeholder_token, *args, **kwargs):
+ num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs):
+ output = []
+ if num_vec_per_token == 1:
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
+ output.append(placeholder_token)
+ else:
+ output = []
+ for i in range(num_vec_per_token):
+ ith_token = placeholder_token + f"_{i}"
+ self.try_adding_tokens(ith_token, *args, **kwargs)
+ output.append(ith_token)
+ # handle cases where there is a new placeholder token that contains the current placeholder token but is larger
+ for token in self.token_map:
+ if token in placeholder_token:
+ raise ValueError(
+ f"The tokenizer already has placeholder token {token} that can get confused with"
+ f" {placeholder_token}keep placeholder tokens independent"
+ )
+ self.token_map[placeholder_token] = output
+
+ def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0):
+ """
+ Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder
+ can encode them
+ vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119
+ where shuffling tokens were found to force the model to learn the concepts more descriptively.
+ """
+ if isinstance(text, list):
+ output = []
+ for i in range(len(text)):
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
+ return output
+ for placeholder_token in self.token_map:
+ if placeholder_token in text:
+ tokens = self.token_map[placeholder_token]
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
+ if vector_shuffle:
+ tokens = copy.copy(tokens)
+ random.shuffle(tokens)
+ text = text.replace(placeholder_token, " ".join(tokens))
+ return text
+
+ def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
+ return super().__call__(
+ self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ ),
+ *args,
+ **kwargs,
+ )
+
+ def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
+ return super().encode(
+ self.replace_placeholder_tokens_in_text(
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
+ ),
+ *args,
+ **kwargs,
+ )
diff --git a/diffusers/examples/research_projects/mulit_token_textual_inversion/requirements.txt b/diffusers/examples/research_projects/mulit_token_textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a612982f4abbaa64f83db52e411a1235a372259
--- /dev/null
+++ b/diffusers/examples/research_projects/mulit_token_textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt b/diffusers/examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f85ad523a3b46b65abf0138c05ecdd656e6845c
--- /dev/null
+++ b/diffusers/examples/research_projects/mulit_token_textual_inversion/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py b/diffusers/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b6c3860a2967db967561581fa060f5dae64082
--- /dev/null
+++ b/diffusers/examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
@@ -0,0 +1,927 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from multi_token_clip import MultiTokenCLIPTokenizer
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
+ """
+ Add tokens to the tokenizer and set the initial value of token embeddings
+ """
+ tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
+ text_encoder.resize_token_embeddings(len(tokenizer))
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ if initializer_token:
+ token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
+ else:
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
+ return placeholder_token
+
+
+def save_progress(tokenizer, text_encoder, accelerator, save_path):
+ for placeholder_token in tokenizer.token_map:
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]
+ if len(placeholder_token_ids) == 1:
+ learned_embeds = learned_embeds[None]
+ learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):
+ for placeholder_token in learned_embeds_dict:
+ placeholder_embeds = learned_embeds_dict[placeholder_token]
+ num_vec_per_token = placeholder_embeds.shape[0]
+ placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)
+ add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
+ token_embeds[placeholder_token_id] = placeholder_embeds[i]
+
+
+def load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):
+ """
+ Automatic1111's tokens have format
+ {'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833, 0.0030, 0.0057, ..., -0.0264, -0.0616, -0.0529],
+ [ 0.0058, -0.0190, -0.0584, ..., -0.0025, -0.0945, -0.0490],
+ [ 0.0916, 0.0025, 0.0365, ..., -0.0685, -0.0124, 0.0728],
+ [ 0.0812, -0.0199, -0.0100, ..., -0.0581, -0.0780, 0.0254]],
+ requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}
+ """
+ learned_embeds_dict = {}
+ learned_embeds_dict[placeholder_token] = automatic_dict["string_to_param"]["*"]
+ load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)
+
+
+def get_mask(tokenizer, accelerator):
+ # Get the mask of the weights that won't change
+ mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)
+ for placeholder_token in tokenizer.token_map:
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
+ for i in range(len(placeholder_token_ids)):
+ mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)
+ return mask
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--progressive_tokens_max_steps",
+ type=int,
+ default=2000,
+ help="The number of steps until all tokens will be used.",
+ )
+ parser.add_argument(
+ "--progressive_tokens",
+ action="store_true",
+ help="Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.",
+ )
+ parser.add_argument("--vector_shuffle", action="store_true", help="Shuffling tokens durint training")
+ parser.add_argument(
+ "--num_vec_per_token",
+ type=int,
+ default=1,
+ help=(
+ "The number of vectors used to represent the placeholder token. The higher the number, the better the"
+ " result at the cost of editability. This can be fixed by prompt editing."
+ ),
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--only_save_embeds",
+ action="store_true",
+ default=False,
+ help="Save only the embeddings for the new concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ vector_shuffle=False,
+ progressive_tokens=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+ self.vector_shuffle = vector_shuffle
+ self.progressive_tokens = progressive_tokens
+ self.prop_tokens_to_load = 0
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer.encode(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ vector_shuffle=self.vector_shuffle,
+ prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,
+ )[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ if is_xformers_available():
+ try:
+ unet.enable_xformers_memory_efficient_attention()
+ except Exception as e:
+ logger.warning(
+ "Could not enable memory efficient attention. Make sure xformers is installed"
+ f" correctly and a GPU is available: {e}"
+ )
+ add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the unet and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+ if args.progressive_tokens:
+ train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps
+
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = get_mask(tokenizer, accelerator)
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(tokenizer, text_encoder, accelerator, save_path)
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = (
+ None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ )
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and args.only_save_embeds:
+ logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = not args.only_save_embeds
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(tokenizer, text_encoder, accelerator, save_path)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py b/diffusers/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecc89f98298e3e4205581fee1689761c519bc4e4
--- /dev/null
+++ b/diffusers/examples/research_projects/mulit_token_textual_inversion/textual_inversion_flax.py
@@ -0,0 +1,654 @@
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import PIL
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.14.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument(
+ "--use_auth_token",
+ action="store_true",
+ help=(
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ " private models)."
+ ),
+ )
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
+ if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
+ return
+ model.config.vocab_size = new_num_tokens
+
+ params = model.params
+ old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+ old_num_tokens, emb_dim = old_embeddings.shape
+
+ initializer = jax.nn.initializers.normal()
+
+ new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
+ new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
+ new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
+ params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
+
+ model.params = params
+ return model
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
+
+ # Create sampling rng
+ rng = jax.random.PRNGKey(args.seed)
+ rng, _ = jax.random.split(rng)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder = resize_token_embeddings(
+ text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
+ )
+ original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {"pixel_values": pixel_values, "input_ids": input_ids}
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ optimizer = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ def create_mask(params, label_fn):
+ def _map(params, mask, label_fn):
+ for k in params:
+ if label_fn(k):
+ mask[k] = "token_embedding"
+ else:
+ if isinstance(params[k], dict):
+ mask[k] = {}
+ _map(params[k], mask[k], label_fn)
+ else:
+ mask[k] = "zero"
+
+ mask = {}
+ _map(params, mask, label_fn)
+ return mask
+
+ def zero_grads():
+ # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
+ def init_fn(_):
+ return ()
+
+ def update_fn(updates, state, params=None):
+ return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+ # Zero out gradients of layers other than the token embedding layer
+ tx = optax.multi_transform(
+ {"token_embedding": optimizer, "zero": zero_grads()},
+ create_mask(text_encoder.params, lambda s: s == "token_embedding"),
+ )
+
+ state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ # Define gradient train step fn
+ def train_step(state, vae_params, unet_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+ encoder_hidden_states = state.apply_fn(
+ batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
+ )[0]
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+ new_state = state.apply_gradients(grads=grad)
+
+ # Keep the token embeddings fixed except the newly added embeddings for the concept,
+ # as we only want to optimize the concept embeddings
+ token_embeds = original_token_embeds.at[placeholder_token_id].set(
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
+ )
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train and eval step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ vae_params = jax_utils.replicate(vae_params)
+ unet_params = jax_utils.replicate(unet_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+ global_step += 1
+
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ # Also save the newly trained embeddings
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
+ placeholder_token_id
+ ]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth/README.md b/diffusers/examples/research_projects/multi_subject_dreambooth/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5fff305f82be263fc08ffc222f003701d42dfe7f
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth/README.md
@@ -0,0 +1,338 @@
+# Multi Subject DreamBooth training
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+This `train_multi_subject_dreambooth.py` script shows how to implement the training procedure for one or more subjects and adapt it for stable diffusion. Note that this code is based off of the `examples/dreambooth/train_dreambooth.py` script as of 01/06/2022.
+
+This script was added by @kopsahlong, and is not actively maintained. However, if you come across anything that could use fixing, feel free to open an issue and tag @kopsahlong.
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the script, make sure to install the library's training dependencies:
+
+To start, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd into the folder `diffusers/examples/research_projects/multi_subject_dreambooth` and run the following:
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+### Multi Subject Training Example
+In order to have your model learn multiple concepts at once, we simply add in the additional data directories and prompts to our `instance_data_dir` and `instance_prompt` (as well as `class_data_dir` and `class_prompt` if `--with_prior_preservation` is specified) as one comma separated string.
+
+See an example with 2 subjects below, which learns a model for one dog subject and one human subject:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export OUTPUT_DIR="path-to-save-model"
+
+# Subject 1
+export INSTANCE_DIR_1="path-to-instance-images-concept-1"
+export INSTANCE_PROMPT_1="a photo of a sks dog"
+export CLASS_DIR_1="path-to-class-images-dog"
+export CLASS_PROMPT_1="a photo of a dog"
+
+# Subject 2
+export INSTANCE_DIR_2="path-to-instance-images-concept-2"
+export INSTANCE_PROMPT_2="a photo of a t@y person"
+export CLASS_DIR_2="path-to-class-images-person"
+export CLASS_PROMPT_2="a photo of a person"
+
+accelerate launch train_multi_subject_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir="$INSTANCE_DIR_1,$INSTANCE_DIR_2" \
+ --output_dir=$OUTPUT_DIR \
+ --train_text_encoder \
+ --instance_prompt="$INSTANCE_PROMPT_1,$INSTANCE_PROMPT_2" \
+ --with_prior_preservation \
+ --prior_loss_weight=1.0 \
+ --class_data_dir="$CLASS_DIR_1,$CLASS_DIR_2" \
+ --class_prompt="$CLASS_PROMPT_1,$CLASS_PROMPT_2"\
+ --num_class_images=50 \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=1500
+```
+
+This example shows training for 2 subjects, but please note that the model can be trained on any number of new concepts. This can be done by continuing to add in the corresponding directories and prompts to the corresponding comma separated string.
+
+Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.
+
+**Important**: New parameters are added to the script, making possible to validate the progress of the training by
+generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt
+it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we
+introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different
+configuration for each subject that you want to train.
+
+An example of how to generate the file:
+```python
+import json
+
+# here we are using parameters for prior-preservation and validation as well.
+concepts_list = [
+ {
+ "instance_prompt": "drawing of a t@y meme",
+ "class_prompt": "drawing of a meme",
+ "instance_data_dir": "/some_folder/meme_toy",
+ "class_data_dir": "/data/meme",
+ "validation_prompt": "drawing of a t@y meme about football in Uruguay",
+ "validation_negative_prompt": "black and white"
+ },
+ {
+ "instance_prompt": "drawing of a sks sir",
+ "class_prompt": "drawing of a sir",
+ "instance_data_dir": "/some_other_folder/sir_sks",
+ "class_data_dir": "/data/sir",
+ "validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest",
+ "validation_negative_prompt": "an old man",
+ "validation_guidance_scale": 20,
+ "validation_number_images": 3,
+ "validation_inference_steps": 10
+ }
+]
+
+with open("concepts_list.json", "w") as f:
+ json.dump(concepts_list, f, indent=4)
+```
+And then just point to the file when executing the script:
+
+```bash
+# exports...
+accelerate launch train_multi_subject_dreambooth.py \
+# more parameters...
+--concepts_list="concepts_list.json"
+```
+
+You can use the helper from the script to get a better sense of each parameter.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+prompt = "A photo of a t@y person petting an sks dog"
+image = pipe(prompt, num_inference_steps=200, guidance_scale=7.5).images[0]
+
+image.save("person-petting-dog.png")
+```
+
+### Inference from a training checkpoint
+
+You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
+
+## Additional Dreambooth documentation
+Because the `train_multi_subject_dreambooth.py` script here was forked from an original version of `train_dreambooth.py` in the `examples/dreambooth` folder, I've included the original applicable training documentation for single subject examples below.
+
+This should explain how to play with training variables such as prior preservation, fine tuning the text encoder, etc. which is still applicable to our multi subject training code. Note also that the examples below, which are single subject examples, also work with `train_multi_subject_dreambooth.py`, as this script supports 1 (or more) subjects.
+
+### Single subject dog toy example
+
+Let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
+
+And launch the training using
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400
+```
+
+### Training with prior-preservation loss
+
+Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+
+### Training on a 16GB GPU:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
+
+To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=2 --gradient_checkpointing \
+ --use_8bit_adam \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Training on a 8 GB GPU:
+
+By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
+tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
+
+DeepSpeed needs to be enabled with `accelerate config`. During configuration
+answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
+mixed precision and offloading both parameters and optimizer state to cpu it's
+possible to train on under 8 GB VRAM with a drawback of requiring significantly
+more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
+
+Changing the default Adam optimizer to DeepSpeed's special version of Adam
+`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
+it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
+does not seem to be compatible with DeepSpeed at the moment.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch --mixed_precision="fp16" train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --sample_batch_size=1 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Fine-tune text encoder with the UNet.
+
+The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
+Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
+
+___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export INSTANCE_DIR="path-to-instance-images"
+export CLASS_DIR="path-to-class-images"
+export OUTPUT_DIR="path-to-save-model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_text_encoder \
+ --instance_data_dir=$INSTANCE_DIR \
+ --class_data_dir=$CLASS_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --with_prior_preservation --prior_loss_weight=1.0 \
+ --instance_prompt="a photo of sks dog" \
+ --class_prompt="a photo of dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --learning_rate=2e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --num_class_images=200 \
+ --max_train_steps=800
+```
+
+### Using DreamBooth for other pipelines than Stable Diffusion
+
+Altdiffusion also support dreambooth now, the runing comman is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
+One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
+
+```
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
+or
+export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
+```
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
+
+You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth/requirements.txt b/diffusers/examples/research_projects/multi_subject_dreambooth/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e19b0ce60bf407bec21a9b85f9232cad957bfa6f
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/diffusers/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..d58c4009b69a4cc6c817f084040d21a3f9570be6
--- /dev/null
+++ b/diffusers/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -0,0 +1,1185 @@
+import argparse
+import itertools
+import json
+import logging
+import math
+import uuid
+import warnings
+from os import environ, listdir, makedirs
+from os.path import basename, join
+from pathlib import Path
+from typing import List
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from PIL import Image
+from torch import dtype
+from torch.nn import Module
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.13.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def log_validation_images_to_tracker(
+ images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int
+):
+ logger.info(f"Logging images to tracker for validation prompt: {validation_prompt}.")
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{label}_{epoch}_{i}: {validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+
+# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings`
+# argument is implemented.
+def generate_validation_images(
+ text_encoder: Module,
+ tokenizer: Module,
+ unet: Module,
+ vae: Module,
+ arguments: argparse.Namespace,
+ accelerator: Accelerator,
+ weight_dtype: dtype,
+):
+ logger.info("Running validation images.")
+
+ pipeline_args = {}
+
+ if text_encoder is not None:
+ pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
+
+ if vae is not None:
+ pipeline_args["vae"] = vae
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ arguments.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ revision=arguments.revision,
+ torch_dtype=weight_dtype,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the
+ # scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ generator = (
+ None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed)
+ )
+
+ images_sets = []
+ for vp, nvi, vnp, vis, vgs in zip(
+ arguments.validation_prompt,
+ arguments.validation_number_images,
+ arguments.validation_negative_prompt,
+ arguments.validation_inference_steps,
+ arguments.validation_guidance_scale,
+ ):
+ images = []
+ if vp is not None:
+ logger.info(
+ f"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, "
+ f"guidance scale: {vgs}."
+ )
+
+ pipeline_args = {"prompt": vp, "negative_prompt": vnp, "num_inference_steps": vis, "guidance_scale": vgs}
+
+ # run inference
+ # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a
+ # time or in small batches
+ for _ in range(nvi):
+ with torch.autocast("cuda"):
+ image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0]
+ images.append(image)
+
+ images_sets.append(images)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images_sets
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=False,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=None,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` "
+ "multiple times (`validation_number_images`) and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning. You can use commas to "
+ "define multiple negative prompts. This parameter can be defined also within the file given by "
+ "`concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_number_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with the validation parameters given. This "
+ "can be defined within the file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_negative_prompt",
+ type=str,
+ default=None,
+ help="A negative prompt that is used during validation to verify that the model is learning. You can use commas"
+ " to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can "
+ "be defined also within the file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_inference_steps",
+ type=int,
+ default=25,
+ help="Number of inference steps (denoising steps) to run during validation. This can be defined within the "
+ "file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--validation_guidance_scale",
+ type=float,
+ default=7.5,
+ help="To control how much the image generation process follows the text prompt. This can be defined within the "
+ "file given by `concepts_list` parameter in the respective subject.",
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--concepts_list",
+ type=str,
+ default=None,
+ help="Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt,"
+ " class_prompt, etc.",
+ )
+
+ if input_args:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt):
+ raise ValueError(
+ "You must specify either instance parameters (data directory, prompt, etc.) or use "
+ "the `concept_list` parameter and specify them within the file."
+ )
+
+ if args.concepts_list:
+ if args.instance_prompt:
+ raise ValueError("If you are using `concepts_list` parameter, define the instance prompt within the file.")
+ if args.instance_data_dir:
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define the instance data directory within the file."
+ )
+ if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt):
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define validation parameters for "
+ "each subject within the file:\n - `validation_prompt`."
+ "\n - `validation_negative_prompt`.\n - `validation_guidance_scale`."
+ "\n - `validation_number_images`.\n - `validation_prompt`."
+ "\n - `validation_inference_steps`.\nThe `validation_steps` parameter is the only one "
+ "that needs to be defined outside the file."
+ )
+
+ env_local_rank = int(environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if not args.concepts_list:
+ if not args.class_data_dir:
+ raise ValueError("You must specify a data directory for class images.")
+ if not args.class_prompt:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ if args.class_data_dir:
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define the class data directory within the file."
+ )
+ if args.class_prompt:
+ raise ValueError(
+ "If you are using `concepts_list` parameter, define the class prompt within the file."
+ )
+ else:
+ # logger is not available yet
+ if not args.class_data_dir:
+ warnings.warn(
+ "Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`."
+ )
+ if not args.class_prompt:
+ warnings.warn(
+ "Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`."
+ )
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and then tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = []
+ self.instance_images_path = []
+ self.num_instance_images = []
+ self.instance_prompt = []
+ self.class_data_root = [] if class_data_root is not None else None
+ self.class_images_path = []
+ self.num_class_images = []
+ self.class_prompt = []
+ self._length = 0
+
+ for i in range(len(instance_data_root)):
+ self.instance_data_root.append(Path(instance_data_root[i]))
+ if not self.instance_data_root[i].exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path.append(list(Path(instance_data_root[i]).iterdir()))
+ self.num_instance_images.append(len(self.instance_images_path[i]))
+ self.instance_prompt.append(instance_prompt[i])
+ self._length += self.num_instance_images[i]
+
+ if class_data_root is not None:
+ self.class_data_root.append(Path(class_data_root[i]))
+ self.class_data_root[i].mkdir(parents=True, exist_ok=True)
+ self.class_images_path.append(list(self.class_data_root[i].iterdir()))
+ self.num_class_images.append(len(self.class_images_path))
+ if self.num_class_images[i] > self.num_instance_images[i]:
+ self._length -= self.num_instance_images[i]
+ self._length += self.num_class_images[i]
+ self.class_prompt.append(class_prompt[i])
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ for i in range(len(self.instance_images_path)):
+ instance_image = Image.open(self.instance_images_path[i][index % self.num_instance_images[i]])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example[f"instance_images_{i}"] = self.image_transforms(instance_image)
+ example[f"instance_prompt_ids_{i}"] = self.tokenizer(
+ self.instance_prompt[i],
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ if self.class_data_root:
+ for i in range(len(self.class_data_root)):
+ class_image = Image.open(self.class_images_path[i][index % self.num_class_images[i]])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example[f"class_images_{i}"] = self.image_transforms(class_image)
+ example[f"class_prompt_ids_{i}"] = self.tokenizer(
+ self.class_prompt[i],
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return example
+
+
+def collate_fn(num_instances, examples, with_prior_preservation=False):
+ input_ids = []
+ pixel_values = []
+
+ for i in range(num_instances):
+ input_ids += [example[f"instance_prompt_ids_{i}"] for example in examples]
+ pixel_values += [example[f"instance_images_{i}"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ for i in range(num_instances):
+ input_ids += [example[f"class_prompt_ids_{i}"] for example in examples]
+ pixel_values += [example[f"class_images_{i}"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ return batch
+
+
+class PromptDataset(Dataset):
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ instance_data_dir = []
+ instance_prompt = []
+ class_data_dir = [] if args.with_prior_preservation else None
+ class_prompt = [] if args.with_prior_preservation else None
+ if args.concepts_list:
+ with open(args.concepts_list, "r") as f:
+ concepts_list = json.load(f)
+
+ if args.validation_steps:
+ args.validation_prompt = []
+ args.validation_number_images = []
+ args.validation_negative_prompt = []
+ args.validation_inference_steps = []
+ args.validation_guidance_scale = []
+
+ for concept in concepts_list:
+ instance_data_dir.append(concept["instance_data_dir"])
+ instance_prompt.append(concept["instance_prompt"])
+
+ if args.with_prior_preservation:
+ try:
+ class_data_dir.append(concept["class_data_dir"])
+ class_prompt.append(concept["class_prompt"])
+ except KeyError:
+ raise KeyError(
+ "`class_data_dir` or `class_prompt` not found in concepts_list while using "
+ "`with_prior_preservation`."
+ )
+ else:
+ if "class_data_dir" in concept:
+ warnings.warn(
+ "Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`."
+ )
+ if "class_prompt" in concept:
+ warnings.warn(
+ "Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`."
+ )
+
+ if args.validation_steps:
+ args.validation_prompt.append(concept.get("validation_prompt", None))
+ args.validation_number_images.append(concept.get("validation_number_images", 4))
+ args.validation_negative_prompt.append(concept.get("validation_negative_prompt", None))
+ args.validation_inference_steps.append(concept.get("validation_inference_steps", 25))
+ args.validation_guidance_scale.append(concept.get("validation_guidance_scale", 7.5))
+ else:
+ # Parse instance and class inputs, and double check that lengths match
+ instance_data_dir = args.instance_data_dir.split(",")
+ instance_prompt = args.instance_prompt.split(",")
+ assert all(
+ x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
+ ), "Instance data dir and prompt inputs are not of the same length."
+
+ if args.with_prior_preservation:
+ class_data_dir = args.class_data_dir.split(",")
+ class_prompt = args.class_prompt.split(",")
+ assert all(
+ x == len(instance_data_dir)
+ for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]
+ ), "Instance & class data dir or prompt inputs are not of the same length."
+
+ if args.validation_steps:
+ validation_prompts = args.validation_prompt.split(",")
+ num_of_validation_prompts = len(validation_prompts)
+ args.validation_prompt = validation_prompts
+ args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts
+
+ negative_validation_prompts = [None] * num_of_validation_prompts
+ if args.validation_negative_prompt:
+ negative_validation_prompts = args.validation_negative_prompt.split(",")
+ while len(negative_validation_prompts) < num_of_validation_prompts:
+ negative_validation_prompts.append(None)
+ args.validation_negative_prompt = negative_validation_prompts
+
+ assert num_of_validation_prompts == len(
+ negative_validation_prompts
+ ), "The length of negative prompts for validation is greater than the number of validation prompts."
+ args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
+ args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ for i in range(len(class_data_dir)):
+ class_images_dir = Path(class_data_dir[i])
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(class_prompt[i], num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for ii, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = (
+ class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg"
+ )
+ image.save(image_filename)
+
+ # Clean up the memory deleting one-time-use variables.
+ del pipeline
+ del sample_dataloader
+ del sample_dataset
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = None
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ vae.requires_grad_(False)
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=instance_data_dir,
+ instance_prompt=instance_prompt,
+ class_data_root=class_data_dir,
+ class_prompt=class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(len(instance_data_dir), examples, args.with_prior_preservation),
+ num_workers=1,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+ if not args.train_text_encoder:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initialize automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("dreambooth", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ time_steps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ time_steps = time_steps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, time_steps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ save_path = join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if (
+ args.validation_steps
+ and any(args.validation_prompt)
+ and global_step % args.validation_steps == 0
+ ):
+ images_set = generate_validation_images(
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype
+ )
+ for images, validation_prompt in zip(images_set, args.validation_prompt):
+ if len(images) > 0:
+ label = str(uuid.uuid1())[:8] # generate an id for different set of images
+ log_validation_images_to_tracker(
+ images, label, validation_prompt, accelerator, global_step
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/onnxruntime/README.md b/diffusers/examples/research_projects/onnxruntime/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..204d9c951c996fedabc169d9a32781be9f4c4cc1
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/README.md
@@ -0,0 +1,5 @@
+## Diffusers examples with ONNXRuntime optimizations
+
+**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.**
+
+This aims to provide diffusers examples with ONNXRuntime optimizations for training/fine-tuning unconditional image generation, text to image, and textual inversion. Please see individual directories for more details on how to run each task using ONNXRuntime.
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/text_to_image/README.md b/diffusers/examples/research_projects/onnxruntime/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cd9397939ac2399ac161f19623430636a4c3c9ad
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/text_to_image/README.md
@@ -0,0 +1,74 @@
+# Stable Diffusion text-to-image fine-tuning
+
+The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
+
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Pokemon example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+## Use ONNXRuntime to accelerate training
+In order to leverage onnxruntime to accelerate training, please use train_text_to_image.py
+
+The command to train a DDPM UNetCondition model on the Pokemon dataset with onnxruntime:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export dataset_name="lambdalabs/pokemon-blip-captions"
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/text_to_image/requirements.txt b/diffusers/examples/research_projects/onnxruntime/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2dbadea4474aac24b501e61a4b05f24168ac85be
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+modelcards
diff --git a/diffusers/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/diffusers/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7100788cde23496cc31a71ae6b8d2f65ea7a9b6
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -0,0 +1,943 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+from onnxruntime.training.ortmodule import ORTModule
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer = ORT_FP16_Optimizer(optimizer)
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ unet = ORTModule(unet)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_pertubation:
+ new_noise = noise + args.input_pertubation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_pertubation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/onnxruntime/textual_inversion/README.md b/diffusers/examples/research_projects/onnxruntime/textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9f08983eaaadef4ca750e9791373898f33ee5f0b
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/textual_inversion/README.md
@@ -0,0 +1,94 @@
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+
+### Cat toy example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
+
+Let's first download it locally:
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
+```
+
+This will be our training data.
+Now we can launch the training using
+
+## Use ONNXRuntime to accelerate training
+In order to leverage onnxruntime to accelerate training, please use textual_inversion.py
+
+The command to train on custom data with onnxruntime:
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATA_DIR="path-to-dir-containing-images"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat"
+```
+
+Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/textual_inversion/requirements.txt b/diffusers/examples/research_projects/onnxruntime/textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c1a94eac83e6eb9a3a2dd11672b5d73f794ca3d1
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+modelcards
diff --git a/diffusers/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/diffusers/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..59b5089d07b4c3041e6103f844c730e8f91caa4c
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -0,0 +1,946 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import warnings
+from pathlib import Path
+
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+from onnxruntime.training.ortmodule import ORTModule
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- textual_inversion
+inference: true
+---
+ """
+ model_card = f"""
+# Textual inversion text2image fine-tuning - {repo_id}
+These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+ return images
+
+
+def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
+ logger.info("Saving embeddings")
+ learned_embeds = (
+ accelerator.unwrap_model(text_encoder)
+ .get_input_embeddings()
+ .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
+ )
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--save_as_full_pipeline",
+ action="store_true",
+ help="Save the complete stable diffusion pipeline.",
+ )
+ parser.add_argument(
+ "--num_vectors",
+ type=int,
+ default=1,
+ help="How many textual inversion vectors shall be used to learn the concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=None,
+ help=(
+ "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Add the placeholder token in tokenizer
+ placeholder_tokens = [args.placeholder_token]
+
+ if args.num_vectors < 1:
+ raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, args.num_vectors):
+ additional_tokens.append(f"{args.placeholder_token}_{i}")
+ placeholder_tokens += additional_tokens
+
+ num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for token_id in placeholder_token_ids:
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer = ORT_FP16_Optimizer(optimizer)
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+ if args.validation_epochs is not None:
+ warnings.warn(
+ f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
+ " Deprecated validation_epochs in favor of `validation_steps`"
+ f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
+ FutureWarning,
+ stacklevel=2,
+ )
+ args.validation_steps = args.validation_epochs * len(train_dataset)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ text_encoder = ORTModule(text_encoder)
+ unet = ORTModule(unet)
+ vae = ORTModule(vae)
+
+ # For mixed precision training we cast the unet and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
+
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ images = []
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
+ save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and not args.save_as_full_pipeline:
+ logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = args.save_as_full_pipeline
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
+ save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/README.md b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c28ecefc9a3002b2f6c6d3d97e53047e82ab2733
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/README.md
@@ -0,0 +1,50 @@
+## Training examples
+
+Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+#### Use ONNXRuntime to accelerate training
+
+In order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py
+
+The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-flowers-64" \
+ --use_ema \
+ --train_batch_size=16 \
+ --num_epochs=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=fp16
+ ```
+
+Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
diff --git a/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca21143c42d9c08bf693fcc8cd11fed53acb895f
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt
@@ -0,0 +1,4 @@
+accelerate>=0.16.0
+torchvision
+datasets
+tensorboard
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cad9f2fbed9f5856d96156f5216e850943413a7
--- /dev/null
+++ b/diffusers/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
@@ -0,0 +1,687 @@
+import argparse
+import inspect
+import logging
+import math
+import os
+from pathlib import Path
+
+import accelerate
+import datasets
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
+from onnxruntime.training.ortmodule import ORTModule
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.17.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ if not isinstance(arr, torch.Tensor):
+ arr = torch.from_numpy(arr)
+ res = arr[timesteps].float().to(timesteps.device)
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that HF Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="ddpm-model-64",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--overwrite_output_dir", action="store_true")
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=64,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ default=False,
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+ " process."
+ ),
+ )
+ parser.add_argument("--num_epochs", type=int, default=100)
+ parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
+ parser.add_argument(
+ "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="cosine",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
+ parser.add_argument(
+ "--use_ema",
+ action="store_true",
+ help="Whether to use Exponential Moving Average for the final model weights.",
+ )
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
+ parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
+ )
+ parser.add_argument(
+ "--logger",
+ type=str,
+ default="tensorboard",
+ choices=["tensorboard", "wandb"],
+ help=(
+ "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
+ " for experiment tracking and logging of model metrics and model checkpoints"
+ ),
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default="epsilon",
+ choices=["epsilon", "sample"],
+ help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
+ )
+ parser.add_argument("--ddpm_num_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more docs"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
+
+ return args
+
+
+def main(args):
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.logger == "tensorboard":
+ if not is_tensorboard_available():
+ raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
+
+ elif args.logger == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
+ ema_model.load_state_dict(load_model.state_dict())
+ ema_model.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Initialize the model
+ if args.model_config_name_or_path is None:
+ model = UNet2DModel(
+ sample_size=args.resolution,
+ in_channels=3,
+ out_channels=3,
+ layers_per_block=2,
+ block_out_channels=(128, 128, 256, 256, 512, 512),
+ down_block_types=(
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "AttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "AttnUpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+ )
+ else:
+ config = UNet2DModel.load_config(args.model_config_name_or_path)
+ model = UNet2DModel.from_config(config)
+
+ # Create EMA for the model.
+ if args.use_ema:
+ ema_model = EMAModel(
+ model.parameters(),
+ decay=args.ema_max_decay,
+ use_ema_warmup=True,
+ inv_gamma=args.ema_inv_gamma,
+ power=args.ema_power,
+ model_cls=UNet2DModel,
+ model_config=model.config,
+ )
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ model.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Initialize the scheduler
+ accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
+ if accepts_prediction_type:
+ noise_scheduler = DDPMScheduler(
+ num_train_timesteps=args.ddpm_num_steps,
+ beta_schedule=args.ddpm_beta_schedule,
+ prediction_type=args.prediction_type,
+ )
+ else:
+ noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer = ORT_FP16_Optimizer(optimizer)
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ split="train",
+ )
+ else:
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets and DataLoaders creation.
+ augmentations = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def transform_images(examples):
+ images = [augmentations(image.convert("RGB")) for image in examples["image"]]
+ return {"input": images}
+
+ logger.info(f"Dataset size: {len(dataset)}")
+
+ dataset.set_transform(transform_images)
+ train_dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Initialize the learning rate scheduler
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=(len(train_dataloader) * args.num_epochs),
+ )
+
+ # Prepare everything with our `accelerator`.
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_model.to(accelerator.device)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ run = os.path.split(__file__)[-1].split(".")[0]
+ accelerator.init_trackers(run)
+
+ model = ORTModule(model)
+
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ max_train_steps = args.num_epochs * num_update_steps_per_epoch
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(dataset)}")
+ logger.info(f" Num Epochs = {args.num_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Train!
+ for epoch in range(first_epoch, args.num_epochs):
+ model.train()
+ progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
+ progress_bar.set_description(f"Epoch {epoch}")
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ clean_images = batch["input"]
+ # Sample noise that we'll add to the images
+ noise = torch.randn(
+ clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
+ ).to(clean_images.device)
+ bsz = clean_images.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
+ ).long()
+
+ # Add noise to the clean images according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
+
+ with accelerator.accumulate(model):
+ # Predict the noise residual
+ model_output = model(noisy_images, timesteps, return_dict=False)[0]
+
+ if args.prediction_type == "epsilon":
+ loss = F.mse_loss(model_output, noise) # this could have different weights!
+ elif args.prediction_type == "sample":
+ alpha_t = _extract_into_tensor(
+ noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
+ )
+ snr_weights = alpha_t / (1 - alpha_t)
+ loss = snr_weights * F.mse_loss(
+ model_output, clean_images, reduction="none"
+ ) # use SNR weighting from distillation paper
+ loss = loss.mean()
+ else:
+ raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_model.step(model.parameters())
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+ if args.use_ema:
+ logs["ema_decay"] = ema_model.cur_decay_value
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+ progress_bar.close()
+
+ accelerator.wait_for_everyone()
+
+ # Generate sample images for visual inspection
+ if accelerator.is_main_process:
+ if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ generator = torch.Generator(device=pipeline.device).manual_seed(0)
+ # run pipeline in inference (sample random noise and denoise)
+ images = pipeline(
+ generator=generator,
+ batch_size=args.eval_batch_size,
+ num_inference_steps=args.ddpm_num_inference_steps,
+ output_type="numpy",
+ ).images
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ # denormalize the images and save to tensorboard
+ images_processed = (images * 255).round().astype("uint8")
+
+ if args.logger == "tensorboard":
+ if is_accelerate_version(">=", "0.17.0.dev0"):
+ tracker = accelerator.get_tracker("tensorboard", unwrap=True)
+ else:
+ tracker = accelerator.get_tracker("tensorboard")
+ tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
+ elif args.logger == "wandb":
+ # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
+ accelerator.get_tracker("wandb").log(
+ {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
+ step=global_step,
+ )
+
+ if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
+ # save the model
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message=f"Epoch {epoch}",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/rdm/README.md b/diffusers/examples/research_projects/rdm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cfd755e9c9c3d030953fdc0c40834c0d5f51c04d
--- /dev/null
+++ b/diffusers/examples/research_projects/rdm/README.md
@@ -0,0 +1,5 @@
+## Diffusers examples with ONNXRuntime optimizations
+
+**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Isamu Isozaki(isamu-isozaki) on github with any questions.**
+
+The aim of this project is to provide retrieval augmented diffusion models to diffusers!
\ No newline at end of file
diff --git a/diffusers/examples/research_projects/rdm/pipeline_rdm.py b/diffusers/examples/research_projects/rdm/pipeline_rdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b4cacb831917562b23fba9a68abf998dafde0f
--- /dev/null
+++ b/diffusers/examples/research_projects/rdm/pipeline_rdm.py
@@ -0,0 +1,453 @@
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from PIL import Image
+from retriever import Retriever, normalize_images, preprocess_images
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ ImagePipelineOutput,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ UNet2DConditionModel,
+ logging,
+)
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.utils import is_accelerate_available, randn_tensor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class RDMPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Retrieval Augmented Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ clip ([`CLIPModel`]):
+ Frozen CLIP model. Retrieval Augmented Diffusion uses the CLIP model, specifically the
+ [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ clip: CLIPModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ feature_extractor: CLIPFeatureExtractor,
+ retriever: Optional[Retriever] = None,
+ ):
+ super().__init__()
+ self.register_modules(
+ vae=vae,
+ clip=clip,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+ # Copy from statement here and all the methods we take from stable_diffusion_pipeline
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.retriever = retriever
+
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(True)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(False)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ if isinstance(self.unet.config.attention_head_dim, int):
+ slice_size = self.unet.config.attention_head_dim // 2
+ else:
+ slice_size = self.unet.config.attention_head_dim[0] // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.clip, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt):
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ prompt_embeds = self.clip.get_text_features(text_input_ids.to(self.device))
+ prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)
+ prompt_embeds = prompt_embeds[:, None, :]
+ return prompt_embeds
+
+ def _encode_image(self, retrieved_images, batch_size):
+ if len(retrieved_images[0]) == 0:
+ return None
+ for i in range(len(retrieved_images)):
+ retrieved_images[i] = normalize_images(retrieved_images[i])
+ retrieved_images[i] = preprocess_images(retrieved_images[i], self.feature_extractor).to(
+ self.clip.device, dtype=self.clip.dtype
+ )
+ _, c, h, w = retrieved_images[0].shape
+
+ retrieved_images = torch.reshape(torch.cat(retrieved_images, dim=0), (-1, c, h, w))
+ image_embeddings = self.clip.get_image_features(retrieved_images)
+ image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)
+ _, d = image_embeddings.shape
+ image_embeddings = torch.reshape(image_embeddings, (batch_size, -1, d))
+ return image_embeddings
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def retrieve_images(self, retrieved_images, prompt_embeds, knn=10):
+ if self.retriever is not None:
+ additional_images = self.retriever.retrieve_imgs_batch(prompt_embeds[:, 0].cpu(), knn).total_examples
+ for i in range(len(retrieved_images)):
+ retrieved_images[i] += additional_images[i][self.retriever.config.image_column]
+ return retrieved_images
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ retrieved_images: Optional[List[Image.Image]] = None,
+ height: int = 768,
+ width: int = 768,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ knn: Optional[int] = 10,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if retrieved_images is not None:
+ retrieved_images = [retrieved_images for _ in range(batch_size)]
+ else:
+ retrieved_images = [[] for _ in range(batch_size)]
+ device = self._execution_device
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if prompt_embeds is None:
+ prompt_embeds = self._encode_prompt(prompt)
+ retrieved_images = self.retrieve_images(retrieved_images, prompt_embeds, knn=knn)
+ image_embeddings = self._encode_image(retrieved_images, batch_size)
+ if image_embeddings is not None:
+ prompt_embeds = torch.cat([prompt_embeds, image_embeddings], dim=1)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_embeddings = torch.zeros_like(prompt_embeds).to(prompt_embeds.device)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([uncond_embeddings, prompt_embeds])
+ # get the initial random noise unless the user supplied it
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+
+ image = self.image_processor.postprocess(
+ image, output_type=output_type, do_denormalize=[True] * image.shape[0]
+ )
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/examples/research_projects/rdm/retriever.py b/diffusers/examples/research_projects/rdm/retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..16518ed1bc42f85565b584bf11b843d00dc220bc
--- /dev/null
+++ b/diffusers/examples/research_projects/rdm/retriever.py
@@ -0,0 +1,250 @@
+import os
+from typing import List
+
+import faiss
+import numpy as np
+import torch
+from datasets import Dataset, load_dataset
+from PIL import Image
+from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig
+
+from diffusers import logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def normalize_images(images: List[Image.Image]):
+ images = [np.array(image) for image in images]
+ images = [image / 127.5 - 1 for image in images]
+ return images
+
+
+def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.FloatTensor:
+ """
+ Preprocesses a list of images into a batch of tensors.
+
+ Args:
+ images (:obj:`List[Image.Image]`):
+ A list of images to preprocess.
+
+ Returns:
+ :obj:`torch.FloatTensor`: A batch of tensors.
+ """
+ images = [np.array(image) for image in images]
+ images = [(image + 1.0) / 2.0 for image in images]
+ images = feature_extractor(images, return_tensors="pt").pixel_values
+ return images
+
+
+class IndexConfig(PretrainedConfig):
+ def __init__(
+ self,
+ clip_name_or_path="openai/clip-vit-large-patch14",
+ dataset_name="Isamu136/oxford_pets_with_l14_emb",
+ image_column="image",
+ index_name="embeddings",
+ index_path=None,
+ dataset_set="train",
+ metric_type=faiss.METRIC_L2,
+ faiss_device=-1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.clip_name_or_path = clip_name_or_path
+ self.dataset_name = dataset_name
+ self.image_column = image_column
+ self.index_name = index_name
+ self.index_path = index_path
+ self.dataset_set = dataset_set
+ self.metric_type = metric_type
+ self.faiss_device = faiss_device
+
+
+class Index:
+ """
+ Each index for a retrieval model is specific to the clip model used and the dataset used.
+ """
+
+ def __init__(self, config: IndexConfig, dataset: Dataset):
+ self.config = config
+ self.dataset = dataset
+ self.index_initialized = False
+ self.index_name = config.index_name
+ self.index_path = config.index_path
+ self.init_index()
+
+ def set_index_name(self, index_name: str):
+ self.index_name = index_name
+
+ def init_index(self):
+ if not self.index_initialized:
+ if self.index_path and self.index_name:
+ try:
+ self.dataset.add_faiss_index(
+ column=self.index_name, metric_type=self.config.metric_type, device=self.config.faiss_device
+ )
+ self.index_initialized = True
+ except Exception as e:
+ print(e)
+ logger.info("Index not initialized")
+ if self.index_name in self.dataset.features:
+ self.dataset.add_faiss_index(column=self.index_name)
+ self.index_initialized = True
+
+ def build_index(
+ self,
+ model=None,
+ feature_extractor: CLIPFeatureExtractor = None,
+ torch_dtype=torch.float32,
+ ):
+ if not self.index_initialized:
+ model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
+ feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained(
+ self.config.clip_name_or_path
+ )
+ self.dataset = get_dataset_with_emb_from_clip_model(
+ self.dataset,
+ model,
+ feature_extractor,
+ image_column=self.config.image_column,
+ index_name=self.config.index_name,
+ )
+ self.init_index()
+
+ def retrieve_imgs(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.get_nearest_examples(self.index_name, vec, k=k)
+
+ def retrieve_imgs_batch(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.get_nearest_examples_batch(self.index_name, vec, k=k)
+
+ def retrieve_indices(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.search(self.index_name, vec, k=k)
+
+ def retrieve_indices_batch(self, vec, k: int = 20):
+ vec = np.array(vec).astype(np.float32)
+ return self.dataset.search_batch(self.index_name, vec, k=k)
+
+
+class Retriever:
+ def __init__(
+ self,
+ config: IndexConfig,
+ index: Index = None,
+ dataset: Dataset = None,
+ model=None,
+ feature_extractor: CLIPFeatureExtractor = None,
+ ):
+ self.config = config
+ self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ retriever_name_or_path: str,
+ index: Index = None,
+ dataset: Dataset = None,
+ model=None,
+ feature_extractor: CLIPFeatureExtractor = None,
+ **kwargs,
+ ):
+ config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
+ return cls(config, index=index, dataset=dataset, model=model, feature_extractor=feature_extractor)
+
+ @staticmethod
+ def _build_index(
+ config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None
+ ):
+ dataset = dataset or load_dataset(config.dataset_name)
+ dataset = dataset[config.dataset_set]
+ index = Index(config, dataset)
+ index.build_index(model=model, feature_extractor=feature_extractor)
+ return index
+
+ def save_pretrained(self, save_directory):
+ os.makedirs(save_directory, exist_ok=True)
+ if self.config.index_path is None:
+ index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
+ self.index.dataset.get_index(self.config.index_name).save(index_path)
+ self.config.index_path = index_path
+ self.config.save_pretrained(save_directory)
+
+ def init_retrieval(self):
+ logger.info("initializing retrieval")
+ self.index.init_index()
+
+ def retrieve_imgs(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_imgs(embeddings, k)
+
+ def retrieve_imgs_batch(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_imgs_batch(embeddings, k)
+
+ def retrieve_indices(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_indices(embeddings, k)
+
+ def retrieve_indices_batch(self, embeddings: np.ndarray, k: int):
+ return self.index.retrieve_indices_batch(embeddings, k)
+
+ def __call__(
+ self,
+ embeddings,
+ k: int = 20,
+ ):
+ return self.index.retrieve_imgs(embeddings, k)
+
+
+def map_txt_to_clip_feature(clip_model, tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > tokenizer.model_max_length:
+ removed_text = tokenizer.batch_decode(text_input_ids[:, tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : tokenizer.model_max_length]
+ text_embeddings = clip_model.get_text_features(text_input_ids.to(clip_model.device))
+ text_embeddings = text_embeddings / torch.linalg.norm(text_embeddings, dim=-1, keepdim=True)
+ text_embeddings = text_embeddings[:, None, :]
+ return text_embeddings[0][0].cpu().detach().numpy()
+
+
+def map_img_to_model_feature(model, feature_extractor, imgs, device):
+ for i, image in enumerate(imgs):
+ if not image.mode == "RGB":
+ imgs[i] = image.convert("RGB")
+ imgs = normalize_images(imgs)
+ retrieved_images = preprocess_images(imgs, feature_extractor).to(device)
+ image_embeddings = model(retrieved_images)
+ image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)
+ image_embeddings = image_embeddings[None, ...]
+ return image_embeddings.cpu().detach().numpy()[0][0]
+
+
+def get_dataset_with_emb_from_model(dataset, model, feature_extractor, image_column="image", index_name="embeddings"):
+ return dataset.map(
+ lambda example: {
+ index_name: map_img_to_model_feature(model, feature_extractor, [example[image_column]], model.device)
+ }
+ )
+
+
+def get_dataset_with_emb_from_clip_model(
+ dataset, clip_model, feature_extractor, image_column="image", index_name="embeddings"
+):
+ return dataset.map(
+ lambda example: {
+ index_name: map_img_to_model_feature(
+ clip_model.get_image_features, feature_extractor, [example[image_column]], clip_model.device
+ )
+ }
+ )
diff --git a/diffusers/examples/research_projects/realfill/README.md b/diffusers/examples/research_projects/realfill/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b70f425368e56c5df34b897f1df2c6bc00d9396d
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/README.md
@@ -0,0 +1,118 @@
+# RealFill
+
+[RealFill](https://arxiv.org/abs/2309.16668) is a method to personalize text2image inpainting models like stable diffusion inpainting given just a few(1~5) images of a scene.
+The `train_realfill.py` script shows how to implement the training procedure for stable diffusion inpainting.
+
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+cd to the realfill folder and run
+```bash
+cd realfill
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+### Toy example
+
+Now let's fill the real. For this example, we will use some images of the flower girl example from the paper.
+
+We already provide some images for testing in [this link](https://github.com/thuanz123/realfill/tree/main/data/flowerwoman)
+
+You only have to launch the training using:
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting"
+export TRAIN_DIR="data/flowerwoman"
+export OUTPUT_DIR="flowerwoman-model"
+
+accelerate launch train_realfill.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=16 \
+ --gradient_accumulation_steps=1 \
+ --unet_learning_rate=2e-4 \
+ --text_encoder_learning_rate=4e-5 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=2000 \
+ --lora_rank=8 \
+ --lora_dropout=0.1 \
+ --lora_alpha=16 \
+```
+
+### Training on a low-memory GPU:
+
+It is possible to run realfill on a low-memory GPU by using the following optimizations:
+- [gradient checkpointing and the 8-bit optimizer](#training-with-gradient-checkpointing-and-8-bit-optimizers)
+- [xformers](#training-with-xformers)
+- [setting grads to none](#set-grads-to-none)
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-2-inpainting"
+export TRAIN_DIR="data/flowerwoman"
+export OUTPUT_DIR="flowerwoman-model"
+
+accelerate launch train_realfill.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --resolution=512 \
+ --train_batch_size=16 \
+ --gradient_accumulation_steps=1 --gradient_checkpointing \
+ --use_8bit_adam \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+ --unet_learning_rate=2e-4 \
+ --text_encoder_learning_rate=4e-5 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=100 \
+ --max_train_steps=2000 \
+ --lora_rank=8 \
+ --lora_dropout=0.1 \
+ --lora_alpha=16 \
+```
+
+### Training with gradient checkpointing and 8-bit optimizers:
+
+With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train realfill on a 16GB GPU.
+
+To install `bitsandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script.
+
+### Set grads to none
+
+To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
+
+More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
+
+## Acknowledge
+This repo is built upon the code of DreamBooth from diffusers and we thank the developers for their great works and efforts to release source code. Furthermore, a special "thank you" to RealFill's authors for publishing such an amazing work.
diff --git a/diffusers/examples/research_projects/realfill/infer.py b/diffusers/examples/research_projects/realfill/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3153307c4ad31892df7607d3ecf9b2af79b65187
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/infer.py
@@ -0,0 +1,91 @@
+import argparse
+import os
+
+import torch
+from PIL import Image, ImageFilter
+from transformers import CLIPTextModel
+
+from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
+
+
+parser = argparse.ArgumentParser(description="Inference")
+parser.add_argument(
+ "--model_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+)
+parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ required=True,
+ help="The directory of the validation image",
+)
+parser.add_argument(
+ "--validation_mask",
+ type=str,
+ default=None,
+ required=True,
+ help="The directory of the validation mask",
+)
+parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./test-infer/",
+ help="The output directory where predictions are saved",
+)
+parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible inference.")
+
+args = parser.parse_args()
+
+if __name__ == "__main__":
+ os.makedirs(args.output_dir, exist_ok=True)
+ generator = None
+
+ # create & load model
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32, revision=None
+ )
+
+ pipe.unet = UNet2DConditionModel.from_pretrained(
+ args.model_path,
+ subfolder="unet",
+ revision=None,
+ )
+ pipe.text_encoder = CLIPTextModel.from_pretrained(
+ args.model_path,
+ subfolder="text_encoder",
+ revision=None,
+ )
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+ pipe = pipe.to("cuda")
+
+ if args.seed is not None:
+ generator = torch.Generator(device="cuda").manual_seed(args.seed)
+
+ image = Image.open(args.validation_image)
+ mask_image = Image.open(args.validation_mask)
+
+ results = pipe(
+ ["a photo of sks"] * 16,
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=25,
+ guidance_scale=5,
+ generator=generator,
+ ).images
+
+ erode_kernel = ImageFilter.MaxFilter(3)
+ mask_image = mask_image.filter(erode_kernel)
+
+ blur_kernel = ImageFilter.BoxBlur(1)
+ mask_image = mask_image.filter(blur_kernel)
+
+ for idx, result in enumerate(results):
+ result = Image.composite(result, image, mask_image)
+ result.save(f"{args.output_dir}/{idx}.png")
+
+ del pipe
+ torch.cuda.empty_cache()
diff --git a/diffusers/examples/research_projects/realfill/requirements.txt b/diffusers/examples/research_projects/realfill/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3827f0852a2089703ba9c433e024f09515aa37cf
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/requirements.txt
@@ -0,0 +1,9 @@
+diffusers==0.20.1
+accelerate==0.23.0
+transformers==4.34.0
+peft==0.5.0
+torch==2.0.1
+torchvision>=0.16
+ftfy==6.1.1
+tensorboard==2.14.0
+Jinja2==3.1.2
diff --git a/diffusers/examples/research_projects/realfill/train_realfill.py b/diffusers/examples/research_projects/realfill/train_realfill.py
new file mode 100644
index 0000000000000000000000000000000000000000..e251d8d1769caea8e9e5a5d903a2a582533c01a9
--- /dev/null
+++ b/diffusers/examples/research_projects/realfill/train_realfill.py
@@ -0,0 +1,977 @@
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.v2 as transforms_v2
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, PeftModel, get_peft_model
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ StableDiffusionInpaintPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.20.1")
+
+logger = get_logger(__name__)
+
+
+def make_mask(images, resolution, times=30):
+ mask, times = torch.ones_like(images[0:1, :, :]), np.random.randint(1, times)
+ min_size, max_size, margin = np.array([0.03, 0.25, 0.01]) * resolution
+ max_size = min(max_size, resolution - margin * 2)
+
+ for _ in range(times):
+ width = np.random.randint(int(min_size), int(max_size))
+ height = np.random.randint(int(min_size), int(max_size))
+
+ x_start = np.random.randint(int(margin), resolution - int(margin) - width + 1)
+ y_start = np.random.randint(int(margin), resolution - int(margin) - height + 1)
+ mask[:, y_start : y_start + height, x_start : x_start + width] = 0
+
+ mask = 1 - mask if random.random() < 0.5 else mask
+ return mask
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ repo_folder=None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+prompt: "a photo of sks"
+tags:
+- stable-diffusion-inpainting
+- stable-diffusion-inpainting-diffusers
+- text-to-image
+- diffusers
+- realfill
+inference: true
+---
+ """
+ model_card = f"""
+# RealFill - {repo_id}
+
+This is a realfill model derived from {base_model}. The weights were trained using [RealFill](https://realfill.github.io/).
+You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ epoch,
+):
+ logger.info(f"Running validation... \nGenerating {args.num_validation_images} images")
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+
+ # set `keep_fp32_wrapper` to True because we do not want to remove
+ # mixed precision hooks while we are still training
+ pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
+ pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ target_dir = Path(args.train_data_dir) / "target"
+ target_image, target_mask = target_dir / "target.png", target_dir / "mask.png"
+ image, mask_image = Image.open(target_image), Image.open(target_mask)
+
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+
+ images = []
+ for _ in range(args.num_validation_images):
+ image = pipeline(
+ prompt="a photo of sks",
+ image=image,
+ mask_image=mask_image,
+ num_inference_steps=25,
+ guidance_scale=5,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log({"validation": [wandb.Image(image, caption=str(i)) for i, image in enumerate(images)]})
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of images.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_conditioning`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run realfill validation every X steps. RealFill validation consists of running the conditioning"
+ " `args.validation_conditioning` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="realfill-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--unet_learning_rate",
+ type=float,
+ default=2e-4,
+ help="Learning rate to use for unet.",
+ )
+ parser.add_argument(
+ "--text_encoder_learning_rate",
+ type=float,
+ default=4e-5,
+ help="Learning rate to use for text encoder.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--wandb_key",
+ type=str,
+ default=None,
+ help=("If report to option is set to wandb, api-key for wandb used for login to wandb "),
+ )
+ parser.add_argument(
+ "--wandb_project_name",
+ type=str,
+ default=None,
+ help=("If report to option is set to wandb, project name in wandb for log tracking "),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=16,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=27,
+ help=("The alpha constant of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout rate of the LoRA update matrices.",
+ )
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="The bias type of the Lora update matrices. Must be 'none', 'all' or 'lora_only'.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class RealFillDataset(Dataset):
+ """
+ A dataset to prepare the training and conditioning images and
+ the masks with the dummy prompt for fine-tuning the model.
+ It pre-processes the images, masks and tokenizes the prompts.
+ """
+
+ def __init__(
+ self,
+ train_data_root,
+ tokenizer,
+ size=512,
+ ):
+ self.size = size
+ self.tokenizer = tokenizer
+
+ self.ref_data_root = Path(train_data_root) / "ref"
+ self.target_image = Path(train_data_root) / "target" / "target.png"
+ self.target_mask = Path(train_data_root) / "target" / "mask.png"
+ if not (self.ref_data_root.exists() and self.target_image.exists() and self.target_mask.exists()):
+ raise ValueError("Train images root doesn't exists.")
+
+ self.train_images_path = list(self.ref_data_root.iterdir()) + [self.target_image]
+ self.num_train_images = len(self.train_images_path)
+ self.train_prompt = "a photo of sks"
+
+ self.transform = transforms_v2.Compose(
+ [
+ transforms_v2.ToImage(),
+ transforms_v2.RandomResize(size, int(1.125 * size)),
+ transforms_v2.RandomCrop(size),
+ transforms_v2.ToDtype(torch.float32, scale=True),
+ transforms_v2.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self.num_train_images
+
+ def __getitem__(self, index):
+ example = {}
+
+ image = Image.open(self.train_images_path[index])
+ image = exif_transpose(image)
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ if index < len(self) - 1:
+ weighting = Image.new("L", image.size)
+ else:
+ weighting = Image.open(self.target_mask)
+ weighting = exif_transpose(weighting)
+
+ image, weighting = self.transform(image, weighting)
+ example["images"], example["weightings"] = image, weighting < 0
+
+ if random.random() < 0.1:
+ example["masks"] = torch.ones_like(example["images"][0:1, :, :])
+ else:
+ example["masks"] = make_mask(example["images"], self.size)
+
+ example["conditioning_images"] = example["images"] * (example["masks"] < 0.5)
+
+ train_prompt = "" if random.random() < 0.1 else self.train_prompt
+ example["prompt_ids"] = self.tokenizer(
+ train_prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ return example
+
+
+def collate_fn(examples):
+ input_ids = [example["prompt_ids"] for example in examples]
+ images = [example["images"] for example in examples]
+
+ masks = [example["masks"] for example in examples]
+ weightings = [example["weightings"] for example in examples]
+ conditioning_images = [example["conditioning_images"] for example in examples]
+
+ images = torch.stack(images)
+ images = images.to(memory_format=torch.contiguous_format).float()
+
+ masks = torch.stack(masks)
+ masks = masks.to(memory_format=torch.contiguous_format).float()
+
+ weightings = torch.stack(weightings)
+ weightings = weightings.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_images = torch.stack(conditioning_images)
+ conditioning_images = conditioning_images.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "images": images,
+ "masks": masks,
+ "weightings": weightings,
+ "conditioning_images": conditioning_images,
+ }
+ return batch
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_dir=logging_dir,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ wandb.login(key=args.wandb_key)
+ wandb.init(project=args.wandb_project_name)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ config = LoraConfig(
+ r=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=["to_k", "to_q", "to_v", "key", "query", "value"],
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ )
+ unet = get_peft_model(unet, config)
+
+ config = LoraConfig(
+ r=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=["k_proj", "q_proj", "v_proj"],
+ lora_dropout=args.lora_dropout,
+ bias=args.lora_bias,
+ )
+ text_encoder = get_peft_model(text_encoder, config)
+
+ vae.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ text_encoder.gradient_checkpointing_enable()
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = (
+ "unet"
+ if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
+ else "text_encoder"
+ )
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ sub_dir = (
+ "unet"
+ if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
+ else "text_encoder"
+ )
+ model_cls = (
+ UNet2DConditionModel
+ if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
+ else CLIPTextModel
+ )
+
+ load_model = model_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder=sub_dir)
+ load_model = PeftModel.from_pretrained(load_model, input_dir, subfolder=sub_dir)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.unet_learning_rate = (
+ args.unet_learning_rate
+ * args.gradient_accumulation_steps
+ * args.train_batch_size
+ * accelerator.num_processes
+ )
+
+ args.text_encoder_learning_rate = (
+ args.text_encoder_learning_rate
+ * args.gradient_accumulation_steps
+ * args.train_batch_size
+ * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ optimizer = optimizer_class(
+ [
+ {"params": unet.parameters(), "lr": args.unet_learning_rate},
+ {"params": text_encoder.parameters(), "lr": args.text_encoder_learning_rate},
+ ],
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = RealFillDataset(
+ train_data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=1,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader
+ )
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae to device and cast to weight_dtype
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ accelerator.init_trackers("realfill", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ text_encoder.train()
+
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet, text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Convert masked images to latent space
+ conditionings = vae.encode(batch["conditioning_images"].to(dtype=weight_dtype)).latent_dist.sample()
+ conditionings = conditionings * 0.18215
+
+ # Downsample mask and weighting so that they match with the latents
+ masks, size = batch["masks"].to(dtype=weight_dtype), latents.shape[2:]
+ masks = F.interpolate(masks, size=size)
+
+ weightings = batch["weightings"].to(dtype=weight_dtype)
+ weightings = F.interpolate(weightings, size=size)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Concatenate noisy latents, masks and conditionings to get inputs to unet
+ inputs = torch.cat([noisy_latents, masks, conditionings], dim=1)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(inputs, timesteps, encoder_hidden_states).sample
+
+ # Compute the diffusion loss
+ assert noise_scheduler.config.prediction_type == "epsilon"
+ loss = (weightings * F.mse_loss(model_pred.float(), noise.float(), reduction="none")).mean()
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ if args.report_to == "wandb":
+ accelerator.print(progress_bar)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item()}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(),
+ text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(),
+ revision=args.revision,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ # Final inference
+ images = log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/research_projects/sdxl_flax/README.md b/diffusers/examples/research_projects/sdxl_flax/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..612fdf1edd43dc343d0e57e8f4747790bb3ae444
--- /dev/null
+++ b/diffusers/examples/research_projects/sdxl_flax/README.md
@@ -0,0 +1,243 @@
+# Stable Diffusion XL for JAX + TPUv5e
+
+[TPU v5e](https://cloud.google.com/blog/products/compute/how-cloud-tpu-v5e-accelerates-large-scale-ai-inference) is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.
+
+[JAX](https://github.com/google/jax) is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:
+
+- **High performance**. All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) - the Accelerated Linear Algebra compiler
+
+- **Compilation**. JAX uses just-in-time (jit) compilation of JAX Python functions so it can be executed efficiently in XLA. In order to get the best performance, we must use static shapes for jitted functions, this is because JAX transforms work by tracing a function and to determine its effect on inputs of a specific shape and type. When a new shape is introduced to an already compiled function, it retriggers compilation on the new shape, which can greatly reduce performance. **Note**: JIT compilation is particularly well-suited for text-to-image generation because all inputs and outputs (image input / output sizes) are static.
+
+- **Parallelization**. Workloads can be scaled across multiple devices using JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute in parallel on XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn't compromise performance.
+
+👉 Try it out for yourself:
+
+[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/google/sdxl)
+
+## Stable Diffusion XL pipeline in JAX
+
+Upon having access to a TPU VM (TPUs higher than version 3), you should first install
+a TPU-compatible version of JAX:
+```
+pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+Next, we can install [flax](https://github.com/google/flax) and the diffusers library:
+
+```
+pip install flax diffusers transformers
+```
+
+In [sdxl_single.py](./sdxl_single.py) we give a simple example of how to write a text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
+
+Let's explain it step-by-step:
+
+**Imports and Setup**
+
+```python
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+from diffusers import FlaxStableDiffusionXLPipeline
+
+from jax.experimental.compilation_cache import compilation_cache as cc
+cc.initialize_cache("/tmp/sdxl_cache")
+import time
+
+NUM_DEVICES = jax.device_count()
+```
+
+First, we import the necessary libraries:
+- `jax` is provides the primitives for TPU operations
+- `flax.jax_utils` contains some useful utility functions for `Flax`, a neural network library built on top of JAX
+- `diffusers` has all the code that is relevant for SDXL.
+- We also initialize a cache to speed up the JAX model compilation.
+- We automatically determine the number of available TPU devices.
+
+**1. Downloading Model and Loading Pipeline**
+
+```python
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+```
+Here, a pre-trained model `stable-diffusion-xl-base-1.0` from the namespace `stabilityai` is loaded. It returns a pipeline for inference and its parameters.
+
+**2. Casting Parameter Types**
+
+```python
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+```
+This section adjusts the data types of the model parameters.
+We convert all parameters to `bfloat16` to speed-up the computation with model weights.
+**Note** that the scheduler parameters are **not** converted to `blfoat16` as the loss
+in precision is degrading the pipeline's performance too significantly.
+
+**3. Define Inputs to Pipeline**
+
+```python
+default_prompt = ...
+default_neg_prompt = ...
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+```
+Here, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.
+
+**4. Tokenizing Inputs**
+
+```python
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+```
+This function tokenizes the given prompts. It's essential because the text encoders of SDXL don't understand raw text; they work with numbers. Tokenization converts text to numbers.
+
+**5. Parallelization and Replication**
+
+```python
+p_params = replicate(params)
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ ...
+```
+To utilize JAX's parallel capabilities, the parameters and input tensors are duplicated across devices. The `replicate_all` function also ensures that every device produces a different image by creating a unique random seed for each device.
+
+**6. Putting Everything Together**
+
+```python
+def generate(...):
+ ...
+```
+This function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that's more interpretable (PIL format).
+
+**7. Compilation Step**
+
+```python
+start = time.time()
+print(f"Compiling ...")
+generate(default_prompt, default_neg_prompt)
+print(f"Compiled in {time.time() - start}")
+```
+The initial run of the `generate` function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.
+
+**8. Fast Inference**
+
+```python
+start = time.time()
+prompt = ...
+neg_prompt = ...
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+```
+Now that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.
+
+In summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX's capabilities.
+
+## Ahead of Time (AOT) Compilation
+
+FlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let's build parallelization ourselves.
+
+For this we will be using a JAX feature called [Ahead of Time](https://jax.readthedocs.io/en/latest/aot.html) (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.
+
+In [sdxl_single_aot.py](./sdxl_single_aot.py) we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0)
+
+We add a `aot_compile` function that compiles the `pipeline._generate` function
+telling JAX which input arguments are static, that is, arguments that
+are known at compile time and won't change. In our case, it is num_inference_steps,
+height, width and return_latents.
+
+Once the function is compiled, these parameters are omitted from future calls and
+cannot be changed without modifying the code and recompiling.
+
+```python
+def aot_compile(
+ prompt=default_prompt,
+ negative_prompt=default_neg_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+
+ return pmap(
+ pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]
+ ).lower(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps, # num_inference_steps
+ height, # height
+ width, # width
+ g,
+ None,
+ neg_prompt_ids,
+ False # return_latents
+ ).compile()
+````
+
+Next we can compile the generate function by executing `aot_compile`.
+
+```python
+start = time.time()
+print("Compiling ...")
+p_generate = aot_compile()
+print(f"Compiled in {time.time() - start}")
+```
+And again we put everything together in a `generate` function.
+
+```python
+def generate(
+ prompt,
+ negative_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+ images = p_generate(
+ prompt_ids,
+ p_params,
+ rng,
+ g,
+ None,
+ neg_prompt_ids)
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+```
+
+The first forward pass after AOT compilation still takes a while longer than
+subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
+Fills the C++ dispatch cache.
+When using jit, this extra step is done automatically, but when using AOT compilation,
+it doesn't happen until the function call is made.
+
+```python
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"First inference in {time.time() - start}")
+```
+
+From this point forward, any calls to generate should result in a faster inference
+time and it won't change.
+
+```python
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+```
diff --git a/diffusers/examples/research_projects/sdxl_flax/sdxl_single.py b/diffusers/examples/research_projects/sdxl_flax/sdxl_single.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9b862d99b5c2d583d28395ccb2191e50bef7a4
--- /dev/null
+++ b/diffusers/examples/research_projects/sdxl_flax/sdxl_single.py
@@ -0,0 +1,106 @@
+# Show best practices for SDXL JAX
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+
+# Let's cache the model compilation, so that it doesn't take as long the next time around.
+from jax.experimental.compilation_cache import compilation_cache as cc
+
+from diffusers import FlaxStableDiffusionXLPipeline
+
+
+cc.initialize_cache("/tmp/sdxl_cache")
+
+
+NUM_DEVICES = jax.device_count()
+
+# 1. Let's start by downloading the model and loading it into our pipeline class
+# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
+# will have to be passed to the pipeline during inference
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+
+# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
+# float32 to keep maximal precision
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+
+# 3. Next, we define the different inputs to the pipeline
+default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
+default_neg_prompt = "fog, grainy, purple"
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+
+
+# 4. In order to be able to compile the pipeline
+# all inputs have to be tensors or strings
+# Let's tokenize the prompt and negative prompt
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+
+
+# 5. To make full use of JAX's parallelization capabilities
+# the parameters and input tensors are duplicated across devices
+# To make sure every device generates a different image, we create
+# different seeds for each image. The model parameters won't change
+# during inference so we do not wrap them into a function
+p_params = replicate(params)
+
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ p_prompt_ids = replicate(prompt_ids)
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
+ rng = jax.random.PRNGKey(seed)
+ rng = jax.random.split(rng, NUM_DEVICES)
+ return p_prompt_ids, p_neg_prompt_ids, rng
+
+
+# 6. Let's now put it all together in a generate function
+def generate(
+ prompt,
+ negative_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps,
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ images = pipeline(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps=num_inference_steps,
+ neg_prompt_ids=neg_prompt_ids,
+ guidance_scale=guidance_scale,
+ jit=True,
+ ).images
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+
+
+# 7. Remember that the first call will compile the function and hence be very slow. Let's run generate once
+# so that the pipeline call is compiled
+start = time.time()
+print("Compiling ...")
+generate(default_prompt, default_neg_prompt)
+print(f"Compiled in {time.time() - start}")
+
+# 8. Now the model forward pass will run very quickly, let's try it again
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+
+for i, image in enumerate(images):
+ image.save(f"castle_{i}.png")
diff --git a/diffusers/examples/research_projects/sdxl_flax/sdxl_single_aot.py b/diffusers/examples/research_projects/sdxl_flax/sdxl_single_aot.py
new file mode 100644
index 0000000000000000000000000000000000000000..58447fd86daf21d0bf98ed05698986ddafc231d9
--- /dev/null
+++ b/diffusers/examples/research_projects/sdxl_flax/sdxl_single_aot.py
@@ -0,0 +1,143 @@
+import time
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.jax_utils import replicate
+from jax import pmap
+
+# Let's cache the model compilation, so that it doesn't take as long the next time around.
+from jax.experimental.compilation_cache import compilation_cache as cc
+
+from diffusers import FlaxStableDiffusionXLPipeline
+
+
+cc.initialize_cache("/tmp/sdxl_cache")
+
+
+NUM_DEVICES = jax.device_count()
+
+# 1. Let's start by downloading the model and loading it into our pipeline class
+# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
+# will have to be passed to the pipeline during inference
+pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
+)
+
+# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
+# float32 to keep maximal precision
+scheduler_state = params.pop("scheduler")
+params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
+params["scheduler"] = scheduler_state
+
+# 3. Next, we define the different inputs to the pipeline
+default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
+default_neg_prompt = "fog, grainy, purple"
+default_seed = 33
+default_guidance_scale = 5.0
+default_num_steps = 25
+width = 1024
+height = 1024
+
+
+# 4. In order to be able to compile the pipeline
+# all inputs have to be tensors or strings
+# Let's tokenize the prompt and negative prompt
+def tokenize_prompt(prompt, neg_prompt):
+ prompt_ids = pipeline.prepare_inputs(prompt)
+ neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
+ return prompt_ids, neg_prompt_ids
+
+
+# 5. To make full use of JAX's parallelization capabilities
+# the parameters and input tensors are duplicated across devices
+# To make sure every device generates a different image, we create
+# different seeds for each image. The model parameters won't change
+# during inference so we do not wrap them into a function
+p_params = replicate(params)
+
+
+def replicate_all(prompt_ids, neg_prompt_ids, seed):
+ p_prompt_ids = replicate(prompt_ids)
+ p_neg_prompt_ids = replicate(neg_prompt_ids)
+ rng = jax.random.PRNGKey(seed)
+ rng = jax.random.split(rng, NUM_DEVICES)
+ return p_prompt_ids, p_neg_prompt_ids, rng
+
+
+# 6. To compile the pipeline._generate function, we must pass all parameters
+# to the function and tell JAX which are static arguments, that is, arguments that
+# are known at compile time and won't change. In our case, it is num_inference_steps,
+# height, width and return_latents.
+# Once the function is compiled, these parameters are ommited from future calls and
+# cannot be changed without modifying the code and recompiling.
+def aot_compile(
+ prompt=default_prompt,
+ negative_prompt=default_neg_prompt,
+ seed=default_seed,
+ guidance_scale=default_guidance_scale,
+ num_inference_steps=default_num_steps,
+):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+
+ return (
+ pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9])
+ .lower(
+ prompt_ids,
+ p_params,
+ rng,
+ num_inference_steps, # num_inference_steps
+ height, # height
+ width, # width
+ g,
+ None,
+ neg_prompt_ids,
+ False, # return_latents
+ )
+ .compile()
+ )
+
+
+start = time.time()
+print("Compiling ...")
+p_generate = aot_compile()
+print(f"Compiled in {time.time() - start}")
+
+
+# 7. Let's now put it all together in a generate function.
+def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale):
+ prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
+ prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
+ g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
+ g = g[:, None]
+ images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids)
+
+ # convert the images to PIL
+ images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
+ return pipeline.numpy_to_pil(np.array(images))
+
+
+# 8. The first forward pass after AOT compilation still takes a while longer than
+# subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
+# Fills the C++ dispatch cache.
+# When using jit, this extra step is done automatically, but when using AOT compilation,
+# it doesn't happen until the function call is made.
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"First inference in {time.time() - start}")
+
+# 9. From this point forward, any calls to generate should result in a faster inference
+# time and it won't change.
+start = time.time()
+prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
+neg_prompt = "cartoon, illustration, animation. face. male, female"
+images = generate(prompt, neg_prompt)
+print(f"Inference in {time.time() - start}")
+
+for i, image in enumerate(images):
+ image.save(f"castle_{i}.png")
diff --git a/diffusers/examples/t2i_adapter/README.md b/diffusers/examples/t2i_adapter/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7d7491950d0ec22ed6f9badd8307efee73789786
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/README.md
@@ -0,0 +1 @@
+We don't yet support training T2I-Adapters on Stable Diffusion yet. For training T2I-Adapters on Stable Diffusion XL, refer [here](./README_sdxl.md).
\ No newline at end of file
diff --git a/diffusers/examples/t2i_adapter/README_sdxl.md b/diffusers/examples/t2i_adapter/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..d583341c367f49e9c0d734a8b96473bdd59f79b6
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/README_sdxl.md
@@ -0,0 +1,131 @@
+# T2I-Adapter training example for Stable Diffusion XL (SDXL)
+
+The `train_t2i_adapter_sdxl.py` script shows how to implement the [T2I-Adapter training procedure](https://hf.co/papers/2302.08453) for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/t2i_adapter` folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+## Circle filling dataset
+
+The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
+
+## Training
+
+Our training examples use two test conditioning images. They can be downloaded by running
+
+```sh
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+export OUTPUT_DIR="path to save model"
+
+accelerate launch train_t2i_adapter_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --mixed_precision="fp16" \
+ --resolution=1024 \
+ --learning_rate=1e-5 \
+ --max_train_steps=15000 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=100 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --report_to="wandb" \
+ --seed=42 \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Inference
+
+Once training is done, we can perform inference like so:
+
+```python
+from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest
+from diffusers.utils import load_image
+import torch
+
+base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+adapter_path = "path to adapter"
+
+adapter = T2IAdapter.from_pretrained(adapter_path, torch_dtype=torch.float16)
+pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
+ base_model_path, adapter=adapter, torch_dtype=torch.float16
+)
+
+# speed up diffusion process with faster scheduler and memory optimization
+pipe.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)
+# remove following line if xformers is not installed or when using Torch 2.0.
+pipe.enable_xformers_memory_efficient_attention()
+# memory optimization.
+pipe.enable_model_cpu_offload()
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+# generate image
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt, num_inference_steps=20, generator=generator, image=control_image
+).images[0]
+image.save("./output.png")
+```
+
+## Notes
+
+### Specifying a better VAE
+
+SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
diff --git a/diffusers/examples/t2i_adapter/requirements.txt b/diffusers/examples/t2i_adapter/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2955535b192796e2618393012560a7b534fbea23
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/requirements.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+accelerate>=0.16.0
+safetensors
+datasets
+torchvision
+ftfy
+tensorboard
+wandb
\ No newline at end of file
diff --git a/diffusers/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/diffusers/examples/t2i_adapter/train_t2i_adapter_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1c9113bbd9d857d52ece64d1049253ab79c1e21
--- /dev/null
+++ b/diffusers/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -0,0 +1,1290 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ EulerDiscreteScheduler,
+ StableDiffusionXLAdapterPipeline,
+ T2IAdapter,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def image_grid(imgs, rows, cols):
+ assert len(imgs) == rows * cols
+
+ w, h = imgs[0].size
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ return grid
+
+
+def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ adapter = accelerator.unwrap_model(adapter)
+
+ pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=unet,
+ adapter=adapter,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ if len(args.validation_image) == len(args.validation_prompt):
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_image) == 1:
+ validation_images = args.validation_image * len(args.validation_prompt)
+ validation_prompts = args.validation_prompt
+ elif len(args.validation_prompt) == 1:
+ validation_images = args.validation_image
+ validation_prompts = args.validation_prompt * len(args.validation_image)
+ else:
+ raise ValueError(
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
+ )
+
+ image_logs = []
+
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
+ validation_image = Image.open(validation_image).convert("RGB")
+ validation_image = validation_image.resize((args.resolution, args.resolution))
+
+ images = []
+
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
+ ).images[0]
+ images.append(image)
+
+ image_logs.append(
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images = []
+
+ formatted_images.append(np.asarray(validation_image))
+
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="adapter conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
+ img_str = ""
+ if image_logs is not None:
+ img_str = "You can find some example images below.\n"
+ for i, log in enumerate(image_logs):
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
+ img_str += f"prompt: {validation_prompt}\n"
+ images = [validation_image] + images
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ img_str += f"![images_{i})](./images_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- t2iadapter
+inference: true
+---
+ """
+ model_card = f"""
+# t2iadapter-{repo_id}
+
+These are t2iadapter weights trained on {base_model} with new type of conditioning.
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--adapter_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained adapter model or model identifier from huggingface.co/models."
+ " If not specified adapter weights are initialized w.r.t the configurations of SDXL.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="t2iadapter-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--detection_resolution",
+ type=int,
+ default=None,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=3,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=1,
+ help=("Number of subprocesses to use for data loading."),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the adapter conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the t2iadapter conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="sd_xl_train_t2iadapter",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the t2iadapter encoder."
+ )
+
+ return args
+
+
+def get_train_dataset(args, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ if args.train_data_dir is not None:
+ dataset = load_dataset(
+ args.train_data_dir,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.caption_column is None:
+ caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {caption_column}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ if args.conditioning_image_column is None:
+ conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ raise ValueError(
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+
+ with accelerator.main_process_first():
+ train_dataset = dataset["train"].shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def prepare_train_dataset(dataset, accelerator):
+ image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[args.image_column]]
+ images = [image_transforms(image) for image in images]
+
+ conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
+
+ examples["pixel_values"] = images
+ examples["conditioning_pixel_values"] = conditioning_images
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+
+def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+
+ add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples])
+ add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples])
+
+ return {
+ "pixel_values": pixel_values,
+ "conditioning_pixel_values": conditioning_pixel_values,
+ "prompt_ids": prompt_ids,
+ "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
+ }
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ if args.adapter_model_name_or_path:
+ logger.info("Loading existing adapter weights.")
+ t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)
+ else:
+ logger.info("Initializing t2iadapter weights.")
+ t2iadapter = T2IAdapter(
+ in_channels=3,
+ channels=(320, 640, 1280, 1280),
+ num_res_blocks=2,
+ downscale_factor=16,
+ adapter_type="full_adapter_xl",
+ )
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ i = len(weights) - 1
+
+ while len(weights) > 0:
+ weights.pop()
+ model = models[i]
+
+ sub_dir = "t2iadapter"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ i -= 1
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = T2IAdapter.from_pretrained(os.path.join(input_dir, "t2iadapter"))
+
+ if args.control_type != "style":
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ t2iadapter.train()
+ unet.train()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(t2iadapter).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = t2iadapter.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ else:
+ vae.to(accelerator.device, dtype=torch.float32)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):
+ original_size = (args.resolution, args.resolution)
+ target_size = (args.resolution, args.resolution)
+ crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
+ prompt_batch = batch[args.caption_column]
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ train_dataset = get_train_dataset(args, accelerator)
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+
+ # Then get the training dataset ready to be passed to the dataloader.
+ train_dataset = prepare_train_dataset(train_dataset, accelerator)
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ t2iadapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ t2iadapter, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ image_logs = None
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(t2iadapter):
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+
+ # encode pixel values with batch size of at most 8 to avoid OOM
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Cubic sampling to sample a random timestep for each image.
+ # For more details about why cubic sampling is used, refer to section 3.4 of https://arxiv.org/abs/2302.08453
+ timesteps = torch.rand((bsz,), device=latents.device)
+ timesteps = (1 - timesteps**3) * noise_scheduler.config.num_train_timesteps
+ timesteps = timesteps.long().to(noise_scheduler.timesteps.dtype)
+ timesteps = timesteps.clamp(0, noise_scheduler.config.num_train_timesteps - 1)
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Scale the noisy latents for the UNet
+ sigmas = get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
+ inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)
+
+ # Adapter conditioning.
+ t2iadapter_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+ down_block_additional_residuals = t2iadapter(t2iadapter_image)
+ down_block_additional_residuals = [
+ sample.to(dtype=weight_dtype) for sample in down_block_additional_residuals
+ ]
+
+ # Predict the noise residual
+ model_pred = unet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states=batch["prompt_ids"],
+ added_cond_kwargs=batch["unet_added_conditions"],
+ down_block_additional_residuals=down_block_additional_residuals,
+ ).sample
+
+ # Denoise the latents
+ denoised_latents = model_pred * (-sigmas) + noisy_latents
+ weighing = sigmas**-2.0
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = latents # we are computing loss against denoise latents
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # MSE loss
+ loss = torch.mean(
+ (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ dim=1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = t2iadapter.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ vae,
+ unet,
+ t2iadapter,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ t2iadapter = accelerator.unwrap_model(t2iadapter)
+ t2iadapter.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ image_logs=image_logs,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/test_examples.py b/diffusers/examples/test_examples.py
new file mode 100644
index 0000000000000000000000000000000000000000..292c433a3395d8e8d6a8fd32b4fe7844f44854ca
--- /dev/null
+++ b/diffusers/examples/test_examples.py
@@ -0,0 +1,1725 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc..
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+import unittest
+from typing import List
+
+import safetensors
+from accelerate.utils import write_basic_config
+
+from diffusers import DiffusionPipeline, UNet2DConditionModel
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class ExamplesTestsAccelerate(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._tmpdir = tempfile.mkdtemp()
+ cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
+
+ write_basic_config(save_location=cls.configPath)
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ shutil.rmtree(cls._tmpdir)
+
+ def test_train_unconditional(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 2
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ """.split()
+
+ run_command(self._launch_args + test_args, return_stdout=True)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_textual_inversion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_if(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_lora(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --train_text_encoder
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # check `text_encoder` is present at all.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ keys = lora_state_dict.keys()
+ is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_text_encoder_present)
+
+ # the names of the keys of the state dict should either start with `unet`
+ # or `text_encoder`.
+ is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_correct_naming)
+
+ def test_dreambooth_lora_if_model(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --caption_column text
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --train_text_encoder
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_custom_diffusion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 1.0e-05
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --modifier_token
+ --no_safe_serialization
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin")))
+
+ def test_text_to_image(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_checkpointing(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ "checkpoint-4",
+ "checkpoint-6",
+ },
+ )
+
+ def test_text_to_image_checkpointing_use_ema(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ "checkpoint-4",
+ "checkpoint-6",
+ },
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 9, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4, 6, 8
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ # resume and we should try to checkpoint at 10, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 11
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_text_to_image_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --train_text_encoder
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 9, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4, 6, 8
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ # resume and we should try to checkpoint at 10, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 11
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 2
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --resume_from_checkpoint=checkpoint-6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
+ )
+
+ def test_textual_inversion_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2", "checkpoint-3"},
+ )
+
+ resume_run_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-3
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-3", "checkpoint-4"},
+ )
+
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
+ )
+
+ def test_controlnet_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+ def test_t2i_adapter_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/t2i_adapter/train_t2i_adapter_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --adapter_model_name_or_path=hf-internal-testing/tiny-adapter
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=9
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_text_to_image_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ def test_text_to_image_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
diff --git a/diffusers/examples/text_to_image/README.md b/diffusers/examples/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b9f4013c7467d4d6cb171765bf2991aacdf35ee
--- /dev/null
+++ b/diffusers/examples/text_to_image/README.md
@@ -0,0 +1,323 @@
+# Stable Diffusion text-to-image fine-tuning
+
+The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.
+
+___Note___:
+
+___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
+
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Pokemon example
+
+You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
+
+You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
+
+Run the following command to authenticate your token
+
+```bash
+huggingface-cli login
+```
+
+If you have already cloned the repo, then you won't need to go through these steps.
+
+
+
+#### Hardware
+With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+
+
+To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
+If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export TRAIN_DIR="path_to_your_dataset"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+
+Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
+
+
+```python
+from diffusers import StableDiffusionPipeline
+
+model_path = "path_to_saved_model"
+pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(prompt="yoda").images[0]
+image.save("yoda-pokemon.png")
+```
+
+Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
+```python
+from diffusers import StableDiffusionPipeline, UNet2DConditionModel
+
+model_path = "path_to_saved_model"
+
+unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-/unet")
+
+pipe = StableDiffusionPipeline.from_pretrained("", unet=unet, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+image = pipe(prompt="yoda").images[0]
+image.save("yoda-pokemon.png")
+```
+
+#### Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-pokemon-model"
+```
+
+
+#### Training with Min-SNR weighting
+
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
+by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
+value when using it is 5.0.
+
+You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:
+
+* Training without the Min-SNR weighting strategy
+* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
+* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)
+
+For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.
+
+Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
+
+```bash
+export MODEL_NAME="CompVis/stable-diffusion-v1-4"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=512 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --output_dir="sd-pokemon-model-lora" \
+ --validation_prompt="cute dragon creature" --report_to="wandb"
+```
+
+The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
+
+**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___**
+
+The final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___**
+
+You can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw).
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights. You
+need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora`.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_path = "sayakpaul/sd-model-finetuned-lora-t4"
+pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
+pipe.unet.load_attn_procs(model_path)
+pipe.to("cuda")
+
+prompt = "A pokemon with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("pokemon.png")
+```
+
+If you are loading the LoRA parameters from the Hub and if the Hub repository has
+a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then
+you can do:
+
+```py
+from huggingface_hub.repocard import RepoCard
+
+lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4"
+card = RepoCard.load(lora_model_id)
+base_model_id = card.data.to_dict()["base_model"]
+
+pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
+...
+```
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+**___Note: The flax example doesn't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards or TPU v3.___**
+
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-pokemon-model"
+```
+
+To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
+If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export TRAIN_DIR="path_to_your_dataset"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$TRAIN_DIR \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --mixed_precision="fp16" \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-pokemon-model"
+```
+
+### Training with xFormers:
+
+You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.
+
+xFormers training is not available for Flax/JAX.
+
+**Note**:
+
+According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
+
+## Stable Diffusion XL
+
+* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
+* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
diff --git a/diffusers/examples/text_to_image/README_sdxl.md b/diffusers/examples/text_to_image/README_sdxl.md
new file mode 100644
index 0000000000000000000000000000000000000000..75c9cb126472065e623a14724eed54673e918275
--- /dev/null
+++ b/diffusers/examples/text_to_image/README_sdxl.md
@@ -0,0 +1,225 @@
+# Stable Diffusion XL text-to-image fine-tuning
+
+The `train_text_to_image_sdxl.py` script shows how to fine-tune Stable Diffusion XL (SDXL) on your own dataset.
+
+🚨 This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset. 🚨
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/text_to_image` folder and run
+```bash
+pip install -r requirements_sdxl.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell (e.g., a notebook)
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+### Training
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_NAME \
+ --dataset_name=$DATASET_NAME \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=512 --center_crop --random_flip \
+ --proportion_empty_prompts=0.2 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 --gradient_checkpointing \
+ --max_train_steps=10000 \
+ --use_8bit_adam \
+ --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --mixed_precision="fp16" \
+ --report_to="wandb" \
+ --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
+ --checkpointing_steps=5000 \
+ --output_dir="sdxl-pokemon-model" \
+ --push_to_hub
+```
+
+**Notes**:
+
+* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
+* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
+* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
+* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+### Inference
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+model_path = "you-model-id-goes-here" # <-- change this
+pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
+pipe.to("cuda")
+
+prompt = "A pokemon with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("pokemon.png")
+```
+
+### Inference in Pytorch XLA
+```python
+from diffusers import DiffusionPipeline
+import torch
+import torch_xla.core.xla_model as xm
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipe = DiffusionPipeline.from_pretrained(model_id)
+
+device = xm.xla_device()
+pipe.to(device)
+
+prompt = "A pokemon with green eyes and red legs."
+start = time()
+image = pipe(prompt, num_inference_steps=inference_steps).images[0]
+print(f'Compilation time is {time()-start} sec')
+image.save("pokemon.png")
+
+start = time()
+image = pipe(prompt, num_inference_steps=inference_steps).images[0]
+print(f'Inference time is {time()-start} sec after compilation')
+```
+
+Note: There is a warmup step in PyTorch XLA. This takes longer because of
+compilation and optimization. To see the real benefits of Pytorch XLA and
+speedup, we need to call the pipe again on the input with the same length
+as the original prompt to reuse the optimized graph and get the performance
+boost.
+
+## LoRA training example for Stable Diffusion XL (SDXL)
+
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
+on consumer GPUs like Tesla T4, Tesla V100.
+
+### Training
+
+First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
+
+**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
+
+```bash
+export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
+export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+```
+
+For this example we want to directly store the trained LoRA embeddings on the Hub, so
+we need to be logged in and add the `--push_to_hub` flag.
+
+```bash
+huggingface-cli login
+```
+
+Now we can start training!
+
+```bash
+accelerate launch train_text_to_image_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --pretrained_vae_model_name_or_path=$VAE_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=1024 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=2 --checkpointing_steps=500 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --mixed_precision="fp16" \
+ --seed=42 \
+ --output_dir="sd-pokemon-model-lora-sdxl" \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub
+```
+
+The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
+
+**Notes**:
+
+* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
+
+### Finetuning the text encoder and UNet
+
+The script also allows you to finetune the `text_encoder` along with the `unet`.
+
+🚨 Training the text encoder requires additional memory.
+
+Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:
+
+```bash
+accelerate launch train_text_to_image_lora_sdxl.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=1024 --random_flip \
+ --train_batch_size=1 \
+ --num_train_epochs=2 --checkpointing_steps=500 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --output_dir="sd-pokemon-model-lora-sdxl-txt" \
+ --train_text_encoder \
+ --validation_prompt="cute dragon creature" --report_to="wandb" \
+ --push_to_hub
+```
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
+need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+model_path = "takuoko/sd-pokemon-model-lora-sdxl"
+pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
+pipe.to("cuda")
+pipe.load_lora_weights(model_path)
+
+prompt = "A pokemon with green eyes and red legs."
+image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
+image.save("pokemon.png")
+```
diff --git a/diffusers/examples/text_to_image/requirements.txt b/diffusers/examples/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..31b9026efdc2799b1d02e2e3f4d8dfc463737fdc
--- /dev/null
+++ b/diffusers/examples/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+datasets
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/text_to_image/requirements_flax.txt b/diffusers/examples/text_to_image/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b6eb64e254625ee8eff2ef126d67adfd5b6994dc
--- /dev/null
+++ b/diffusers/examples/text_to_image/requirements_flax.txt
@@ -0,0 +1,9 @@
+transformers>=4.25.1
+datasets
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/text_to_image/requirements_sdxl.txt b/diffusers/examples/text_to_image/requirements_sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cdd3336e3617dc8b13d46b1a3b5529bf32abda44
--- /dev/null
+++ b/diffusers/examples/text_to_image/requirements_sdxl.txt
@@ -0,0 +1,7 @@
+accelerate>=0.22.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+datasets
diff --git a/diffusers/examples/text_to_image/train_text_to_image.py b/diffusers/examples/text_to_image/train_text_to_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..628a0c9d7d96a0f4855f02c6278f79d4d6baf35c
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image.py
@@ -0,0 +1,1066 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder and set unet to trainable
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_perturbation:
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_perturbation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/text_to_image/train_text_to_image_flax.py b/diffusers/examples/text_to_image/train_text_to_image_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..e62d03c730b197c350798087e2667197d72e1713
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_flax.py
@@ -0,0 +1,592 @@
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import torch
+import torch.utils.checkpoint
+import transformers
+from datasets import load_dataset
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--from_pt",
+ action="store_true",
+ default=False,
+ help="Flag to indicate whether to convert models from PyTorch.",
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+dataset_name_mapping = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
+ input_ids = inputs.input_ids
+ return input_ids
+
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+
+ return examples
+
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = [example["input_ids"] for example in examples]
+
+ padded_tokens = tokenizer.pad(
+ {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
+ )
+ batch = {
+ "pixel_values": pixel_values,
+ "input_ids": padded_tokens.input_ids,
+ }
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
+ )
+
+ weight_dtype = jnp.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = jnp.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = jnp.bfloat16
+
+ # Load models and create wrapper for stable diffusion
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="tokenizer",
+ )
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="text_encoder",
+ dtype=weight_dtype,
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="vae",
+ dtype=weight_dtype,
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ from_pt=args.from_pt,
+ revision=args.revision,
+ subfolder="unet",
+ dtype=weight_dtype,
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ adamw = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ optimizer = optax.chain(
+ optax.clip_by_global_norm(args.max_grad_norm),
+ adamw,
+ )
+
+ state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ rng = jax.random.PRNGKey(args.seed)
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ def train_step(state, text_encoder_params, vae_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ # Convert images to latent space
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ # Sample a random timestep for each image
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(
+ batch["input_ids"],
+ params=text_encoder_params,
+ train=False,
+ )[0]
+
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+
+ new_state = state.apply_gradients(grads=grad)
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ text_encoder_params = jax_utils.replicate(text_encoder.params)
+ vae_params = jax_utils.replicate(vae_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+
+ global_step += 1
+ if global_step >= args.max_train_steps:
+ break
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(text_encoder_params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(state.params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/text_to_image/train_text_to_image_lora.py b/diffusers/examples/text_to_image/train_text_to_image_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7309196dec85a81550540d5c3292026961cf6a0
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_lora.py
@@ -0,0 +1,975 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.models.lora import LoRALinearLayer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+# TODO: This function should be removed once training scripts are rewritten in PEFT
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ def text_encoder_attn_modules(text_encoder):
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
+
+ attn_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+
+ return attn_modules
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def main():
+ args = parse_args()
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # now we will add new LoRA weights to the attention layers
+ # It's important to realize here how many attention weights will be added and of which sizes
+ # The sizes of the attention layers consist only of two different variables:
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
+
+ # Let's first see how many attention processors we will have to set.
+ # For Stable Diffusion, it should be equal to:
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
+ # => 32 layers
+
+ # Set correct lora layers
+ unet_lora_parameters = []
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
+ # Parse the attention module.
+ attn_module = unet
+ for n in attn_processor_name.split(".")[:-1]:
+ attn_module = getattr(attn_module, n)
+
+ # Set the `lora_layer` attribute of the attention-related matrices.
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
+ )
+ )
+
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=args.rank,
+ )
+ )
+
+ # Accumulate the LoRA params to optimize.
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet_lora_parameters,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet_lora_parameters
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+ unet.save_attn_procs(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.unet.load_attn_procs(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
+
+ if accelerator.is_main_process:
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py b/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..96bfe9e16783e417f95bd91e93421b48ce650e49
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -0,0 +1,1296 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
+
+import argparse
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import Dict
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.models.lora import LoRALinearLayer
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+# TODO: This function should be removed once training scripts are rewritten in PEFT
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ def text_encoder_attn_modules(text_encoder):
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
+
+ attn_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+
+ return attn_modules
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ dataset_name=str,
+ train_text_encoder=False,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+dataset: {dataset_name}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA text2image fine-tuning - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
+ """
+ Returns:
+ a state dict containing just the attention processor parameters.
+ """
+ attn_processors = unet.attn_processors
+
+ attn_processors_state_dict = {}
+
+ for attn_processor_key, attn_processor in attn_processors.items():
+ for parameter_key, parameter in attn_processor.state_dict().items():
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
+
+ return attn_processors_state_dict
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ unet.to(accelerator.device, dtype=weight_dtype)
+ if args.pretrained_vae_model_name_or_path is None:
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_parameters = []
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
+ # Parse the attention module.
+ attn_module = unet
+ for n in attn_processor_name.split(".")[:-1]:
+ attn_module = getattr(attn_module, n)
+
+ # Set the `lora_layer` attribute of the attention-related matrices.
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
+ )
+ )
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=args.rank,
+ )
+ )
+
+ # Accumulate the LoRA params to optimize.
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_one, dtype=torch.float32, rank=args.rank
+ )
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
+ text_encoder_two, dtype=torch.float32, rank=args.rank
+ )
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
+
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
+ LoraLoaderMixin.load_lora_into_text_encoder(
+ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
+ return tokens_one, tokens_two
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ x1 = image.width - x1
+ image = train_flip(image)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ tokens_one, tokens_two = tokenize_captions(examples)
+ examples["input_ids_one"] = tokens_one
+ examples["input_ids_two"] = tokens_two
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
+ return {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
+ )
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
+
+ if args.train_text_encoder:
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one)
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two)
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ del unet
+ del text_encoder_one
+ del text_encoder_two
+ del text_encoder_lora_layers
+ del text_encoder_2_lora_layers
+ torch.cuda.empty_cache()
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ train_text_encoder=args.train_text_encoder,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/text_to_image/train_text_to_image_sdxl.py b/diffusers/examples/text_to_image/train_text_to_image_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..041464e701cc0ca40556e6faee810f2ecc8e406d
--- /dev/null
+++ b/diffusers/examples/text_to_image/train_text_to_image_sdxl.py
@@ -0,0 +1,1258 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ validation_prompt=None,
+ base_model=str,
+ dataset_name=str,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+dataset: {dataset_name}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch[caption_column]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
+
+
+def compute_vae_encodings(batch, vae):
+ images = batch.pop("pixel_values")
+ pixel_values = torch.stack(list(images))
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ return {"model_input": model_input.cpu()}
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ x1 = image.width - x1
+ image = train_flip(image)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ caption_column=args.caption_column,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ new_fingerprint_for_vae = Hasher.hash("vae")
+ train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
+ train_dataset = train_dataset.map(
+ compute_vae_encodings_fn,
+ batched=True,
+ batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
+ new_fingerprint=new_fingerprint_for_vae,
+ )
+
+ del text_encoders, tokenizers, vae
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def collate_fn(examples):
+ model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
+
+ return {
+ "model_input": model_input,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image without bias.
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ else:
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/textual_inversion/README.md b/diffusers/examples/textual_inversion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0a1d8a459fc6719671191bd770e80ce7d0660606
--- /dev/null
+++ b/diffusers/examples/textual_inversion/README.md
@@ -0,0 +1,148 @@
+## Textual Inversion fine-tuning example
+
+[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
+
+## Running on Colab
+
+Colab for training
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
+
+Colab for inference
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
+
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run:
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Cat toy example
+
+First, let's login so that we can upload the checkpoint to the Hub during training:
+
+```bash
+huggingface-cli login
+```
+
+Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
+
+Let's first download it locally:
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
+```
+
+This will be our training data.
+Now we can launch the training using:
+
+**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+
+```bash
+export MODEL_NAME="runwayml/stable-diffusion-v1-5"
+export DATA_DIR="./cat"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --push_to_hub \
+ --output_dir="textual_inversion_cat"
+```
+
+A full training run takes ~1 hour on one V100 GPU.
+
+**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
+only one embedding vector is used for the placeholder token, *e.g.* `""`.
+However, one can also add multiple embedding vectors for the placeholder token
+to increase the number of fine-tuneable parameters. This can help the model to learn
+more complex details. To use multiple embedding vectors, you should define `--num_vectors`
+to a number larger than one, *e.g.*:
+```bash
+--num_vectors 5
+```
+
+The saved textual inversion vectors will then be larger in size compared to the default case.
+
+### Inference
+
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_id = "path-to-your-trained-model"
+pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
+
+prompt = "A backpack"
+
+image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+
+image.save("cat-backpack.png")
+```
+
+
+## Training with Flax/JAX
+
+For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+```bash
+pip install -U -r requirements_flax.txt
+```
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="path-to-dir-containing-images"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --output_dir="textual_inversion_cat"
+```
+It should be at least 70% faster than the PyTorch script with the same configuration.
+
+### Training with xformers:
+You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
diff --git a/diffusers/examples/textual_inversion/requirements.txt b/diffusers/examples/textual_inversion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7a612982f4abbaa64f83db52e411a1235a372259
--- /dev/null
+++ b/diffusers/examples/textual_inversion/requirements.txt
@@ -0,0 +1,6 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/textual_inversion/requirements_flax.txt b/diffusers/examples/textual_inversion/requirements_flax.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f85ad523a3b46b65abf0138c05ecdd656e6845c
--- /dev/null
+++ b/diffusers/examples/textual_inversion/requirements_flax.txt
@@ -0,0 +1,8 @@
+transformers>=4.25.1
+flax
+optax
+torch
+torchvision
+ftfy
+tensorboard
+Jinja2
diff --git a/diffusers/examples/textual_inversion/textual_inversion.py b/diffusers/examples/textual_inversion/textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce998aab1fbb20ea83318e769f8527c1e9179f2
--- /dev/null
+++ b/diffusers/examples/textual_inversion/textual_inversion.py
@@ -0,0 +1,990 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import PIL
+import safetensors
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+if is_wandb_available():
+ import wandb
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+- textual_inversion
+inference: true
+---
+ """
+ model_card = f"""
+# Textual inversion text2image fine-tuning - {repo_id}
+These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
+{img_str}
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=unet,
+ vae=vae,
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+ return images
+
+
+def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
+ logger.info("Saving embeddings")
+ learned_embeds = (
+ accelerator.unwrap_model(text_encoder)
+ .get_input_embeddings()
+ .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
+ )
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
+
+ if safe_serialization:
+ safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
+ else:
+ torch.save(learned_embeds_dict, save_path)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--save_as_full_pipeline",
+ action="store_true",
+ help="Save the complete stable diffusion pipeline.",
+ )
+ parser.add_argument(
+ "--num_vectors",
+ type=int,
+ default=1,
+ help="How many textual inversion vectors shall be used to learn the concept.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=None,
+ help=(
+ "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--no_safe_serialization",
+ action="store_true",
+ help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load tokenizer
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Add the placeholder token in tokenizer
+ placeholder_tokens = [args.placeholder_token]
+
+ if args.num_vectors < 1:
+ raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, args.num_vectors):
+ additional_tokens.append(f"{args.placeholder_token}_{i}")
+ placeholder_tokens += additional_tokens
+
+ num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
+ if num_added_tokens != args.num_vectors:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for token_id in placeholder_token_ids:
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+
+ # Freeze vae and unet
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ # Freeze all parameters except for the token embeddings in text encoder
+ text_encoder.text_model.encoder.requires_grad_(False)
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
+ unet.train()
+ text_encoder.gradient_checkpointing_enable()
+ unet.enable_gradient_checkpointing()
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+ if args.validation_epochs is not None:
+ warnings.warn(
+ f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
+ " Deprecated validation_epochs in favor of `validation_steps`"
+ f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
+ FutureWarning,
+ stacklevel=2,
+ )
+ args.validation_steps = args.validation_epochs * len(train_dataset)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ )
+
+ # Prepare everything with our `accelerator`.
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and unet to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ # keep original embeddings as reference
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(text_encoder):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
+
+ with torch.no_grad():
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds_params[index_no_updates]
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ images = []
+ progress_bar.update(1)
+ global_step += 1
+ if global_step % args.save_steps == 0:
+ weight_name = (
+ f"learned_embeds-steps-{global_step}.bin"
+ if args.no_safe_serialization
+ else f"learned_embeds-steps-{global_step}.safetensors"
+ )
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=not args.no_safe_serialization,
+ )
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ if args.push_to_hub and not args.save_as_full_pipeline:
+ logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
+ save_full_model = True
+ else:
+ save_full_model = args.save_as_full_pipeline
+ if save_full_model:
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ )
+ pipeline.save_pretrained(args.output_dir)
+ # Save the newly trained embeddings
+ weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors"
+ save_path = os.path.join(args.output_dir, weight_name)
+ save_progress(
+ text_encoder,
+ placeholder_token_ids,
+ accelerator,
+ args,
+ save_path,
+ safe_serialization=not args.no_safe_serialization,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/textual_inversion/textual_inversion_flax.py b/diffusers/examples/textual_inversion/textual_inversion_flax.py
new file mode 100644
index 0000000000000000000000000000000000000000..5de1a8d7c325916de57ad2639488233765669e25
--- /dev/null
+++ b/diffusers/examples/textual_inversion/textual_inversion_flax.py
@@ -0,0 +1,681 @@
+import argparse
+import logging
+import math
+import os
+import random
+from pathlib import Path
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import optax
+import PIL
+import torch
+import torch.utils.checkpoint
+import transformers
+from flax import jax_utils
+from flax.training import train_state
+from flax.training.common_utils import shard
+from huggingface_hub import create_repo, upload_folder
+
+# TODO: remove and import from diffusers.utils when the new version of diffusers is released
+from packaging import version
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
+
+from diffusers import (
+ FlaxAutoencoderKL,
+ FlaxDDPMScheduler,
+ FlaxPNDMScheduler,
+ FlaxStableDiffusionPipeline,
+ FlaxUNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import check_min_version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
+# ------------------------------------------------------------------------------
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
+ )
+ parser.add_argument(
+ "--placeholder_token",
+ type=str,
+ default=None,
+ required=True,
+ help="A token to use as a placeholder for the concept.",
+ )
+ parser.add_argument(
+ "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
+ )
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=5000,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--save_steps",
+ type=int,
+ default=500,
+ help="Save learned_embeds.bin every X updates steps.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument(
+ "--use_auth_token",
+ action="store_true",
+ help=(
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ " private models)."
+ ),
+ )
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.train_data_dir is None:
+ raise ValueError("You must specify a train data directory.")
+
+ return args
+
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ flip_p=0.5,
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ ):
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+ self.flip_p = flip_p
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
+ }[interpolation]
+
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = {}
+ image = Image.open(self.image_paths[i % self.num_images])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ placeholder_string = self.placeholder_token
+ text = random.choice(self.templates).format(placeholder_string)
+
+ example["input_ids"] = self.tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0]
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ (
+ h,
+ w,
+ ) = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip_transform(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+ return example
+
+
+def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
+ if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
+ return
+ model.config.vocab_size = new_num_tokens
+
+ params = model.params
+ old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+ old_num_tokens, emb_dim = old_embeddings.shape
+
+ initializer = jax.nn.initializers.normal()
+
+ new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
+ new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
+ new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
+ params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
+
+ model.params = params
+ return model
+
+
+def get_params_to_save(params):
+ return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
+
+
+def main():
+ args = parse_args()
+
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ if jax.process_index() == 0:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ # Setup logging, we only want one process per machine to log things on the screen.
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
+ if jax.process_index() == 0:
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Load the tokenizer and add the placeholder token as a additional special token
+ if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ raise ValueError("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+ # Load models and create wrapper for stable diffusion
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Create sampling rng
+ rng = jax.random.PRNGKey(args.seed)
+ rng, _ = jax.random.split(rng)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder = resize_token_embeddings(
+ text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
+ )
+ original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
+
+ train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=args.placeholder_token,
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+ )
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+
+ batch = {"pixel_values": pixel_values, "input_ids": input_ids}
+ batch = {k: v.numpy() for k, v in batch.items()}
+
+ return batch
+
+ total_train_batch_size = args.train_batch_size * jax.local_device_count()
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Optimization
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * total_train_batch_size
+
+ constant_scheduler = optax.constant_schedule(args.learning_rate)
+
+ optimizer = optax.adamw(
+ learning_rate=constant_scheduler,
+ b1=args.adam_beta1,
+ b2=args.adam_beta2,
+ eps=args.adam_epsilon,
+ weight_decay=args.adam_weight_decay,
+ )
+
+ def create_mask(params, label_fn):
+ def _map(params, mask, label_fn):
+ for k in params:
+ if label_fn(k):
+ mask[k] = "token_embedding"
+ else:
+ if isinstance(params[k], dict):
+ mask[k] = {}
+ _map(params[k], mask[k], label_fn)
+ else:
+ mask[k] = "zero"
+
+ mask = {}
+ _map(params, mask, label_fn)
+ return mask
+
+ def zero_grads():
+ # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
+ def init_fn(_):
+ return ()
+
+ def update_fn(updates, state, params=None):
+ return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
+
+ return optax.GradientTransformation(init_fn, update_fn)
+
+ # Zero out gradients of layers other than the token embedding layer
+ tx = optax.multi_transform(
+ {"token_embedding": optimizer, "zero": zero_grads()},
+ create_mask(text_encoder.params, lambda s: s == "token_embedding"),
+ )
+
+ state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
+
+ noise_scheduler = FlaxDDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
+ )
+ noise_scheduler_state = noise_scheduler.create_state()
+
+ # Initialize our training
+ train_rngs = jax.random.split(rng, jax.local_device_count())
+
+ # Define gradient train step fn
+ def train_step(state, vae_params, unet_params, batch, train_rng):
+ dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
+
+ def compute_loss(params):
+ vae_outputs = vae.apply(
+ {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
+ )
+ latents = vae_outputs.latent_dist.sample(sample_rng)
+ # (NHWC) -> (NCHW)
+ latents = jnp.transpose(latents, (0, 3, 1, 2))
+ latents = latents * vae.config.scaling_factor
+
+ noise_rng, timestep_rng = jax.random.split(sample_rng)
+ noise = jax.random.normal(noise_rng, latents.shape)
+ bsz = latents.shape[0]
+ timesteps = jax.random.randint(
+ timestep_rng,
+ (bsz,),
+ 0,
+ noise_scheduler.config.num_train_timesteps,
+ )
+ noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
+ encoder_hidden_states = state.apply_fn(
+ batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
+ )[0]
+ # Predict the noise residual and compute loss
+ model_pred = unet.apply(
+ {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ loss = (target - model_pred) ** 2
+ loss = loss.mean()
+
+ return loss
+
+ grad_fn = jax.value_and_grad(compute_loss)
+ loss, grad = grad_fn(state.params)
+ grad = jax.lax.pmean(grad, "batch")
+ new_state = state.apply_gradients(grads=grad)
+
+ # Keep the token embeddings fixed except the newly added embeddings for the concept,
+ # as we only want to optimize the concept embeddings
+ token_embeds = original_token_embeds.at[placeholder_token_id].set(
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
+ )
+ new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
+
+ metrics = {"loss": loss}
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
+ return new_state, metrics, new_train_rng
+
+ # Create parallel version of the train and eval step
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
+
+ # Replicate the train state on each device
+ state = jax_utils.replicate(state)
+ vae_params = jax_utils.replicate(vae_params)
+ unet_params = jax_utils.replicate(unet_params)
+
+ # Train!
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+
+ # Scheduler and math around the number of training steps.
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+
+ global_step = 0
+
+ epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
+ for epoch in epochs:
+ # ======================== Training ================================
+
+ train_metrics = []
+
+ steps_per_epoch = len(train_dataset) // total_train_batch_size
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
+ # train
+ for batch in train_dataloader:
+ batch = shard(batch)
+ state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
+ train_metrics.append(train_metric)
+
+ train_step_progress_bar.update(1)
+ global_step += 1
+
+ if global_step >= args.max_train_steps:
+ break
+ if global_step % args.save_steps == 0:
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
+ "embedding"
+ ][placeholder_token_id]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(
+ os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
+ )
+
+ train_metric = jax_utils.unreplicate(train_metric)
+
+ train_step_progress_bar.close()
+ epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
+
+ # Create the pipeline using using the trained modules and save it.
+ if jax.process_index() == 0:
+ scheduler = FlaxPNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
+ )
+ safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", from_pt=True
+ )
+ pipeline = FlaxStableDiffusionPipeline(
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
+ )
+
+ pipeline.save_pretrained(
+ args.output_dir,
+ params={
+ "text_encoder": get_params_to_save(state.params),
+ "vae": get_params_to_save(vae_params),
+ "unet": get_params_to_save(unet_params),
+ "safety_checker": safety_checker.params,
+ },
+ )
+
+ # Also save the newly trained embeddings
+ learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
+ placeholder_token_id
+ ]
+ learned_embeds_dict = {args.placeholder_token: learned_embeds}
+ jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/unconditional_image_generation/README.md b/diffusers/examples/unconditional_image_generation/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d83dc928c7a1164b3e8896bcfa1ef5d417ea6b80
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/README.md
@@ -0,0 +1,163 @@
+## Training an unconditional diffusion model
+
+Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+### Unconditional Flowers
+
+The command to train a DDPM UNet model on the Oxford Flowers dataset:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/flowers-102-categories" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-flowers-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=no \
+ --push_to_hub
+```
+An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64
+
+A full training run takes 2 hours on 4xV100 GPUs.
+
+
+
+
+### Unconditional Pokemon
+
+The command to train a DDPM UNet model on the Pokemon dataset:
+
+```bash
+accelerate launch train_unconditional.py \
+ --dataset_name="huggan/pokemon" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-pokemon-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision=no \
+ --push_to_hub
+```
+An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64
+
+A full training run takes 2 hours on 4xV100 GPUs.
+
+
+
+### Training with multiple GPUs
+
+`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
+for running distributed training with `accelerate`. Here is an example command:
+
+```bash
+accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
+ --dataset_name="huggan/pokemon" \
+ --resolution=64 --center_crop --random_flip \
+ --output_dir="ddpm-ema-pokemon-64" \
+ --train_batch_size=16 \
+ --num_epochs=100 \
+ --gradient_accumulation_steps=1 \
+ --use_ema \
+ --learning_rate=1e-4 \
+ --lr_warmup_steps=500 \
+ --mixed_precision="fp16" \
+ --logger="wandb"
+```
+
+To be able to use Weights and Biases (`wandb`) as a logger you need to install the library: `pip install wandb`.
+
+### Using your own data
+
+To use your own dataset, there are 2 ways:
+- you can either provide your own folder as `--train_data_dir`
+- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
+
+Below, we explain both in more detail.
+
+#### Provide the dataset as a folder
+
+If you provide your own folders with images, the script expects the following directory structure:
+
+```bash
+data_dir/xxx.png
+data_dir/xxy.png
+data_dir/[...]/xxz.png
+```
+
+In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:
+
+```bash
+accelerate launch train_unconditional.py \
+ --train_data_dir \
+
+```
+
+Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
+
+#### Upload your data to the hub, as a (possibly private) repo
+
+It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
+
+```python
+from datasets import load_dataset
+
+# example 1: local folder
+dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
+
+# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
+
+# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
+dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
+
+# example 4: providing several splits
+dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
+```
+
+`ImageFolder` will create an `image` column containing the PIL-encoded images.
+
+Next, push it to the hub!
+
+```python
+# assuming you have ran the huggingface-cli login command in a terminal
+dataset.push_to_hub("name_of_your_dataset")
+
+# if you want to push to a private repo, simply pass private=True:
+dataset.push_to_hub("name_of_your_dataset", private=True)
+```
+
+and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
+
+More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
diff --git a/diffusers/examples/unconditional_image_generation/requirements.txt b/diffusers/examples/unconditional_image_generation/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f366720afd11e41945a3f29472be2048dbf98404
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/requirements.txt
@@ -0,0 +1,3 @@
+accelerate>=0.16.0
+torchvision
+datasets
diff --git a/diffusers/examples/unconditional_image_generation/train_unconditional.py b/diffusers/examples/unconditional_image_generation/train_unconditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e552c9b3dde2dd9cbe7ad6a5c1a1829c1848390
--- /dev/null
+++ b/diffusers/examples/unconditional_image_generation/train_unconditional.py
@@ -0,0 +1,704 @@
+import argparse
+import inspect
+import logging
+import math
+import os
+import shutil
+from datetime import timedelta
+from pathlib import Path
+
+import accelerate
+import datasets
+import torch
+import torch.nn.functional as F
+from accelerate import Accelerator, InitProcessGroupKwargs
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ if not isinstance(arr, torch.Tensor):
+ arr = torch.from_numpy(arr)
+ res = arr[timesteps].float().to(timesteps.device)
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that HF Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--model_config_name_or_path",
+ type=str,
+ default=None,
+ help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="ddpm-model-64",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--overwrite_output_dir", action="store_true")
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=64,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ default=False,
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+ " process."
+ ),
+ )
+ parser.add_argument("--num_epochs", type=int, default=100)
+ parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
+ parser.add_argument(
+ "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="cosine",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
+ parser.add_argument(
+ "--use_ema",
+ action="store_true",
+ help="Whether to use Exponential Moving Average for the final model weights.",
+ )
+ parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
+ parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
+ parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
+ )
+ parser.add_argument(
+ "--logger",
+ type=str,
+ default="tensorboard",
+ choices=["tensorboard", "wandb"],
+ help=(
+ "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
+ " for experiment tracking and logging of model metrics and model checkpoints"
+ ),
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="no",
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose"
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU."
+ ),
+ )
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default="epsilon",
+ choices=["epsilon", "sample"],
+ help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
+ )
+ parser.add_argument("--ddpm_num_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
+ parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
+
+ return args
+
+
+def main(args):
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) # a big number for high resolution or big dataset
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.logger,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.logger == "tensorboard":
+ if not is_tensorboard_available():
+ raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
+
+ elif args.logger == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
+ ema_model.load_state_dict(load_model.state_dict())
+ ema_model.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Initialize the model
+ if args.model_config_name_or_path is None:
+ model = UNet2DModel(
+ sample_size=args.resolution,
+ in_channels=3,
+ out_channels=3,
+ layers_per_block=2,
+ block_out_channels=(128, 128, 256, 256, 512, 512),
+ down_block_types=(
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "DownBlock2D",
+ "AttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types=(
+ "UpBlock2D",
+ "AttnUpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ "UpBlock2D",
+ ),
+ )
+ else:
+ config = UNet2DModel.load_config(args.model_config_name_or_path)
+ model = UNet2DModel.from_config(config)
+
+ # Create EMA for the model.
+ if args.use_ema:
+ ema_model = EMAModel(
+ model.parameters(),
+ decay=args.ema_max_decay,
+ use_ema_warmup=True,
+ inv_gamma=args.ema_inv_gamma,
+ power=args.ema_power,
+ model_cls=UNet2DModel,
+ model_config=model.config,
+ )
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ model.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Initialize the scheduler
+ accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
+ if accepts_prediction_type:
+ noise_scheduler = DDPMScheduler(
+ num_train_timesteps=args.ddpm_num_steps,
+ beta_schedule=args.ddpm_beta_schedule,
+ prediction_type=args.prediction_type,
+ )
+ else:
+ noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ split="train",
+ )
+ else:
+ dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets and DataLoaders creation.
+ augmentations = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def transform_images(examples):
+ images = [augmentations(image.convert("RGB")) for image in examples["image"]]
+ return {"input": images}
+
+ logger.info(f"Dataset size: {len(dataset)}")
+
+ dataset.set_transform(transform_images)
+ train_dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+ )
+
+ # Initialize the learning rate scheduler
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=(len(train_dataloader) * args.num_epochs),
+ )
+
+ # Prepare everything with our `accelerator`.
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ model, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_model.to(accelerator.device)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ run = os.path.split(__file__)[-1].split(".")[0]
+ accelerator.init_trackers(run)
+
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ max_train_steps = args.num_epochs * num_update_steps_per_epoch
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(dataset)}")
+ logger.info(f" Num Epochs = {args.num_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_train_steps}")
+
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Train!
+ for epoch in range(first_epoch, args.num_epochs):
+ model.train()
+ progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
+ progress_bar.set_description(f"Epoch {epoch}")
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ clean_images = batch["input"].to(weight_dtype)
+ # Sample noise that we'll add to the images
+ noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
+ bsz = clean_images.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
+ ).long()
+
+ # Add noise to the clean images according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
+
+ with accelerator.accumulate(model):
+ # Predict the noise residual
+ model_output = model(noisy_images, timesteps).sample
+
+ if args.prediction_type == "epsilon":
+ loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights!
+ elif args.prediction_type == "sample":
+ alpha_t = _extract_into_tensor(
+ noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
+ )
+ snr_weights = alpha_t / (1 - alpha_t)
+ # use SNR weighting from distillation paper
+ loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none")
+ loss = loss.mean()
+ else:
+ raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_model.step(model.parameters())
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
+ if args.use_ema:
+ logs["ema_decay"] = ema_model.cur_decay_value
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+ progress_bar.close()
+
+ accelerator.wait_for_everyone()
+
+ # Generate sample images for visual inspection
+ if accelerator.is_main_process:
+ if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ generator = torch.Generator(device=pipeline.device).manual_seed(0)
+ # run pipeline in inference (sample random noise and denoise)
+ images = pipeline(
+ generator=generator,
+ batch_size=args.eval_batch_size,
+ num_inference_steps=args.ddpm_num_inference_steps,
+ output_type="numpy",
+ ).images
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ # denormalize the images and save to tensorboard
+ images_processed = (images * 255).round().astype("uint8")
+
+ if args.logger == "tensorboard":
+ if is_accelerate_version(">=", "0.17.0.dev0"):
+ tracker = accelerator.get_tracker("tensorboard", unwrap=True)
+ else:
+ tracker = accelerator.get_tracker("tensorboard")
+ tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
+ elif args.logger == "wandb":
+ # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
+ accelerator.get_tracker("wandb").log(
+ {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
+ step=global_step,
+ )
+
+ if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
+ # save the model
+ unet = accelerator.unwrap_model(model)
+
+ if args.use_ema:
+ ema_model.store(unet.parameters())
+ ema_model.copy_to(unet.parameters())
+
+ pipeline = DDPMPipeline(
+ unet=unet,
+ scheduler=noise_scheduler,
+ )
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.use_ema:
+ ema_model.restore(unet.parameters())
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message=f"Epoch {epoch}",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/diffusers/examples/wuerstchen/text_to_image/README.md b/diffusers/examples/wuerstchen/text_to_image/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5378e3ef5253e6b20e4b9fd775f0a058ada96b0a
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/README.md
@@ -0,0 +1,93 @@
+# Würstchen text-to-image fine-tuning
+
+## Running locally with PyTorch
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd into the example folder and run
+```bash
+cd examples/wuerstchen/text_to_image
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
+```bash
+huggingface-cli login
+```
+
+## Prior training
+
+You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --dataloader_num_workers=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --validation_prompts="A robot pokemon, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="wuerstchen-prior-pokemon-model"
+```
+
+
+## Training with LoRA
+
+Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+
+In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+
+- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
+- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
+
+
+### Prior Training
+
+First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
+
+```bash
+export DATASET_NAME="lambdalabs/pokemon-blip-captions"
+
+accelerate launch train_text_to_image_prior_lora.py \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME --caption_column="text" \
+ --resolution=768 \
+ --train_batch_size=8 \
+ --num_train_epochs=100 --checkpointing_steps=5000 \
+ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --seed=42 \
+ --rank=4 \
+ --validation_prompt="cute dragon creature" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="wuerstchen-prior-pokemon-lora"
+```
diff --git a/diffusers/examples/wuerstchen/text_to_image/__init__.py b/diffusers/examples/wuerstchen/text_to_image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffusers/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/diffusers/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd551ebf1623af6d03cff998b258f19ee1294245
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py
@@ -0,0 +1,23 @@
+import torch.nn as nn
+from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+
+
+class EfficientNetEncoder(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
+ super().__init__()
+
+ if effnet == "efficientnet_v2_s":
+ self.backbone = efficientnet_v2_s(weights="DEFAULT").features
+ else:
+ self.backbone = efficientnet_v2_l(weights="DEFAULT").features
+ self.mapper = nn.Sequential(
+ nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
+ )
+
+ def forward(self, x):
+ return self.mapper(self.backbone(x))
diff --git a/diffusers/examples/wuerstchen/text_to_image/requirements.txt b/diffusers/examples/wuerstchen/text_to_image/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a58ad09eca55e95f3d28da348fff8d7f40046764
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+wandb
+huggingface-cli
+bitsandbytes
+deepspeed
diff --git a/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..33de3d3bf777d52a42fa5a61669926867109e0cb
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -0,0 +1,888 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState, is_initialized
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from modeling_efficient_net_encoder import EfficientNetEncoder
+from torchvision import transforms
+from tqdm import tqdm
+from transformers import CLIPTextModel, PreTrainedTokenizerFast
+from transformers.utils import ContextManagers
+
+from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
+
+ yaml = f"""
+---
+license: mit
+base_model: {args.pretrained_prior_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- wuerstchen
+- text-to-image
+- diffusers
+- lora
+inference: true
+---
+ """
+ model_card = f"""
+# LoRA Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}
+ )
+# load lora weights from folder:
+pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype})
+
+image = pipeline(prompt=prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* LoRA rank: {args.rank}
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.prior_prior.set_attn_processor(attn_processors)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="wuerstchen-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, effnet, tokenizer, clip_model
+ noise_scheduler = DDPMWuerstchenScheduler()
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
+ state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
+ image_encoder = EfficientNetEncoder()
+ image_encoder.load_state_dict(state_dict["effnet_state_dict"])
+ image_encoder.eval()
+
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+
+ # Freeze text_encoder, cast to weight_dtype and image_encoder and move to device
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # load prior model, cast to weight_dtype and move to device
+ prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ prior.to(accelerator.device, dtype=weight_dtype)
+
+ # lora attn processor
+ lora_attn_procs = {}
+ for name in prior.attn_processors.keys():
+ lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank)
+ prior.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(prior.attn_processors)
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+ optimizer = optimizer_cls(
+ lora_layers.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ effnet_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
+ effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ prior.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, effnet_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["effnet_pixel_values"].to(weight_dtype),
+ )
+
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
+ prompt_embeds = text_encoder_output.last_hidden_state
+ image_embeds = image_encoder(effnet_images)
+ # scale
+ image_embeds = image_embeds.add(1.0).div(42.0)
+
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
+
+ # add noise to latent
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ # Predict the noise residual and compute losscd
+ pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
+
+ # vanilla loss
+ loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ log_validation(
+ text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
+ )
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = prior.to(torch.float32)
+ WuerstchenPriorPipeline.save_lora_weights(
+ os.path.join(args.output_dir, "prior_lora"),
+ unet_lora_layers=lora_layers,
+ )
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ )
+ pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
+ # load lora weights
+ pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
+
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ width=args.resolution,
+ height=args.resolution,
+ ).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..62450679f20131f473851c55a0e60dc8abc00147
--- /dev/null
+++ b/diffusers/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -0,0 +1,925 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState, is_initialized
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from modeling_efficient_net_encoder import EfficientNetEncoder
+from packaging import version
+from torchvision import transforms
+from tqdm import tqdm
+from transformers import CLIPTextModel, PreTrainedTokenizerFast
+from transformers.utils import ContextManagers
+
+from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
+from diffusers.optimization import get_scheduler
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
+
+ yaml = f"""
+---
+license: mit
+base_model: {args.pretrained_prior_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- wuerstchen
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype})
+pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype})
+prompt = "{args.validation_prompts[0]}"
+(image_embeds,) = pipe_prior(prompt).to_tuple()
+image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_prior=accelerator.unwrap_model(prior),
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ height=args.resolution,
+ width=args.resolution,
+ ).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
+ parser.add_argument(
+ "--pretrained_decoder_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_prior_model_name_or_path",
+ type=str,
+ default="warp-ai/wuerstchen-prior",
+ required=False,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="wuerstchen-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="learning rate",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument(
+ "--adam_weight_decay",
+ type=float,
+ default=0.0,
+ required=False,
+ help="weight decay_to_use",
+ )
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+def main():
+ args = parse_args()
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
+ )
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, effnet, tokenizer, clip_model
+ noise_scheduler = DDPMWuerstchenScheduler()
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
+ state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
+ image_encoder = EfficientNetEncoder()
+ image_encoder.load_state_dict(state_dict["effnet_state_dict"])
+ image_encoder.eval()
+
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+
+ # Freeze text_encoder and image_encoder
+ text_encoder.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ # load prior model
+ prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+
+ # Create EMA for the prior
+ if args.use_ema:
+ ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+ ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config)
+ ema_prior.to(accelerator.device)
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if args.use_ema:
+ ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "prior"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior)
+ ema_prior.load_state_dict(load_model.state_dict())
+ ema_prior.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ prior.enable_gradient_checkpointing()
+
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+ optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ text_input_ids = inputs.input_ids
+ text_mask = inputs.attention_mask.bool()
+ return text_input_ids, text_mask
+
+ effnet_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
+ effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
+ text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
+ text_mask = torch.stack([example["text_mask"] for example in examples])
+ return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ prior, optimizer, train_dataloader, lr_scheduler
+ )
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ prior.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ # Skip steps until we reach the resumed step
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ if step % args.gradient_accumulation_steps == 0:
+ progress_bar.update(1)
+ continue
+
+ with accelerator.accumulate(prior):
+ # Convert images to latent space
+ text_input_ids, text_mask, effnet_images = (
+ batch["text_input_ids"],
+ batch["text_mask"],
+ batch["effnet_pixel_values"].to(weight_dtype),
+ )
+
+ with torch.no_grad():
+ text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
+ prompt_embeds = text_encoder_output.last_hidden_state
+ image_embeds = image_encoder(effnet_images)
+ # scale
+ image_embeds = image_embeds.add(1.0).div(42.0)
+
+ # Sample noise that we'll add to the image_embeds
+ noise = torch.randn_like(image_embeds)
+ bsz = image_embeds.shape[0]
+
+ # Sample a random timestep for each image
+ timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
+
+ # add noise to latent
+ noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
+
+ # Predict the noise residual and compute losscd
+ pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
+
+ # vanilla loss
+ loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_prior.step(prior.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_prior.store(prior.parameters())
+ ema_prior.copy_to(prior.parameters())
+ log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_prior.restore(prior.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ prior = accelerator.unwrap_model(prior)
+ if args.use_ema:
+ ema_prior.copy_to(prior.parameters())
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ args.pretrained_decoder_model_name_or_path,
+ prior_prior=prior,
+ prior_text_encoder=accelerator.unwrap_model(text_encoder),
+ prior_tokenizer=tokenizer,
+ )
+ pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline"))
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(
+ args.validation_prompts[i],
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ generator=generator,
+ width=args.resolution,
+ height=args.resolution,
+ ).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/pyproject.toml b/diffusers/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..0612f2f9e059e4f62e8d393aab01b6c2d73f7108
--- /dev/null
+++ b/diffusers/pyproject.toml
@@ -0,0 +1,27 @@
+[tool.ruff]
+# Never enforce `E501` (line length violations).
+ignore = ["C901", "E501", "E741", "F402", "F823"]
+select = ["C", "E", "F", "I", "W"]
+line-length = 119
+
+# Ignore import violations in all `__init__.py` files.
+[tool.ruff.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+"src/diffusers/utils/dummy_*.py" = ["F401"]
+
+[tool.ruff.isort]
+lines-after-imports = 2
+known-first-party = ["diffusers"]
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/diffusers/scripts/__init__.py b/diffusers/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/diffusers/scripts/change_naming_configs_and_checkpoints.py b/diffusers/scripts/change_naming_configs_and_checkpoints.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c4f88c2daf8b40f695bde7b07367e11ae4e3a2
--- /dev/null
+++ b/diffusers/scripts/change_naming_configs_and_checkpoints.py
@@ -0,0 +1,113 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Conversion script for the LDM checkpoints. """
+
+import argparse
+import json
+import os
+
+import torch
+from transformers.file_utils import has_file
+
+from diffusers import UNet2DConditionModel, UNet2DModel
+
+
+do_only_config = False
+do_only_weights = True
+do_only_renaming = False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--repo_path",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the architecture.",
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+
+ args = parser.parse_args()
+
+ config_parameters_to_change = {
+ "image_size": "sample_size",
+ "num_res_blocks": "layers_per_block",
+ "block_channels": "block_out_channels",
+ "down_blocks": "down_block_types",
+ "up_blocks": "up_block_types",
+ "downscale_freq_shift": "freq_shift",
+ "resnet_num_groups": "norm_num_groups",
+ "resnet_act_fn": "act_fn",
+ "resnet_eps": "norm_eps",
+ "num_head_channels": "attention_head_dim",
+ }
+
+ key_parameters_to_change = {
+ "time_steps": "time_proj",
+ "mid": "mid_block",
+ "downsample_blocks": "down_blocks",
+ "upsample_blocks": "up_blocks",
+ }
+
+ subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
+
+ with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
+ text = reader.read()
+ config = json.loads(text)
+
+ if do_only_config:
+ for key in config_parameters_to_change.keys():
+ config.pop(key, None)
+
+ if has_file(args.repo_path, "config.json"):
+ model = UNet2DModel(**config)
+ else:
+ class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
+ model = class_name(**config)
+
+ if do_only_config:
+ model.save_config(os.path.join(args.repo_path, subfolder))
+
+ config = dict(model.config)
+
+ if do_only_renaming:
+ for key, value in config_parameters_to_change.items():
+ if key in config:
+ config[value] = config[key]
+ del config[key]
+
+ config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
+ config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
+
+ if do_only_weights:
+ state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
+
+ new_state_dict = {}
+ for param_key, param_value in state_dict.items():
+ if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
+ continue
+ has_changed = False
+ for key, new_key in key_parameters_to_change.items():
+ if not has_changed and param_key.split(".")[0] == key:
+ new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
+ has_changed = True
+ if not has_changed:
+ new_state_dict[param_key] = param_value
+
+ model.load_state_dict(new_state_dict)
+ model.save_pretrained(os.path.join(args.repo_path, subfolder))
diff --git a/diffusers/scripts/conversion_ldm_uncond.py b/diffusers/scripts/conversion_ldm_uncond.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2ebb3934b6696fd427c9bf09eb051cf7befe7f4
--- /dev/null
+++ b/diffusers/scripts/conversion_ldm_uncond.py
@@ -0,0 +1,56 @@
+import argparse
+
+import OmegaConf
+import torch
+
+from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
+
+
+def convert_ldm_original(checkpoint_path, config_path, output_path):
+ config = OmegaConf.load(config_path)
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+ keys = list(state_dict.keys())
+
+ # extract state_dict for VQVAE
+ first_stage_dict = {}
+ first_stage_key = "first_stage_model."
+ for key in keys:
+ if key.startswith(first_stage_key):
+ first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
+
+ # extract state_dict for UNetLDM
+ unet_state_dict = {}
+ unet_key = "model.diffusion_model."
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
+
+ vqvae_init_args = config.model.params.first_stage_config.params
+ unet_init_args = config.model.params.unet_config.params
+
+ vqvae = VQModel(**vqvae_init_args).eval()
+ vqvae.load_state_dict(first_stage_dict)
+
+ unet = UNetLDMModel(**unet_init_args).eval()
+ unet.load_state_dict(unet_state_dict)
+
+ noise_scheduler = DDIMScheduler(
+ timesteps=config.model.params.timesteps,
+ beta_schedule="scaled_linear",
+ beta_start=config.model.params.linear_start,
+ beta_end=config.model.params.linear_end,
+ clip_sample=False,
+ )
+
+ pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
+ pipeline.save_pretrained(output_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", type=str, required=True)
+ parser.add_argument("--config_path", type=str, required=True)
+ parser.add_argument("--output_path", type=str, required=True)
+ args = parser.parse_args()
+
+ convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
diff --git a/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffb735e18224a7ef48503367112f5ce8142bdf9c
--- /dev/null
+++ b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
@@ -0,0 +1,184 @@
+import argparse
+import time
+from pathlib import Path
+from typing import Any, Dict, Literal
+
+import torch
+
+from diffusers import AsymmetricAutoencoderKL
+
+
+ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "down_block_out_channels": [128, 256, 512, 512],
+ "layers_per_down_block": 2,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "up_block_out_channels": [192, 384, 768, 768],
+ "layers_per_up_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+}
+
+ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "down_block_out_channels": [128, 256, 512, 512],
+ "layers_per_down_block": 2,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "up_block_out_channels": [256, 512, 1024, 1024],
+ "layers_per_up_block": 5,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+}
+
+
+def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if k.startswith("encoder."):
+ converted_state_dict[
+ k.replace("encoder.down.", "encoder.down_blocks.")
+ .replace("encoder.mid.", "encoder.mid_block.")
+ .replace("encoder.norm_out.", "encoder.conv_norm_out.")
+ .replace(".downsample.", ".downsamplers.0.")
+ .replace(".nin_shortcut.", ".conv_shortcut.")
+ .replace(".block.", ".resnets.")
+ .replace(".block_1.", ".resnets.0.")
+ .replace(".block_2.", ".resnets.1.")
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
+ ] = v
+ elif k.startswith("decoder.") and "up_layers" not in k:
+ converted_state_dict[
+ k.replace("decoder.encoder.", "decoder.condition_encoder.")
+ .replace(".norm_out.", ".conv_norm_out.")
+ .replace(".up.0.", ".up_blocks.3.")
+ .replace(".up.1.", ".up_blocks.2.")
+ .replace(".up.2.", ".up_blocks.1.")
+ .replace(".up.3.", ".up_blocks.0.")
+ .replace(".block.", ".resnets.")
+ .replace("mid", "mid_block")
+ .replace(".0.upsample.", ".0.upsamplers.0.")
+ .replace(".1.upsample.", ".1.upsamplers.0.")
+ .replace(".2.upsample.", ".2.upsamplers.0.")
+ .replace(".nin_shortcut.", ".conv_shortcut.")
+ .replace(".block_1.", ".resnets.0.")
+ .replace(".block_2.", ".resnets.1.")
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
+ ] = v
+ elif k.startswith("quant_conv."):
+ converted_state_dict[k] = v
+ elif k.startswith("post_quant_conv."):
+ converted_state_dict[k] = v
+ else:
+ print(f" skipping key `{k}`")
+ # fix weights shape
+ for k, v in converted_state_dict.items():
+ if (
+ (k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
+ and k.endswith("weight")
+ and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
+ ):
+ converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
+
+ return converted_state_dict
+
+
+def get_asymmetric_autoencoder_kl_from_original_checkpoint(
+ scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
+) -> AsymmetricAutoencoderKL:
+ print("Loading original state_dict")
+ original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
+ original_state_dict = original_state_dict["state_dict"]
+ print("Converting state_dict")
+ converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
+ kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
+ print("Initializing AsymmetricAutoencoderKL model")
+ asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
+ print("Loading weight from converted state_dict")
+ asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
+ asymmetric_autoencoder_kl.eval()
+ print("AsymmetricAutoencoderKL successfully initialized")
+ return asymmetric_autoencoder_kl
+
+
+if __name__ == "__main__":
+ start = time.time()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--scale",
+ default=None,
+ type=str,
+ required=True,
+ help="Asymmetric VQGAN scale: `1.5` or `2`",
+ )
+ parser.add_argument(
+ "--original_checkpoint_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the original Asymmetric VQGAN checkpoint",
+ )
+ parser.add_argument(
+ "--output_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to save pretrained AsymmetricAutoencoderKL model",
+ )
+ parser.add_argument(
+ "--map_location",
+ default="cpu",
+ type=str,
+ required=False,
+ help="The device passed to `map_location` when loading the checkpoint",
+ )
+ args = parser.parse_args()
+
+ assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
+ assert Path(args.original_checkpoint_path).is_file()
+
+ asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
+ scale=args.scale,
+ original_checkpoint_path=args.original_checkpoint_path,
+ map_location=torch.device(args.map_location),
+ )
+ print("Saving pretrained AsymmetricAutoencoderKL")
+ asymmetric_autoencoder_kl.save_pretrained(args.output_path)
+ print(f"Done in {time.time() - start:.2f} seconds")
diff --git a/diffusers/scripts/convert_blipdiffusion_to_diffusers.py b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..03cf67e5476bd43260d2829767fffc220a7801c1
--- /dev/null
+++ b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py
@@ -0,0 +1,343 @@
+"""
+This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
+"""
+
+import argparse
+import os
+import tempfile
+
+import torch
+from lavis.models import load_model_and_preprocess
+from transformers import CLIPTokenizer
+from transformers.models.blip_2.configuration_blip_2 import Blip2Config
+
+from diffusers import (
+ AutoencoderKL,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines import BlipDiffusionPipeline
+from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
+from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
+from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
+
+
+BLIP2_CONFIG = {
+ "vision_config": {
+ "hidden_size": 1024,
+ "num_hidden_layers": 23,
+ "num_attention_heads": 16,
+ "image_size": 224,
+ "patch_size": 14,
+ "intermediate_size": 4096,
+ "hidden_act": "quick_gelu",
+ },
+ "qformer_config": {
+ "cross_attention_frequency": 1,
+ "encoder_hidden_size": 1024,
+ "vocab_size": 30523,
+ },
+ "num_query_tokens": 16,
+}
+blip2config = Blip2Config(**BLIP2_CONFIG)
+
+
+def qformer_model_from_original_config():
+ qformer = Blip2QFormerModel(blip2config)
+ return qformer
+
+
+def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
+ embeddings = {}
+ embeddings.update(
+ {
+ f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
+ f"{original_embeddings_prefix}.word_embeddings.weight"
+ ]
+ }
+ )
+ embeddings.update(
+ {
+ f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
+ f"{original_embeddings_prefix}.position_embeddings.weight"
+ ]
+ }
+ )
+ embeddings.update(
+ {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
+ )
+ embeddings.update(
+ {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
+ )
+ return embeddings
+
+
+def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
+ proj_layer = {}
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
+ return proj_layer
+
+
+def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
+ attention = {}
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.attention.query.weight": model[
+ f"{original_attention_prefix}.self.query.weight"
+ ]
+ }
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.attention.value.weight": model[
+ f"{original_attention_prefix}.self.value.weight"
+ ]
+ }
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
+ f"{original_attention_prefix}.output.LayerNorm.weight"
+ ]
+ }
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
+ f"{original_attention_prefix}.output.LayerNorm.bias"
+ ]
+ }
+ )
+ return attention
+
+
+def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
+ output_layers = {}
+ output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
+ output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
+ output_layers.update(
+ {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
+ )
+ output_layers.update(
+ {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
+ )
+ return output_layers
+
+
+def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
+ encoder = {}
+ for i in range(blip2config.qformer_config.num_hidden_layers):
+ encoder.update(
+ attention_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
+ )
+ )
+ encoder.update(
+ attention_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
+ )
+ )
+
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
+ f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
+ f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
+ ]
+ }
+ )
+
+ encoder.update(
+ output_layers_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
+ )
+ )
+ encoder.update(
+ output_layers_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
+ )
+ )
+ return encoder
+
+
+def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
+ visual_encoder_layer = {}
+
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
+ )
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
+
+ return visual_encoder_layer
+
+
+def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
+ visual_encoder = {}
+
+ visual_encoder.update(
+ {
+ f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
+ .unsqueeze(0)
+ .unsqueeze(0)
+ }
+ )
+ visual_encoder.update(
+ {
+ f"{diffuser_prefix}.embeddings.position_embedding": model[
+ f"{original_prefix}.positional_embedding"
+ ].unsqueeze(0)
+ }
+ )
+ visual_encoder.update(
+ {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
+ )
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
+
+ for i in range(blip2config.vision_config.num_hidden_layers):
+ visual_encoder.update(
+ visual_encoder_layer_from_original_checkpoint(
+ model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
+ )
+ )
+
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
+
+ return visual_encoder
+
+
+def qformer_original_checkpoint_to_diffusers_checkpoint(model):
+ qformer_checkpoint = {}
+ qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
+ qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
+ qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
+ qformer_checkpoint.update(
+ encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
+ )
+ qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
+ return qformer_checkpoint
+
+
+def get_qformer(model):
+ print("loading qformer")
+
+ qformer = qformer_model_from_original_config()
+ qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
+
+ load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
+
+ print("done loading qformer")
+ return qformer
+
+
+def load_checkpoint_to_model(checkpoint, model):
+ with tempfile.NamedTemporaryFile(delete=False) as file:
+ torch.save(checkpoint, file.name)
+ del checkpoint
+ model.load_state_dict(torch.load(file.name), strict=False)
+
+ os.remove(file.name)
+
+
+def save_blip_diffusion_model(model, args):
+ qformer = get_qformer(model)
+ qformer.eval()
+
+ text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ vae.eval()
+ text_encoder.eval()
+ scheduler = PNDMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ set_alpha_to_one=False,
+ skip_prk_steps=True,
+ )
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
+ image_processor = BlipImageProcessor()
+ blip_diffusion = BlipDiffusionPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ scheduler=scheduler,
+ qformer=qformer,
+ image_processor=image_processor,
+ )
+ blip_diffusion.save_pretrained(args.checkpoint_path)
+
+
+def main(args):
+ model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
+ save_blip_diffusion_model(model.state_dict(), args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_consistency_decoder.py b/diffusers/scripts/convert_consistency_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a294038a5a33999482d5d0e5c722245f8bbc8c5
--- /dev/null
+++ b/diffusers/scripts/convert_consistency_decoder.py
@@ -0,0 +1,1128 @@
+import math
+import os
+import urllib
+import warnings
+from argparse import ArgumentParser
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub.utils import insecure_hashlib
+from safetensors.torch import load_file as stl
+from tqdm import tqdm
+
+from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
+from diffusers.models.embeddings import TimestepEmbedding
+from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
+from diffusers.models.vae import Encoder
+
+
+args = ArgumentParser()
+args.add_argument("--save_pretrained", required=False, default=None, type=str)
+args.add_argument("--test_image", required=True, type=str)
+args = args.parse_args()
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
+ res = arr[timesteps].float()
+ dims_to_append = len(broadcast_shape) - len(res.shape)
+ return res[(...,) + (None,) * dims_to_append]
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas)
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+class ConsistencyDecoder:
+ def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
+ self.n_distilled_steps = 64
+ download_target = _download(
+ "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
+ download_root,
+ )
+ self.ckpt = torch.jit.load(download_target).to(device)
+ self.device = device
+ sigma_data = 0.5
+ betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
+ sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
+ self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
+ self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
+ self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
+
+ @staticmethod
+ def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
+ with torch.no_grad():
+ space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
+ rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
+ if truncate_start:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ else:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ rounded_timesteps[rounded_timesteps == 0] += space
+ return rounded_timesteps
+
+ @staticmethod
+ def ldm_transform_latent(z, extra_scale_factor=1):
+ channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
+ channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
+
+ if len(z.shape) != 4:
+ raise ValueError()
+
+ z = z * 0.18215
+ channels = [z[:, i] for i in range(z.shape[1])]
+
+ channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
+ return torch.stack(channels, dim=1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ features: torch.Tensor,
+ schedule=[1.0, 0.5],
+ generator=None,
+ ):
+ features = self.ldm_transform_latent(features)
+ ts = self.round_timesteps(
+ torch.arange(0, 1024),
+ 1024,
+ self.n_distilled_steps,
+ truncate_start=False,
+ )
+ shape = (
+ features.size(0),
+ 3,
+ 8 * features.size(2),
+ 8 * features.size(3),
+ )
+ x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
+ schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
+ for i in schedule_timesteps:
+ t = ts[i].item()
+ t_ = torch.tensor([t] * features.shape[0]).to(self.device)
+ # noise = torch.randn_like(x_start)
+ noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
+ x_start = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
+ )
+ c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
+
+ import torch.nn.functional as F
+
+ from diffusers import UNet2DModel
+
+ if isinstance(self.ckpt, UNet2DModel):
+ input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
+ model_output = self.ckpt(input, t_).sample
+ else:
+ model_output = self.ckpt(c_in * x_start, t_, features=features)
+
+ B, C = x_start.shape[:2]
+ model_output, _ = torch.split(model_output, C, dim=1)
+ pred_xstart = (
+ _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
+ + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
+ ).clamp(-1, 1)
+ x_start = pred_xstart
+ return x_start
+
+
+def save_image(image, name):
+ import numpy as np
+ from PIL import Image
+
+ image = image[0].cpu().numpy()
+ image = (image + 1.0) * 127.5
+ image = image.clip(0, 255).astype(np.uint8)
+ image = Image.fromarray(image.transpose(1, 2, 0))
+ image.save(name)
+
+
+def load_image(uri, size=None, center_crop=False):
+ import numpy as np
+ from PIL import Image
+
+ image = Image.open(uri)
+ if center_crop:
+ image = image.crop(
+ (
+ (image.width - min(image.width, image.height)) // 2,
+ (image.height - min(image.width, image.height)) // 2,
+ (image.width + min(image.width, image.height)) // 2,
+ (image.height + min(image.width, image.height)) // 2,
+ )
+ )
+ if size is not None:
+ image = image.resize(size)
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
+ image = image / 127.5 - 1.0
+ return image
+
+
+class TimestepEmbedding_(nn.Module):
+ def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
+ super().__init__()
+ self.emb = nn.Embedding(n_time, n_emb)
+ self.f_1 = nn.Linear(n_emb, n_out)
+ self.f_2 = nn.Linear(n_out, n_out)
+
+ def forward(self, x) -> torch.Tensor:
+ x = self.emb(x)
+ x = self.f_1(x)
+ x = F.silu(x)
+ return self.f_2(x)
+
+
+class ImageEmbedding(nn.Module):
+ def __init__(self, in_channels=7, out_channels=320) -> None:
+ super().__init__()
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(x)
+
+
+class ImageUnembedding(nn.Module):
+ def __init__(self, in_channels=320, out_channels=6) -> None:
+ super().__init__()
+ self.gn = nn.GroupNorm(32, in_channels)
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(F.silu(self.gn(x)))
+
+
+class ConvResblock(nn.Module):
+ def __init__(self, in_features=320, out_features=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, out_features * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_features)
+ self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
+
+ self.gn_2 = nn.GroupNorm(32, out_features)
+ self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
+
+ skip_conv = in_features != out_features
+ self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
+
+ def forward(self, x, t):
+ x_skip = x
+ t = self.f_t(F.silu(t))
+ t = t.chunk(2, dim=1)
+ t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
+ t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ f_1 = self.f_1(gn_1)
+
+ gn_2 = self.gn_2(f_1)
+
+ return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
+
+
+# Also ConvResblock
+class Downsample(nn.Module):
+ def __init__(self, in_channels=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
+
+ f_1 = self.f_1(avg_pool2d)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
+
+
+# Also ConvResblock
+class Upsample(nn.Module):
+ def __init__(self, in_channels=1024) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ upsample = F.upsample_nearest(gn_1, scale_factor=2)
+ f_1 = self.f_1(upsample)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
+
+
+class ConvUNetVAE(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.embed_image = ImageEmbedding()
+ self.embed_time = TimestepEmbedding_()
+
+ down_0 = nn.ModuleList(
+ [
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ Downsample(320),
+ ]
+ )
+ down_1 = nn.ModuleList(
+ [
+ ConvResblock(320, 640),
+ ConvResblock(640, 640),
+ ConvResblock(640, 640),
+ Downsample(640),
+ ]
+ )
+ down_2 = nn.ModuleList(
+ [
+ ConvResblock(640, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ Downsample(1024),
+ ]
+ )
+ down_3 = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+ self.down = nn.ModuleList(
+ [
+ down_0,
+ down_1,
+ down_2,
+ down_3,
+ ]
+ )
+
+ self.mid = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+
+ up_3 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_2 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 + 640, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_1 = nn.ModuleList(
+ [
+ ConvResblock(1024 + 640, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(320 + 640, 640),
+ Upsample(640),
+ ]
+ )
+ up_0 = nn.ModuleList(
+ [
+ ConvResblock(320 + 640, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ]
+ )
+ self.up = nn.ModuleList(
+ [
+ up_0,
+ up_1,
+ up_2,
+ up_3,
+ ]
+ )
+
+ self.output = ImageUnembedding()
+
+ def forward(self, x, t, features) -> torch.Tensor:
+ converted = hasattr(self, "converted") and self.converted
+
+ x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
+
+ if converted:
+ t = self.time_embedding(self.time_proj(t))
+ else:
+ t = self.embed_time(t)
+
+ x = self.embed_image(x)
+
+ skips = [x]
+ for i, down in enumerate(self.down):
+ if converted and i in [0, 1, 2, 3]:
+ x, skips_ = down(x, t)
+ for skip in skips_:
+ skips.append(skip)
+ else:
+ for block in down:
+ x = block(x, t)
+ skips.append(x)
+ print(x.float().abs().sum())
+
+ if converted:
+ x = self.mid(x, t)
+ else:
+ for i in range(2):
+ x = self.mid[i](x, t)
+ print(x.float().abs().sum())
+
+ for i, up in enumerate(self.up[::-1]):
+ if converted and i in [0, 1, 2, 3]:
+ skip_4 = skips.pop()
+ skip_3 = skips.pop()
+ skip_2 = skips.pop()
+ skip_1 = skips.pop()
+ skips_ = (skip_1, skip_2, skip_3, skip_4)
+ x = up(x, skips_, t)
+ else:
+ for block in up:
+ if isinstance(block, ConvResblock):
+ x = torch.concat([x, skips.pop()], dim=1)
+ x = block(x, t)
+
+ return self.output(x)
+
+
+def rename_state_dict_key(k):
+ k = k.replace("blocks.", "")
+ for i in range(5):
+ k = k.replace(f"down_{i}_", f"down.{i}.")
+ k = k.replace(f"conv_{i}.", f"{i}.")
+ k = k.replace(f"up_{i}_", f"up.{i}.")
+ k = k.replace(f"mid_{i}", f"mid.{i}")
+ k = k.replace("upsamp.", "4.")
+ k = k.replace("downsamp.", "3.")
+ k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
+ k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
+ k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
+ k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
+ k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
+ k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
+ k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
+ k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
+ return k
+
+
+def rename_state_dict(sd, embedding):
+ sd = {rename_state_dict_key(k): v for k, v in sd.items()}
+ sd["embed_time.emb.weight"] = embedding["weight"]
+ return sd
+
+
+# encode with stable diffusion vae
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe.vae.cuda()
+
+# construct original decoder with jitted model
+decoder_consistency = ConsistencyDecoder(device="cuda:0")
+
+# construct UNet code, overwrite the decoder with conv_unet_vae
+model = ConvUNetVAE()
+model.load_state_dict(
+ rename_state_dict(
+ stl("consistency_decoder.safetensors"),
+ stl("embedding.safetensors"),
+ )
+)
+model = model.cuda()
+
+decoder_consistency.ckpt = model
+
+image = load_image(args.test_image, size=(256, 256), center_crop=True)
+latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
+
+# decode with gan
+sample_gan = pipe.vae.decode(latent).sample.detach()
+save_image(sample_gan, "gan.png")
+
+# decode with conv_unet_vae
+sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_orig, "con_orig.png")
+
+
+########### conversion
+
+print("CONVERSION")
+
+print("DOWN BLOCK ONE")
+
+block_one_sd_orig = model.down[0].state_dict()
+block_one_sd_new = {}
+
+for i in range(3):
+ block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
+ block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
+ block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
+ block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
+block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
+block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
+block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
+block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
+block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
+block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
+block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
+block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
+block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
+
+assert len(block_one_sd_orig) == 0
+
+block_one = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_one.load_state_dict(block_one_sd_new)
+
+print("DOWN BLOCK TWO")
+
+block_two_sd_orig = model.down[1].state_dict()
+block_two_sd_new = {}
+
+for i in range(3):
+ block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
+ block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
+ block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
+ block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
+block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
+block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
+block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
+block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
+block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
+block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
+block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
+block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
+block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
+
+assert len(block_two_sd_orig) == 0
+
+block_two = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_two.load_state_dict(block_two_sd_new)
+
+print("DOWN BLOCK THREE")
+
+block_three_sd_orig = model.down[2].state_dict()
+block_three_sd_new = {}
+
+for i in range(3):
+ block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
+ block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
+ block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
+ block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
+block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
+block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
+block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
+block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
+block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
+block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
+block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
+block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
+block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
+
+assert len(block_three_sd_orig) == 0
+
+block_three = ResnetDownsampleBlock2D(
+ in_channels=640,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_three.load_state_dict(block_three_sd_new)
+
+print("DOWN BLOCK FOUR")
+
+block_four_sd_orig = model.down[3].state_dict()
+block_four_sd_new = {}
+
+for i in range(3):
+ block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
+ block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
+ block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
+ block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(block_four_sd_orig) == 0
+
+block_four = ResnetDownsampleBlock2D(
+ in_channels=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_four.load_state_dict(block_four_sd_new)
+
+
+print("MID BLOCK 1")
+
+mid_block_one_sd_orig = model.mid.state_dict()
+mid_block_one_sd_new = {}
+
+for i in range(2):
+ mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(mid_block_one_sd_orig) == 0
+
+mid_block_one = UNetMidBlock2D(
+ in_channels=1024,
+ temb_channels=1280,
+ num_layers=1,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+ add_attention=False,
+)
+
+mid_block_one.load_state_dict(mid_block_one_sd_new)
+
+print("UP BLOCK ONE")
+
+up_block_one_sd_orig = model.up[-1].state_dict()
+up_block_one_sd_new = {}
+
+for i in range(4):
+ up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
+up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
+up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
+up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
+up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
+up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
+up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
+up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_one_sd_orig) == 0
+
+up_block_one = ResnetUpsampleBlock2D(
+ in_channels=1024,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_one.load_state_dict(up_block_one_sd_new)
+
+print("UP BLOCK TWO")
+
+up_block_two_sd_orig = model.up[-2].state_dict()
+up_block_two_sd_new = {}
+
+for i in range(4):
+ up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
+up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
+up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
+up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
+up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
+up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
+up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
+up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_two_sd_orig) == 0
+
+up_block_two = ResnetUpsampleBlock2D(
+ in_channels=640,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_two.load_state_dict(up_block_two_sd_new)
+
+print("UP BLOCK THREE")
+
+up_block_three_sd_orig = model.up[-3].state_dict()
+up_block_three_sd_new = {}
+
+for i in range(4):
+ up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
+up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
+up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
+up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
+up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
+up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
+up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
+up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_three_sd_orig) == 0
+
+up_block_three = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=1024,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_three.load_state_dict(up_block_three_sd_new)
+
+print("UP BLOCK FOUR")
+
+up_block_four_sd_orig = model.up[-4].state_dict()
+up_block_four_sd_new = {}
+
+for i in range(4):
+ up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
+
+assert len(up_block_four_sd_orig) == 0
+
+up_block_four = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=640,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_four.load_state_dict(up_block_four_sd_new)
+
+print("initial projection (conv_in)")
+
+conv_in_sd_orig = model.embed_image.state_dict()
+conv_in_sd_new = {}
+
+conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
+conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
+
+assert len(conv_in_sd_orig) == 0
+
+block_out_channels = [320, 640, 1024, 1024]
+
+in_channels = 7
+conv_in_kernel = 3
+conv_in_padding = (conv_in_kernel - 1) // 2
+conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
+
+conv_in.load_state_dict(conv_in_sd_new)
+
+print("out projection (conv_out) (conv_norm_out)")
+out_channels = 6
+norm_num_groups = 32
+norm_eps = 1e-5
+act_fn = "silu"
+conv_out_kernel = 3
+conv_out_padding = (conv_out_kernel - 1) // 2
+conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+# uses torch.functional in orig
+# conv_act = get_activation(act_fn)
+conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
+
+conv_norm_out.load_state_dict(model.output.gn.state_dict())
+conv_out.load_state_dict(model.output.f.state_dict())
+
+print("timestep projection (time_proj) (time_embedding)")
+
+f1_sd = model.embed_time.f_1.state_dict()
+f2_sd = model.embed_time.f_2.state_dict()
+
+time_embedding_sd = {
+ "linear_1.weight": f1_sd.pop("weight"),
+ "linear_1.bias": f1_sd.pop("bias"),
+ "linear_2.weight": f2_sd.pop("weight"),
+ "linear_2.bias": f2_sd.pop("bias"),
+}
+
+assert len(f1_sd) == 0
+assert len(f2_sd) == 0
+
+time_embedding_type = "learned"
+num_train_timesteps = 1024
+time_embedding_dim = 1280
+
+time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
+timestep_input_dim = block_out_channels[0]
+
+time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
+
+time_proj.load_state_dict(model.embed_time.emb.state_dict())
+time_embedding.load_state_dict(time_embedding_sd)
+
+print("CONVERT")
+
+time_embedding.to("cuda")
+time_proj.to("cuda")
+conv_in.to("cuda")
+
+block_one.to("cuda")
+block_two.to("cuda")
+block_three.to("cuda")
+block_four.to("cuda")
+
+mid_block_one.to("cuda")
+
+up_block_one.to("cuda")
+up_block_two.to("cuda")
+up_block_three.to("cuda")
+up_block_four.to("cuda")
+
+conv_norm_out.to("cuda")
+conv_out.to("cuda")
+
+model.time_proj = time_proj
+model.time_embedding = time_embedding
+model.embed_image = conv_in
+
+model.down[0] = block_one
+model.down[1] = block_two
+model.down[2] = block_three
+model.down[3] = block_four
+
+model.mid = mid_block_one
+
+model.up[-1] = up_block_one
+model.up[-2] = up_block_two
+model.up[-3] = up_block_three
+model.up[-4] = up_block_four
+
+model.output.gn = conv_norm_out
+model.output.f = conv_out
+
+model.converted = True
+
+sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new, "con_new.png")
+
+assert (sample_consistency_orig == sample_consistency_new).all()
+
+print("making unet")
+
+unet = UNet2DModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=(
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ up_block_types=(
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ block_out_channels=block_out_channels,
+ layers_per_block=3,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="learned",
+ num_train_timesteps=num_train_timesteps,
+ add_attention=False,
+)
+
+unet_state_dict = {}
+
+
+def add_state_dict(prefix, mod):
+ for k, v in mod.state_dict().items():
+ unet_state_dict[f"{prefix}.{k}"] = v
+
+
+add_state_dict("conv_in", conv_in)
+add_state_dict("time_proj", time_proj)
+add_state_dict("time_embedding", time_embedding)
+add_state_dict("down_blocks.0", block_one)
+add_state_dict("down_blocks.1", block_two)
+add_state_dict("down_blocks.2", block_three)
+add_state_dict("down_blocks.3", block_four)
+add_state_dict("mid_block", mid_block_one)
+add_state_dict("up_blocks.0", up_block_one)
+add_state_dict("up_blocks.1", up_block_two)
+add_state_dict("up_blocks.2", up_block_three)
+add_state_dict("up_blocks.3", up_block_four)
+add_state_dict("conv_norm_out", conv_norm_out)
+add_state_dict("conv_out", conv_out)
+
+unet.load_state_dict(unet_state_dict)
+
+print("running with diffusers unet")
+
+unet.to("cuda")
+
+decoder_consistency.ckpt = unet
+
+sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new_2, "con_new_2.png")
+
+assert (sample_consistency_orig == sample_consistency_new_2).all()
+
+print("running with diffusers model")
+
+Encoder.old_constructor = Encoder.__init__
+
+
+def new_constructor(self, **kwargs):
+ self.old_constructor(**kwargs)
+ self.constructor_arguments = kwargs
+
+
+Encoder.__init__ = new_constructor
+
+
+vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+consistency_vae = ConsistencyDecoderVAE(
+ encoder_args=vae.encoder.constructor_arguments,
+ decoder_args=unet.config,
+ scaling_factor=vae.config.scaling_factor,
+ block_out_channels=vae.config.block_out_channels,
+ latent_channels=vae.config.latent_channels,
+)
+consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
+consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
+consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
+
+consistency_vae.to(dtype=torch.float16, device="cuda")
+
+sample_consistency_new_3 = consistency_vae.decode(
+ 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
+).sample
+
+print("max difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().max())
+print("total difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
+# assert (sample_consistency_orig == sample_consistency_new_3).all()
+
+print("running with diffusers pipeline")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
+)
+pipe.to("cuda")
+
+pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
+
+
+if args.save_pretrained is not None:
+ consistency_vae.save_pretrained(args.save_pretrained)
diff --git a/diffusers/scripts/convert_consistency_to_diffusers.py b/diffusers/scripts/convert_consistency_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f8b4ddca8efd5e4392754f12b67448adad8c0b7
--- /dev/null
+++ b/diffusers/scripts/convert_consistency_to_diffusers.py
@@ -0,0 +1,315 @@
+import argparse
+import os
+
+import torch
+
+from diffusers import (
+ CMStochasticIterativeScheduler,
+ ConsistencyModelPipeline,
+ UNet2DModel,
+)
+
+
+TEST_UNET_CONFIG = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 2,
+ "num_class_embeds": 1000,
+ "block_out_channels": [32, 64],
+ "attention_head_dim": 8,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "scale_shift",
+ "attn_norm_num_groups": 32,
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+IMAGENET_64_UNET_CONFIG = {
+ "sample_size": 64,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 3,
+ "num_class_embeds": 1000,
+ "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
+ "attention_head_dim": 64,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "scale_shift",
+ "attn_norm_num_groups": 32,
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+LSUN_256_UNET_CONFIG = {
+ "sample_size": 256,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 2,
+ "num_class_embeds": None,
+ "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
+ "attention_head_dim": 64,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "default",
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+CD_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 40,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+CT_IMAGENET_64_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 201,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+CT_LSUN_256_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 151,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+
+def str2bool(v):
+ """
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("boolean value expected")
+
+
+def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
+ new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
+ new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
+ new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
+ new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
+ new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
+ new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
+ new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
+ new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
+ new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
+ new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
+
+ if has_skip:
+ new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
+ new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
+
+ return new_checkpoint
+
+
+def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
+ weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
+ bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
+
+ new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
+ new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
+
+ new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
+
+ new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
+ checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
+ )
+ new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
+
+ return new_checkpoint
+
+
+def con_pt_to_diffuser(checkpoint_path: str, unet_config):
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
+
+ if unet_config["num_class_embeds"] is not None:
+ new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
+
+ new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
+
+ down_block_types = unet_config["down_block_types"]
+ layers_per_block = unet_config["layers_per_block"]
+ attention_head_dim = unet_config["attention_head_dim"]
+ channels_list = unet_config["block_out_channels"]
+ current_layer = 1
+ prev_channels = channels_list[0]
+
+ for i, layer_type in enumerate(down_block_types):
+ current_channels = channels_list[i]
+ downsample_block_has_skip = current_channels != prev_channels
+ if layer_type == "ResnetDownsampleBlock2D":
+ for j in range(layers_per_block):
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ has_skip = True if j == 0 and downsample_block_has_skip else False
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
+ current_layer += 1
+
+ elif layer_type == "AttnDownBlock2D":
+ for j in range(layers_per_block):
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ has_skip = True if j == 0 and downsample_block_has_skip else False
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
+ new_prefix = f"down_blocks.{i}.attentions.{j}"
+ old_prefix = f"input_blocks.{current_layer}.1"
+ new_checkpoint = convert_attention(
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
+ )
+ current_layer += 1
+
+ if i != len(down_block_types) - 1:
+ new_prefix = f"down_blocks.{i}.downsamplers.0"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ current_layer += 1
+
+ prev_channels = current_channels
+
+ # hardcoded the mid-block for now
+ new_prefix = "mid_block.resnets.0"
+ old_prefix = "middle_block.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ new_prefix = "mid_block.attentions.0"
+ old_prefix = "middle_block.1"
+ new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
+ new_prefix = "mid_block.resnets.1"
+ old_prefix = "middle_block.2"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+
+ current_layer = 0
+ up_block_types = unet_config["up_block_types"]
+
+ for i, layer_type in enumerate(up_block_types):
+ if layer_type == "ResnetUpsampleBlock2D":
+ for j in range(layers_per_block + 1):
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
+ old_prefix = f"output_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
+ current_layer += 1
+
+ if i != len(up_block_types) - 1:
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
+ old_prefix = f"output_blocks.{current_layer-1}.1"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ elif layer_type == "AttnUpBlock2D":
+ for j in range(layers_per_block + 1):
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
+ old_prefix = f"output_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
+ new_prefix = f"up_blocks.{i}.attentions.{j}"
+ old_prefix = f"output_blocks.{current_layer}.1"
+ new_checkpoint = convert_attention(
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
+ )
+ current_layer += 1
+
+ if i != len(up_block_types) - 1:
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
+ old_prefix = f"output_blocks.{current_layer-1}.2"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
+ parser.add_argument(
+ "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
+ )
+ parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
+
+ args = parser.parse_args()
+ args.class_cond = str2bool(args.class_cond)
+
+ ckpt_name = os.path.basename(args.unet_path)
+ print(f"Checkpoint: {ckpt_name}")
+
+ # Get U-Net config
+ if "imagenet64" in ckpt_name:
+ unet_config = IMAGENET_64_UNET_CONFIG
+ elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
+ unet_config = LSUN_256_UNET_CONFIG
+ elif "test" in ckpt_name:
+ unet_config = TEST_UNET_CONFIG
+ else:
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
+
+ if not args.class_cond:
+ unet_config["num_class_embeds"] = None
+
+ converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
+
+ image_unet = UNet2DModel(**unet_config)
+ image_unet.load_state_dict(converted_unet_ckpt)
+
+ # Get scheduler config
+ if "cd" in ckpt_name or "test" in ckpt_name:
+ scheduler_config = CD_SCHEDULER_CONFIG
+ elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
+ scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
+ elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
+ scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
+ else:
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
+
+ cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
+
+ consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
+ consistency_model.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_dance_diffusion_to_diffusers.py b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d53d1f792e89be30e26cd701c178083e94699f00
--- /dev/null
+++ b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py
@@ -0,0 +1,339 @@
+#!/usr/bin/env python3
+import argparse
+import math
+import os
+from copy import deepcopy
+
+import torch
+from audio_diffusion.models import DiffusionAttnUnet1D
+from diffusion import sampling
+from torch import nn
+
+from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
+
+
+MODELS_MAP = {
+ "gwf-440k": {
+ "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 65536,
+ },
+ "jmann-small-190k": {
+ "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 65536,
+ },
+ "jmann-large-580k": {
+ "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 131072,
+ },
+ "maestro-uncond-150k": {
+ "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+ "unlocked-uncond-250k": {
+ "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+ "honk-140k": {
+ "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+}
+
+
+def alpha_sigma_to_t(alpha, sigma):
+ """Returns a timestep, given the scaling factors for the clean image and for
+ the noise."""
+ return torch.atan2(sigma, alpha) / math.pi * 2
+
+
+def get_crash_schedule(t):
+ sigma = torch.sin(t * math.pi / 2) ** 2
+ alpha = (1 - sigma**2) ** 0.5
+ return alpha_sigma_to_t(alpha, sigma)
+
+
+class Object(object):
+ pass
+
+
+class DiffusionUncond(nn.Module):
+ def __init__(self, global_args):
+ super().__init__()
+
+ self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
+ self.diffusion_ema = deepcopy(self.diffusion)
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
+
+
+def download(model_name):
+ url = MODELS_MAP[model_name]["url"]
+ os.system(f"wget {url} ./")
+
+ return f"./{model_name}.ckpt"
+
+
+DOWN_NUM_TO_LAYER = {
+ "1": "resnets.0",
+ "2": "attentions.0",
+ "3": "resnets.1",
+ "4": "attentions.1",
+ "5": "resnets.2",
+ "6": "attentions.2",
+}
+UP_NUM_TO_LAYER = {
+ "8": "resnets.0",
+ "9": "attentions.0",
+ "10": "resnets.1",
+ "11": "attentions.1",
+ "12": "resnets.2",
+ "13": "attentions.2",
+}
+MID_NUM_TO_LAYER = {
+ "1": "resnets.0",
+ "2": "attentions.0",
+ "3": "resnets.1",
+ "4": "attentions.1",
+ "5": "resnets.2",
+ "6": "attentions.2",
+ "8": "resnets.3",
+ "9": "attentions.3",
+ "10": "resnets.4",
+ "11": "attentions.4",
+ "12": "resnets.5",
+ "13": "attentions.5",
+}
+DEPTH_0_TO_LAYER = {
+ "0": "resnets.0",
+ "1": "resnets.1",
+ "2": "resnets.2",
+ "4": "resnets.0",
+ "5": "resnets.1",
+ "6": "resnets.2",
+}
+
+RES_CONV_MAP = {
+ "skip": "conv_skip",
+ "main.0": "conv_1",
+ "main.1": "group_norm_1",
+ "main.3": "conv_2",
+ "main.4": "group_norm_2",
+}
+
+ATTN_MAP = {
+ "norm": "group_norm",
+ "qkv_proj": ["query", "key", "value"],
+ "out_proj": ["proj_attn"],
+}
+
+
+def convert_resconv_naming(name):
+ if name.startswith("skip"):
+ return name.replace("skip", RES_CONV_MAP["skip"])
+
+ # name has to be of format main.{digit}
+ if not name.startswith("main."):
+ raise ValueError(f"ResConvBlock error with {name}")
+
+ return name.replace(name[:6], RES_CONV_MAP[name[:6]])
+
+
+def convert_attn_naming(name):
+ for key, value in ATTN_MAP.items():
+ if name.startswith(key) and not isinstance(value, list):
+ return name.replace(key, value)
+ elif name.startswith(key):
+ return [name.replace(key, v) for v in value]
+ raise ValueError(f"Attn error with {name}")
+
+
+def rename(input_string, max_depth=13):
+ string = input_string
+
+ if string.split(".")[0] == "timestep_embed":
+ return string.replace("timestep_embed", "time_proj")
+
+ depth = 0
+ if string.startswith("net.3."):
+ depth += 1
+ string = string[6:]
+ elif string.startswith("net."):
+ string = string[4:]
+
+ while string.startswith("main.7."):
+ depth += 1
+ string = string[7:]
+
+ if string.startswith("main."):
+ string = string[5:]
+
+ # mid block
+ if string[:2].isdigit():
+ layer_num = string[:2]
+ string_left = string[2:]
+ else:
+ layer_num = string[0]
+ string_left = string[1:]
+
+ if depth == max_depth:
+ new_layer = MID_NUM_TO_LAYER[layer_num]
+ prefix = "mid_block"
+ elif depth > 0 and int(layer_num) < 7:
+ new_layer = DOWN_NUM_TO_LAYER[layer_num]
+ prefix = f"down_blocks.{depth}"
+ elif depth > 0 and int(layer_num) > 7:
+ new_layer = UP_NUM_TO_LAYER[layer_num]
+ prefix = f"up_blocks.{max_depth - depth - 1}"
+ elif depth == 0:
+ new_layer = DEPTH_0_TO_LAYER[layer_num]
+ prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
+
+ if not string_left.startswith("."):
+ raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
+
+ string_left = string_left[1:]
+
+ if "resnets" in new_layer:
+ string_left = convert_resconv_naming(string_left)
+ elif "attentions" in new_layer:
+ new_string_left = convert_attn_naming(string_left)
+ string_left = new_string_left
+
+ if not isinstance(string_left, list):
+ new_string = prefix + "." + new_layer + "." + string_left
+ else:
+ new_string = [prefix + "." + new_layer + "." + s for s in string_left]
+ return new_string
+
+
+def rename_orig_weights(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.endswith("kernel"):
+ # up- and downsample layers, don't have trainable weights
+ continue
+
+ new_k = rename(k)
+
+ # check if we need to transform from Conv => Linear for attention
+ if isinstance(new_k, list):
+ new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
+ else:
+ new_state_dict[new_k] = v
+
+ return new_state_dict
+
+
+def transform_conv_attns(new_state_dict, new_k, v):
+ if len(new_k) == 1:
+ if len(v.shape) == 3:
+ # weight
+ new_state_dict[new_k[0]] = v[:, :, 0]
+ else:
+ # bias
+ new_state_dict[new_k[0]] = v
+ else:
+ # qkv matrices
+ trippled_shape = v.shape[0]
+ single_shape = trippled_shape // 3
+ for i in range(3):
+ if len(v.shape) == 3:
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
+ else:
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
+ return new_state_dict
+
+
+def main(args):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ model_name = args.model_path.split("/")[-1].split(".")[0]
+ if not os.path.isfile(args.model_path):
+ assert (
+ model_name == args.model_path
+ ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ args.model_path = download(model_name)
+
+ sample_rate = MODELS_MAP[model_name]["sample_rate"]
+ sample_size = MODELS_MAP[model_name]["sample_size"]
+
+ config = Object()
+ config.sample_size = sample_size
+ config.sample_rate = sample_rate
+ config.latent_dim = 0
+
+ diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
+ diffusers_state_dict = diffusers_model.state_dict()
+
+ orig_model = DiffusionUncond(config)
+ orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
+ orig_model = orig_model.diffusion_ema.eval()
+ orig_model_state_dict = orig_model.state_dict()
+ renamed_state_dict = rename_orig_weights(orig_model_state_dict)
+
+ renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
+ diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
+
+ assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
+ assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
+
+ for key, value in renamed_state_dict.items():
+ assert (
+ diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
+ ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ if key == "time_proj.weight":
+ value = value.squeeze()
+
+ diffusers_state_dict[key] = value
+
+ diffusers_model.load_state_dict(diffusers_state_dict)
+
+ steps = 100
+ seed = 33
+
+ diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
+
+ generator = torch.manual_seed(seed)
+ noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
+
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
+ step_list = get_crash_schedule(t)
+
+ pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
+
+ generator = torch.manual_seed(33)
+ audio = pipe(num_inference_steps=steps, generator=generator).audios
+
+ generated = sampling.iplms_sample(orig_model, noise, step_list, {})
+ generated = generated.clamp(-1, 1)
+
+ diff_sum = (generated - audio).abs().sum()
+ diff_max = (generated - audio).abs().max()
+
+ if args.save:
+ pipe.save_pretrained(args.checkpoint_path)
+
+ print("Diff sum", diff_sum)
+ print("Diff max", diff_max)
+
+ assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
+
+ print(f"Conversion for {model_name} successful!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
+ )
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..46595784b0bac0016b623b7122082275248363e9
--- /dev/null
+++ b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
@@ -0,0 +1,431 @@
+import argparse
+import json
+
+import torch
+
+from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+ new_item = new_item.replace("block.", "resnets.")
+ new_item = new_item.replace("conv_shorcut", "conv1")
+ new_item = new_item.replace("in_shortcut", "conv_shortcut")
+ new_item = new_item.replace("temb_proj", "time_emb_proj")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # In `model.mid`, the layer is called `attn`.
+ if not in_mid:
+ new_item = new_item.replace("attn", "attentions")
+ new_item = new_item.replace(".k.", ".key.")
+ new_item = new_item.replace(".v.", ".value.")
+ new_item = new_item.replace(".q.", ".query.")
+
+ new_item = new_item.replace("proj_out", "proj_attn")
+ new_item = new_item.replace("norm", "group_norm")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ if attention_paths_to_split is not None:
+ if config is None:
+ raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
+
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
+ checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
+ checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
+
+ for path in paths:
+ new_path = path["new"]
+
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ new_path = new_path.replace("down.", "down_blocks.")
+ new_path = new_path.replace("up.", "up_blocks.")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ if "attentions" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def convert_ddpm_checkpoint(checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
+
+ new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
+ new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
+ new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
+ new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
+
+ num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
+ down_blocks = {
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+
+ if any("downsample" in layer for layer in down_blocks[i]):
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
+ f"down.{i}.downsample.op.weight"
+ ]
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
+
+ if any("block" in layer for layer in down_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
+
+ if any("attn" in layer for layer in down_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
+
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
+
+ # Mid new 2
+ paths = renew_resnet_paths(mid_block_1_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
+ )
+
+ paths = renew_resnet_paths(mid_block_2_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
+ )
+
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
+ )
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ if any("upsample" in layer for layer in up_blocks[i]):
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
+ f"up.{i}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
+
+ if any("block" in layer for layer in up_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ if any("attn" in layer for layer in up_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
+ return new_checkpoint
+
+
+def convert_vq_autoenc_checkpoint(checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
+
+ new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
+
+ new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
+
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
+ down_blocks = {
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+
+ if any("downsample" in layer for layer in down_blocks[i]):
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
+ f"encoder.down.{i}.downsample.conv.weight"
+ ]
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
+ f"encoder.down.{i}.downsample.conv.bias"
+ ]
+
+ if any("block" in layer for layer in down_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
+
+ if any("attn" in layer for layer in down_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
+
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
+
+ # Mid new 2
+ paths = renew_resnet_paths(mid_block_1_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
+ )
+
+ paths = renew_resnet_paths(mid_block_2_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
+ )
+
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
+ )
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ if any("upsample" in layer for layer in up_blocks[i]):
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
+ f"decoder.up.{i}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
+ f"decoder.up.{i}.upsample.conv.bias"
+ ]
+
+ if any("block" in layer for layer in up_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ if any("attn" in layer for layer in up_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
+ new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
+ if "quantize.embedding.weight" in checkpoint:
+ new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
+ new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+
+ parser.add_argument(
+ "--config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the architecture.",
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+
+ args = parser.parse_args()
+ checkpoint = torch.load(args.checkpoint_path)
+
+ with open(args.config_file) as f:
+ config = json.loads(f.read())
+
+ # unet case
+ key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
+ if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
+ converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
+ else:
+ converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
+
+ if "ddpm" in config:
+ del config["ddpm"]
+
+ if config["_class_name"] == "VQModel":
+ model = VQModel(**config)
+ model.load_state_dict(converted_checkpoint)
+ model.save_pretrained(args.dump_path)
+ elif config["_class_name"] == "AutoencoderKL":
+ model = AutoencoderKL(**config)
+ model.load_state_dict(converted_checkpoint)
+ model.save_pretrained(args.dump_path)
+ else:
+ model = UNet2DModel(**config)
+ model.load_state_dict(converted_checkpoint)
+
+ scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
+
+ pipe = DDPMPipeline(unet=model, scheduler=scheduler)
+ pipe.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_diffusers_to_original_sdxl.py b/diffusers/scripts/convert_diffusers_to_original_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f11ef45706898cf4408fdeecbb3b3249aa45d76
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_to_original_sdxl.py
@@ -0,0 +1,340 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+from safetensors.torch import load_file, save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+ # the following are for sdxl
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(3):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i > 0:
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(4):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i < 2:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ # the following are for SDXL
+ ("q.", "to_q."),
+ ("k.", "to_k."),
+ ("v.", "to_v."),
+ ("proj_out.", "to_out.0."),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("transformer.resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "text_model.final_layer_norm."),
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_openclip_text_enc_state_dict(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_openai_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ # Path for safetensors
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
+
+ # Load models from safetensors if it exists, if it doesn't pytorch
+ if osp.exists(unet_path):
+ unet_state_dict = load_file(unet_path, device="cpu")
+ else:
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+
+ if osp.exists(vae_path):
+ vae_state_dict = load_file(vae_path, device="cpu")
+ else:
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+
+ if osp.exists(text_enc_path):
+ text_enc_dict = load_file(text_enc_path, device="cpu")
+ else:
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ if osp.exists(text_enc_2_path):
+ text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
+ else:
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
+ text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
+ text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
+
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9da45211551e32acf34c883c1d6c5218a7bd6dd7
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -0,0 +1,333 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+from safetensors.torch import load_file, save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_text_enc_state_dict_v20(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ # Path for safetensors
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
+
+ # Load models from safetensors if it exists, if it doesn't pytorch
+ if osp.exists(unet_path):
+ unet_state_dict = load_file(unet_path, device="cpu")
+ else:
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+
+ if osp.exists(vae_path):
+ vae_state_dict = load_file(vae_path, device="cpu")
+ else:
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+
+ if osp.exists(text_enc_path):
+ text_enc_dict = load_file(text_enc_path, device="cpu")
+ else:
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
+
+ if is_v20_model:
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
+ else:
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/diffusers/scripts/convert_dit_to_diffusers.py b/diffusers/scripts/convert_dit_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc127f69555c260f594e70444b1540faa196e3fb
--- /dev/null
+++ b/diffusers/scripts/convert_dit_to_diffusers.py
@@ -0,0 +1,162 @@
+import argparse
+import os
+
+import torch
+from torchvision.datasets.utils import download_url
+
+from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
+
+
+pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
+
+
+def download_model(model_name):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ local_path = f"pretrained_models/{model_name}"
+ if not os.path.isfile(local_path):
+ os.makedirs("pretrained_models", exist_ok=True)
+ web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
+ download_url(web_path, "pretrained_models")
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+def main(args):
+ state_dict = download_model(pretrained_models[args.image_size])
+
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ for depth in range(28):
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
+ "t_embedder.mlp.0.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
+ "t_embedder.mlp.0.bias"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
+ "t_embedder.mlp.2.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
+ "t_embedder.mlp.2.bias"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
+ "y_embedder.embedding_table.weight"
+ ]
+
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
+ f"blocks.{depth}.adaLN_modulation.1.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
+ f"blocks.{depth}.adaLN_modulation.1.bias"
+ ]
+
+ q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
+
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
+
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
+ f"blocks.{depth}.attn.proj.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
+
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
+
+ state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
+ state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
+ state_dict.pop(f"blocks.{depth}.attn.proj.weight")
+ state_dict.pop(f"blocks.{depth}.attn.proj.bias")
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
+
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("y_embedder.embedding_table.weight")
+
+ state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
+ state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
+ state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
+ state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
+
+ state_dict.pop("final_layer.linear.weight")
+ state_dict.pop("final_layer.linear.bias")
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
+
+ # DiT XL/2
+ transformer = Transformer2DModel(
+ sample_size=args.image_size // 8,
+ num_layers=28,
+ attention_head_dim=72,
+ in_channels=4,
+ out_channels=8,
+ patch_size=2,
+ attention_bias=True,
+ num_attention_heads=16,
+ activation_fn="gelu-approximate",
+ num_embeds_ada_norm=1000,
+ norm_type="ada_norm_zero",
+ norm_elementwise_affine=False,
+ )
+ transformer.load_state_dict(state_dict, strict=True)
+
+ scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_schedule="linear",
+ prediction_type="epsilon",
+ clip_sample=False,
+ )
+
+ vae = AutoencoderKL.from_pretrained(args.vae_model)
+
+ pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
+
+ if args.save:
+ pipeline.save_pretrained(args.checkpoint_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--image_size",
+ default=256,
+ type=int,
+ required=False,
+ help="Image size of pretrained model, either 256 or 512.",
+ )
+ parser.add_argument(
+ "--vae_model",
+ default="stabilityai/sd-vae-ft-ema",
+ type=str,
+ required=False,
+ help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
+ )
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_gligen_to_diffusers.py b/diffusers/scripts/convert_gligen_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..816e4c112e6fc342db40a66722641a07412ddc22
--- /dev/null
+++ b/diffusers/scripts/convert_gligen_to_diffusers.py
@@ -0,0 +1,587 @@
+import argparse
+import re
+
+import torch
+from transformers import (
+ CLIPProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ StableDiffusionGLIGENPipeline,
+ StableDiffusionGLIGENTextImagePipeline,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
+ assign_to_checkpoint,
+ conv_attn_to_linear,
+ protected,
+ renew_attention_paths,
+ renew_resnet_paths,
+ renew_vae_attention_paths,
+ renew_vae_resnet_paths,
+ shave_segments,
+ textenc_conversion_map,
+ textenc_pattern,
+)
+from diffusers.utils import is_omegaconf_available
+from diffusers.utils.import_utils import BACKENDS_MAPPING
+
+
+def convert_open_clip_checkpoint(checkpoint):
+ checkpoint = checkpoint["text_encoder"]
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ if "cond_stage_model.model.text_projection" in checkpoint:
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
+ else:
+ d_model = 1024
+
+ for key in keys:
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
+ continue
+ if key in textenc_conversion_map:
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
+ # if key.startswith("cond_stage_model.model.transformer."):
+ new_key = key[len("transformer.") :]
+ if new_key.endswith(".in_proj_weight"):
+ new_key = new_key[: -len(".in_proj_weight")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
+ elif new_key.endswith(".in_proj_bias"):
+ new_key = new_key[: -len(".in_proj_bias")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
+ else:
+ if key != "transformer.text_model.embeddings.position_ids":
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+
+ text_model_dict[new_key] = checkpoint[key]
+
+ if key == "transformer.text_model.embeddings.token_embedding.weight":
+ text_model_dict["text_model.embeddings.token_embedding.weight"] = checkpoint[key]
+
+ text_model_dict.pop("text_model.embeddings.transformer.text_model.embeddings.token_embedding.weight")
+
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+def convert_gligen_vae_checkpoint(checkpoint, config):
+ checkpoint = checkpoint["autoencoder"]
+ vae_state_dict = {}
+ vae_key = "first_stage_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for key in new_checkpoint.keys():
+ if "encoder.mid_block.attentions.0" in key or "decoder.mid_block.attentions.0" in key:
+ if "query" in key:
+ new_checkpoint[key.replace("query", "to_q")] = new_checkpoint.pop(key)
+ if "value" in key:
+ new_checkpoint[key.replace("value", "to_v")] = new_checkpoint.pop(key)
+ if "key" in key:
+ new_checkpoint[key.replace("key", "to_k")] = new_checkpoint.pop(key)
+ if "proj_attn" in key:
+ new_checkpoint[key.replace("proj_attn", "to_out.0")] = new_checkpoint.pop(key)
+
+ return new_checkpoint
+
+
+def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
+ unet_state_dict = {}
+ checkpoint = checkpoint["model"]
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ print(f"Checkpoint {path} has bot EMA and non-EMA weights.")
+ print(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ print(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+ for key in keys:
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ for key in keys:
+ if "position_net" in key:
+ new_checkpoint[key] = unet_state_dict[key]
+
+ return new_checkpoint
+
+
+def create_vae_config(original_config, image_size: int):
+ vae_params = original_config.autoencoder.params.ddconfig
+ _ = original_config.autoencoder.params.embed_dim
+
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = {
+ "sample_size": image_size,
+ "in_channels": vae_params.in_channels,
+ "out_channels": vae_params.out_ch,
+ "down_block_types": tuple(down_block_types),
+ "up_block_types": tuple(up_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "latent_channels": vae_params.z_channels,
+ "layers_per_block": vae_params.num_res_blocks,
+ }
+
+ return config
+
+
+def create_unet_config(original_config, image_size: int, attention_type):
+ unet_params = original_config.model.params
+ vae_params = original_config.autoencoder.params.ddconfig
+
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
+
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params.in_channels,
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params.num_res_blocks,
+ "cross_attention_dim": unet_params.context_dim,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "attention_type": attention_type,
+ }
+
+ return config
+
+
+def convert_gligen_to_diffusers(
+ checkpoint_path: str,
+ original_config_file: str,
+ attention_type: str,
+ image_size: int = 512,
+ extract_ema: bool = False,
+ num_in_channels: int = None,
+ device: str = None,
+):
+ if not is_omegaconf_available():
+ raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
+
+ from omegaconf import OmegaConf
+
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ if "global_step" in checkpoint:
+ checkpoint["global_step"]
+ else:
+ print("global_step key not found in model")
+
+ original_config = OmegaConf.load(original_config_file)
+
+ if num_in_channels is not None:
+ original_config["model"]["params"]["in_channels"] = num_in_channels
+
+ num_train_timesteps = original_config.diffusion.params.timesteps
+ beta_start = original_config.diffusion.params.linear_start
+ beta_end = original_config.diffusion.params.linear_end
+
+ scheduler = DDIMScheduler(
+ beta_end=beta_end,
+ beta_schedule="scaled_linear",
+ beta_start=beta_start,
+ num_train_timesteps=num_train_timesteps,
+ steps_offset=1,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ prediction_type="epsilon",
+ )
+
+ # Convert the UNet2DConditionalModel model
+ unet_config = create_unet_config(original_config, image_size, attention_type)
+ unet = UNet2DConditionModel(**unet_config)
+
+ converted_unet_checkpoint = convert_gligen_unet_checkpoint(
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
+ )
+
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ # Convert the VAE model
+ vae_config = create_vae_config(original_config, image_size)
+ converted_vae_checkpoint = convert_gligen_vae_checkpoint(checkpoint, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_checkpoint)
+
+ # Convert the text model
+ text_encoder = convert_open_clip_checkpoint(checkpoint)
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+
+ if attention_type == "gated-text-image":
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+
+ pipe = StableDiffusionGLIGENTextImagePipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ processor=processor,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ elif attention_type == "gated":
+ pipe = StableDiffusionGLIGENPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+
+ return pipe
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--original_config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The YAML config file corresponding to the gligen architecture.",
+ )
+ parser.add_argument(
+ "--num_in_channels",
+ default=None,
+ type=int,
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
+ )
+ parser.add_argument(
+ "--extract_ema",
+ action="store_true",
+ help=(
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
+ ),
+ )
+ parser.add_argument(
+ "--attention_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Type of attention ex: gated or gated-text-image",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--device", type=str, help="Device to use.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+
+ args = parser.parse_args()
+
+ pipe = convert_gligen_to_diffusers(
+ checkpoint_path=args.checkpoint_path,
+ original_config_file=args.original_config_file,
+ attention_type=args.attention_type,
+ extract_ema=args.extract_ema,
+ num_in_channels=args.num_in_channels,
+ device=args.device,
+ )
+
+ if args.half:
+ pipe.to(torch_dtype=torch.float16)
+
+ pipe.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_if.py b/diffusers/scripts/convert_if.py
new file mode 100644
index 0000000000000000000000000000000000000000..66d7f694c8e1f50d5c7aad09f9e465d16689d5f0
--- /dev/null
+++ b/diffusers/scripts/convert_if.py
@@ -0,0 +1,1257 @@
+import argparse
+import inspect
+import os
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
+
+from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel
+from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
+
+
+try:
+ from omegaconf import OmegaConf
+except ImportError:
+ raise ImportError(
+ "OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`."
+ )
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dump_path", required=False, default=None, type=str)
+
+ parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str)
+
+ parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str)
+
+ parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file")
+
+ parser.add_argument(
+ "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file"
+ )
+
+ parser.add_argument(
+ "--unet_checkpoint_path_stage_2",
+ required=False,
+ default=None,
+ type=str,
+ help="Path to stage 2 unet checkpoint file",
+ )
+
+ parser.add_argument(
+ "--unet_checkpoint_path_stage_3",
+ required=False,
+ default=None,
+ type=str,
+ help="Path to stage 3 unet checkpoint file",
+ )
+
+ parser.add_argument("--p_head_path", type=str, required=True)
+
+ parser.add_argument("--w_head_path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main(args):
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
+ text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
+
+ feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path)
+
+ if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None:
+ convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args)
+
+ if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None:
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2)
+
+ if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None:
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3)
+
+
+def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args):
+ unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path)
+
+ scheduler = DDPMScheduler(
+ variance_type="learned_range",
+ beta_schedule="squaredcos_cap_v2",
+ prediction_type="epsilon",
+ thresholding=True,
+ dynamic_thresholding_ratio=0.95,
+ sample_max_value=1.5,
+ )
+
+ pipe = IFPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=True,
+ )
+
+ pipe.save_pretrained(args.dump_path)
+
+
+def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage):
+ if stage == 2:
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_2
+ sample_size = None
+ dump_path = args.dump_path_stage_2
+ elif stage == 3:
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_3
+ sample_size = 1024
+ dump_path = args.dump_path_stage_3
+ else:
+ assert False
+
+ unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size)
+
+ image_noising_scheduler = DDPMScheduler(
+ beta_schedule="squaredcos_cap_v2",
+ )
+
+ scheduler = DDPMScheduler(
+ variance_type="learned_range",
+ beta_schedule="squaredcos_cap_v2",
+ prediction_type="epsilon",
+ thresholding=True,
+ dynamic_thresholding_ratio=0.95,
+ sample_max_value=1.0,
+ )
+
+ pipe = IFSuperResolutionPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ image_noising_scheduler=image_noising_scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=True,
+ )
+
+ pipe.save_pretrained(dump_path)
+
+
+def get_stage_1_unet(unet_config, unet_checkpoint_path):
+ original_unet_config = OmegaConf.load(unet_config)
+ original_unet_config = original_unet_config.params
+
+ unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
+
+ unet = UNet2DConditionModel(**unet_diffusers_config)
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device)
+
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
+ )
+
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ return unet
+
+
+def convert_safety_checker(p_head_path, w_head_path):
+ state_dict = {}
+
+ # p head
+
+ p_head = np.load(p_head_path)
+
+ p_head_weights = p_head["weights"]
+ p_head_weights = torch.from_numpy(p_head_weights)
+ p_head_weights = p_head_weights.unsqueeze(0)
+
+ p_head_biases = p_head["biases"]
+ p_head_biases = torch.from_numpy(p_head_biases)
+ p_head_biases = p_head_biases.unsqueeze(0)
+
+ state_dict["p_head.weight"] = p_head_weights
+ state_dict["p_head.bias"] = p_head_biases
+
+ # w head
+
+ w_head = np.load(w_head_path)
+
+ w_head_weights = w_head["weights"]
+ w_head_weights = torch.from_numpy(w_head_weights)
+ w_head_weights = w_head_weights.unsqueeze(0)
+
+ w_head_biases = w_head["biases"]
+ w_head_biases = torch.from_numpy(w_head_biases)
+ w_head_biases = w_head_biases.unsqueeze(0)
+
+ state_dict["w_head.weight"] = w_head_weights
+ state_dict["w_head.bias"] = w_head_biases
+
+ # vision model
+
+ vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ vision_model_state_dict = vision_model.state_dict()
+
+ for key, value in vision_model_state_dict.items():
+ key = f"vision_model.{key}"
+ state_dict[key] = value
+
+ # full model
+
+ config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
+ safety_checker = IFSafetyChecker(config)
+
+ safety_checker.load_state_dict(state_dict)
+
+ return safety_checker
+
+
+def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
+ attention_resolutions = parse_list(original_unet_config.attention_resolutions)
+ attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions]
+
+ channel_mult = parse_list(original_unet_config.channel_mult)
+ block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult]
+
+ down_block_types = []
+ resolution = 1
+
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnDownBlock2D"
+ elif original_unet_config.resblock_updown:
+ block_type = "ResnetDownsampleBlock2D"
+ else:
+ block_type = "DownBlock2D"
+
+ down_block_types.append(block_type)
+
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnUpBlock2D"
+ elif original_unet_config.resblock_updown:
+ block_type = "ResnetUpsampleBlock2D"
+ else:
+ block_type = "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ head_dim = original_unet_config.num_head_channels
+
+ use_linear_projection = (
+ original_unet_config.use_linear_in_transformer
+ if "use_linear_in_transformer" in original_unet_config
+ else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ projection_class_embeddings_input_dim = None
+
+ if class_embed_type is None:
+ if "num_classes" in original_unet_config:
+ if original_unet_config.num_classes == "sequential":
+ class_embed_type = "projection"
+ assert "adm_in_channels" in original_unet_config
+ projection_class_embeddings_input_dim = original_unet_config.adm_in_channels
+ else:
+ raise NotImplementedError(
+ f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}"
+ )
+
+ config = {
+ "sample_size": original_unet_config.image_size,
+ "in_channels": original_unet_config.in_channels,
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": original_unet_config.num_res_blocks,
+ "cross_attention_dim": original_unet_config.encoder_channels,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "out_channels": original_unet_config.out_channels,
+ "up_block_types": tuple(up_block_types),
+ "upcast_attention": False, # TODO: guessing
+ "cross_attention_norm": "group_norm",
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "addition_embed_type": "text",
+ "act_fn": "gelu",
+ }
+
+ if original_unet_config.use_scale_shift_norm:
+ config["resnet_time_scale_shift"] = "scale_shift"
+
+ if "encoder_dim" in original_unet_config:
+ config["encoder_hid_dim"] = original_unet_config.encoder_dim
+
+ return config
+
+
+def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] in [None, "identity"]:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+
+ # TODO need better check than i in [4, 8, 12, 16]
+ block_type = config["down_block_types"][block_id]
+ if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [
+ 4,
+ 8,
+ 12,
+ 16,
+ ]:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
+ else:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ old_path = f"input_blocks.{i}.1"
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": old_path, "new": new_path}
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ unet_state_dict,
+ additional_replacements=[meta_path],
+ config=config,
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ old_path = "middle_block.1"
+ new_path = "mid_block.attentions.0"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ # len(output_block_list) == 1 -> resnet
+ # len(output_block_list) == 2 -> resnet, attention
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ old_path = f"output_blocks.{i}.1"
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": old_path,
+ "new": new_path,
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(output_block_list) == 3:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ if "encoder_proj.weight" in unet_state_dict:
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight")
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias")
+
+ if "encoder_pooling.0.weight" in unet_state_dict:
+ new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight")
+ new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias")
+
+ new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop(
+ "encoder_pooling.1.positional_embedding"
+ )
+ new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight")
+ new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias")
+ new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight")
+ new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias")
+ new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight")
+ new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias")
+
+ new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight")
+ new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias")
+
+ new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight")
+ new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias")
+
+ return new_checkpoint
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ if "qkv" in new_item:
+ continue
+
+ if "encoder_kv" in new_item:
+ continue
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight")
+ new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config):
+ qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight")
+ qkv_weight = qkv_weight[:, :, 0]
+
+ qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias")
+
+ is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"]
+
+ split = 1 if is_cross_attn_only else 3
+
+ weights, bias = split_attentions(
+ weight=qkv_weight,
+ bias=qkv_bias,
+ split=split,
+ chunk_size=config["attention_head_dim"],
+ )
+
+ if is_cross_attn_only:
+ query_weight, q_bias = weights, bias
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0]
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0]
+ else:
+ [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias
+ new_checkpoint[f"{new_path}.to_k.weight"] = key_weight
+ new_checkpoint[f"{new_path}.to_k.bias"] = k_bias
+ new_checkpoint[f"{new_path}.to_v.weight"] = value_weight
+ new_checkpoint[f"{new_path}.to_v.bias"] = v_bias
+
+ encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight")
+ encoder_kv_weight = encoder_kv_weight[:, :, 0]
+
+ encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias")
+
+ [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
+ weight=encoder_kv_weight,
+ bias=encoder_kv_bias,
+ split=2,
+ chunk_size=config["attention_head_dim"],
+ )
+
+ new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight
+ new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias
+ new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight
+ new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias
+
+
+def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ for path in paths:
+ new_path = path["new"]
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
+def split_attentions(*, weight, bias, split, chunk_size):
+ weights = [None] * split
+ biases = [None] * split
+
+ weights_biases_idx = 0
+
+ for starting_row_index in range(0, weight.shape[0], chunk_size):
+ row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
+
+ weight_rows = weight[row_indices, :]
+ bias_rows = bias[row_indices]
+
+ if weights[weights_biases_idx] is None:
+ weights[weights_biases_idx] = weight_rows
+ biases[weights_biases_idx] = bias_rows
+ else:
+ assert weights[weights_biases_idx] is not None
+ weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
+ biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
+
+ weights_biases_idx = (weights_biases_idx + 1) % split
+
+ return weights, biases
+
+
+def parse_list(value):
+ if isinstance(value, str):
+ value = value.split(",")
+ value = [int(v) for v in value]
+ elif isinstance(value, list):
+ pass
+ else:
+ raise ValueError(f"Can't parse list for type: {type(value)}")
+
+ return value
+
+
+# below is copy and pasted from original convert_if_stage_2.py script
+
+
+def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
+ orig_path = unet_checkpoint_path
+
+ original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml"))
+ original_unet_config = original_unet_config.params
+
+ unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
+ unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int(
+ original_unet_config.channel_mult.split(",")[-1]
+ )
+ if original_unet_config.encoder_dim != original_unet_config.encoder_channels:
+ unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim
+ unet_diffusers_config["class_embed_type"] = "timestep"
+ unet_diffusers_config["addition_embed_type"] = "text"
+
+ unet_diffusers_config["time_embedding_act_fn"] = "gelu"
+ unet_diffusers_config["resnet_skip_time_act"] = True
+ unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
+ unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
+ unet_diffusers_config["only_cross_attention"] = (
+ bool(original_unet_config.disable_self_attentions)
+ if (
+ "disable_self_attentions" in original_unet_config
+ and isinstance(original_unet_config.disable_self_attentions, int)
+ )
+ else True
+ )
+
+ if sample_size is None:
+ unet_diffusers_config["sample_size"] = original_unet_config.image_size
+ else:
+ # The second upscaler unet's sample size is incorrectly specified
+ # in the config and is instead hardcoded in source
+ unet_diffusers_config["sample_size"] = sample_size
+
+ unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
+
+ if verify_param_count:
+ # check that architecture matches - is a bit slow
+ verify_param_count(orig_path, unet_diffusers_config)
+
+ converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
+ )
+ converted_keys = converted_unet_checkpoint.keys()
+
+ model = UNet2DConditionModel(**unet_diffusers_config)
+ expected_weights = model.state_dict().keys()
+
+ diff_c_e = set(converted_keys) - set(expected_weights)
+ diff_e_c = set(expected_weights) - set(converted_keys)
+
+ assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
+ assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
+
+ model.load_state_dict(converted_unet_checkpoint)
+
+ return model
+
+
+def superres_create_unet_diffusers_config(original_unet_config):
+ attention_resolutions = parse_list(original_unet_config.attention_resolutions)
+ attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions]
+
+ channel_mult = parse_list(original_unet_config.channel_mult)
+ block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult]
+
+ down_block_types = []
+ resolution = 1
+
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnDownBlock2D"
+ elif original_unet_config.resblock_updown:
+ block_type = "ResnetDownsampleBlock2D"
+ else:
+ block_type = "DownBlock2D"
+
+ down_block_types.append(block_type)
+
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnUpBlock2D"
+ elif original_unet_config.resblock_updown:
+ block_type = "ResnetUpsampleBlock2D"
+ else:
+ block_type = "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ head_dim = original_unet_config.num_head_channels
+ use_linear_projection = (
+ original_unet_config.use_linear_in_transformer
+ if "use_linear_in_transformer" in original_unet_config
+ else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ class_embed_type = None
+ projection_class_embeddings_input_dim = None
+
+ if "num_classes" in original_unet_config:
+ if original_unet_config.num_classes == "sequential":
+ class_embed_type = "projection"
+ assert "adm_in_channels" in original_unet_config
+ projection_class_embeddings_input_dim = original_unet_config.adm_in_channels
+ else:
+ raise NotImplementedError(
+ f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}"
+ )
+
+ config = {
+ "in_channels": original_unet_config.in_channels,
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": tuple(original_unet_config.num_res_blocks),
+ "cross_attention_dim": original_unet_config.encoder_channels,
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "out_channels": original_unet_config.out_channels,
+ "up_block_types": tuple(up_block_types),
+ "upcast_attention": False, # TODO: guessing
+ "cross_attention_norm": "group_norm",
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "act_fn": "gelu",
+ }
+
+ if original_unet_config.use_scale_shift_norm:
+ config["resnet_time_scale_shift"] = "scale_shift"
+
+ return config
+
+
+def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ if "encoder_proj.weight" in unet_state_dict:
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]
+
+ if "encoder_pooling.0.weight" in unet_state_dict:
+ mapping = {
+ "encoder_pooling.0": "add_embedding.norm1",
+ "encoder_pooling.1": "add_embedding.pool",
+ "encoder_pooling.2": "add_embedding.proj",
+ "encoder_pooling.3": "add_embedding.norm2",
+ }
+ for key in unet_state_dict.keys():
+ if key.startswith("encoder_pooling"):
+ prefix = key[: len("encoder_pooling.0")]
+ new_key = key.replace(prefix, mapping[prefix])
+ new_checkpoint[new_key] = unet_state_dict[key]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+ if not isinstance(config["layers_per_block"], int):
+ layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
+ downsampler_ids = layers_per_block_cumsum
+ else:
+ # TODO need better check than i in [4, 8, 12, 16]
+ downsampler_ids = [4, 8, 12, 16]
+
+ for i in range(1, num_input_blocks):
+ if isinstance(config["layers_per_block"], int):
+ layers_per_block = config["layers_per_block"]
+ block_id = (i - 1) // (layers_per_block + 1)
+ layer_in_block_id = (i - 1) % (layers_per_block + 1)
+ else:
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
+ layer_in_block_id = (i - 1) - passed_blocks
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+
+ block_type = config["down_block_types"][block_id]
+ if (
+ block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
+ ) and i in downsampler_ids:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
+ else:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ old_path = f"input_blocks.{i}.1"
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": old_path, "new": new_path}
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ unet_state_dict,
+ additional_replacements=[meta_path],
+ config=config,
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ old_path = "middle_block.1"
+ new_path = "mid_block.attentions.0"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ if not isinstance(config["layers_per_block"], int):
+ layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
+
+ for i in range(num_output_blocks):
+ if isinstance(config["layers_per_block"], int):
+ layers_per_block = config["layers_per_block"]
+ block_id = i // (layers_per_block + 1)
+ layer_in_block_id = i % (layers_per_block + 1)
+ else:
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
+ layer_in_block_id = i - passed_blocks
+
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ # len(output_block_list) == 1 -> resnet
+ # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+
+ has_attention = True
+ if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
+ has_attention = False
+
+ maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # this layer was no attention
+ has_attention = False
+ maybe_attentions = []
+
+ if has_attention:
+ old_path = f"output_blocks.{i}.1"
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(maybe_attentions)
+ meta_path = {
+ "old": old_path,
+ "new": new_path,
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
+ layer_id = len(output_block_list) - 1
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ return new_checkpoint
+
+
+def verify_param_count(orig_path, unet_diffusers_config):
+ if "-II-" in orig_path:
+ from deepfloyd_if.modules import IFStageII
+
+ if_II = IFStageII(device="cpu", dir_or_name=orig_path)
+ elif "-III-" in orig_path:
+ from deepfloyd_if.modules import IFStageIII
+
+ if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
+ else:
+ assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
+
+ unet = UNet2DConditionModel(**unet_diffusers_config)
+
+ # in params
+ assert_param_count(unet.time_embedding, if_II.model.time_embed)
+ assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
+
+ # downblocks
+ assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
+ assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
+ assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
+
+ if "-II-" in orig_path:
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
+ if "-III-" in orig_path:
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
+ assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
+
+ # mid block
+ assert_param_count(unet.mid_block, if_II.model.middle_block)
+
+ # up block
+ if "-II-" in orig_path:
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
+ if "-III-" in orig_path:
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
+ assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
+
+ # out params
+ assert_param_count(unet.conv_norm_out, if_II.model.out[0])
+ assert_param_count(unet.conv_out, if_II.model.out[2])
+
+ # make sure all model architecture has same param count
+ assert_param_count(unet, if_II.model)
+
+
+def assert_param_count(model_1, model_2):
+ count_1 = sum(p.numel() for p in model_1.parameters())
+ count_2 = sum(p.numel() for p in model_2.parameters())
+ assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
+
+
+def superres_check_against_original(dump_path, unet_checkpoint_path):
+ model_path = dump_path
+ model = UNet2DConditionModel.from_pretrained(model_path)
+ model.to("cuda")
+ orig_path = unet_checkpoint_path
+
+ if "-II-" in orig_path:
+ from deepfloyd_if.modules import IFStageII
+
+ if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
+ elif "-III-" in orig_path:
+ from deepfloyd_if.modules import IFStageIII
+
+ if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
+
+ batch_size = 1
+ channels = model.in_channels // 2
+ height = model.sample_size
+ width = model.sample_size
+ height = 1024
+ width = 1024
+
+ torch.manual_seed(0)
+
+ latents = torch.randn((batch_size, channels, height, width), device=model.device)
+ image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device)
+
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype)
+ t = torch.tensor([5], device=model.device).to(model.dtype)
+
+ seq_len = 64
+ encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to(
+ model.dtype
+ )
+
+ fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype)
+
+ with torch.no_grad():
+ out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states)
+
+ if_II_model.to("cpu")
+ del if_II_model
+ import gc
+
+ torch.cuda.empty_cache()
+ gc.collect()
+ print(50 * "=")
+
+ with torch.no_grad():
+ noise_pred = model(
+ sample=latent_model_input,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=fake_class_labels,
+ timestep=t,
+ ).sample
+
+ print("Out shape", noise_pred.shape)
+ print("Diff", (out - noise_pred).abs().sum())
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/diffusers/scripts/convert_k_upscaler_to_diffusers.py b/diffusers/scripts/convert_k_upscaler_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..62abedd737855ca0b0bc9abb75c9b6fb91d5bde2
--- /dev/null
+++ b/diffusers/scripts/convert_k_upscaler_to_diffusers.py
@@ -0,0 +1,297 @@
+import argparse
+
+import huggingface_hub
+import k_diffusion as K
+import torch
+
+from diffusers import UNet2DConditionModel
+
+
+UPSCALER_REPO = "pcuenq/k-upscaler"
+
+
+def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
+ rv = {
+ # norm1
+ f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
+ f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
+ # conv1
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
+ # norm2
+ f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
+ f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
+ # conv2
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
+ }
+
+ if resnet.conv_shortcut is not None:
+ rv.update(
+ {
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
+ }
+ )
+
+ return rv
+
+
+def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
+ weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
+ bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
+ rv = {
+ # norm
+ f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
+ f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
+ # to_q
+ f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
+ # to_k
+ f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
+ # to_v
+ f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
+ # to_out
+ f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
+ }
+
+ return rv
+
+
+def cross_attn_to_diffusers_checkpoint(
+ checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
+):
+ weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
+ bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
+
+ rv = {
+ # norm2 (ada groupnorm)
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
+ f"{attention_prefix}.norm_dec.mapper.weight"
+ ],
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
+ f"{attention_prefix}.norm_dec.mapper.bias"
+ ],
+ # layernorm on encoder_hidden_state
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
+ f"{attention_prefix}.norm_enc.weight"
+ ],
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
+ f"{attention_prefix}.norm_enc.bias"
+ ],
+ # to_q
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
+ f"{attention_prefix}.q_proj.weight"
+ ]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
+ f"{attention_prefix}.q_proj.bias"
+ ],
+ # to_k
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
+ # to_v
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
+ # to_out
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
+ f"{attention_prefix}.out_proj.weight"
+ ]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
+ f"{attention_prefix}.out_proj.bias"
+ ],
+ }
+
+ return rv
+
+
+def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
+ block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
+ block_prefix = f"{block_prefix}.{block_idx}"
+
+ diffusers_checkpoint = {}
+
+ if not hasattr(block, "attentions"):
+ n = 1 # resnet only
+ elif not block.attentions[0].add_self_attention:
+ n = 2 # resnet -> cross-attention
+ else:
+ n = 3 # resnet -> self-attention -> cross-attention)
+
+ for resnet_idx, resnet in enumerate(block.resnets):
+ # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
+ diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
+ idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
+ resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
+ )
+ )
+
+ if hasattr(block, "attentions"):
+ for attention_idx, attention in enumerate(block.attentions):
+ diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
+ idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
+ self_attention_prefix = f"{block_prefix}.{idx}"
+ cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_index = 1 if not attention.add_self_attention else 2
+ idx = (
+ n * attention_idx + cross_attention_index
+ if block_type == "up"
+ else n * attention_idx + cross_attention_index + 1
+ )
+ cross_attention_prefix = f"{block_prefix}.{idx }"
+
+ diffusers_checkpoint.update(
+ cross_attn_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ diffusers_attention_index=2,
+ attention_prefix=cross_attention_prefix,
+ )
+ )
+
+ if attention.add_self_attention is True:
+ diffusers_checkpoint.update(
+ self_attn_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ attention_prefix=self_attention_prefix,
+ )
+ )
+
+ return diffusers_checkpoint
+
+
+def unet_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ # pre-processing
+ diffusers_checkpoint.update(
+ {
+ "conv_in.weight": checkpoint["inner_model.proj_in.weight"],
+ "conv_in.bias": checkpoint["inner_model.proj_in.bias"],
+ }
+ )
+
+ # timestep and class embedding
+ diffusers_checkpoint.update(
+ {
+ "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
+ "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
+ "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
+ "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
+ "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
+ "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
+ }
+ )
+
+ # down_blocks
+ for down_block_idx, down_block in enumerate(model.down_blocks):
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
+
+ # up_blocks
+ for up_block_idx, up_block in enumerate(model.up_blocks):
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
+
+ # post-processing
+ diffusers_checkpoint.update(
+ {
+ "conv_out.weight": checkpoint["inner_model.proj_out.weight"],
+ "conv_out.bias": checkpoint["inner_model.proj_out.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+def unet_model_from_original_config(original_config):
+ in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
+ out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
+
+ block_out_channels = original_config["channels"]
+
+ assert (
+ len(set(original_config["depths"])) == 1
+ ), "UNet2DConditionModel currently do not support blocks with different number of layers"
+ layers_per_block = original_config["depths"][0]
+
+ class_labels_dim = original_config["mapping_cond_dim"]
+ cross_attention_dim = original_config["cross_cond_dim"]
+
+ attn1_types = []
+ attn2_types = []
+ for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
+ if s:
+ a1 = "self"
+ a2 = "cross" if c else None
+ elif c:
+ a1 = "cross"
+ a2 = None
+ else:
+ a1 = None
+ a2 = None
+ attn1_types.append(a1)
+ attn2_types.append(a2)
+
+ unet = UNet2DConditionModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
+ mid_block_type=None,
+ up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn="gelu",
+ norm_num_groups=None,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=64,
+ time_cond_proj_dim=class_labels_dim,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="fourier",
+ timestep_post_act="gelu",
+ conv_in_kernel=1,
+ conv_out_kernel=1,
+ )
+
+ return unet
+
+
+def main(args):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
+ orig_weights_path = huggingface_hub.hf_hub_download(
+ UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
+ )
+ print(f"loading original model configuration from {orig_config_path}")
+ print(f"loading original model checkpoint from {orig_weights_path}")
+
+ print("converting to diffusers unet")
+ orig_config = K.config.load_config(open(orig_config_path))["model"]
+ model = unet_model_from_original_config(orig_config)
+
+ orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
+ converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
+
+ model.load_state_dict(converted_checkpoint, strict=True)
+ model.save_pretrained(args.dump_path)
+ print(f"saving converted unet model in {args.dump_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02cb498bb9b87ca1516a9799be47273f5a67a85
--- /dev/null
+++ b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
@@ -0,0 +1,1159 @@
+import argparse
+import tempfile
+
+import torch
+from accelerate import load_checkpoint_and_dispatch
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
+from diffusers.models.prior_transformer import PriorTransformer
+from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
+from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
+
+
+r"""
+Example - From the diffusers root directory:
+
+Download weights:
+```sh
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
+```
+
+Convert the model:
+```sh
+$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \
+ --decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \
+ --super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \
+ --prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \
+ --clip_stat_path ./ViT-L-14_stats.th \
+ --dump_path
+```
+"""
+
+
+# prior
+
+PRIOR_ORIGINAL_PREFIX = "model"
+
+# Uses default arguments
+PRIOR_CONFIG = {}
+
+
+def prior_model_from_original_config():
+ model = PriorTransformer(**PRIOR_CONFIG)
+
+ return model
+
+
+def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
+ diffusers_checkpoint = {}
+
+ # .time_embed.0 -> .time_embedding.linear_1
+ diffusers_checkpoint.update(
+ {
+ "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
+ "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
+ }
+ )
+
+ # .clip_img_proj -> .proj_in
+ diffusers_checkpoint.update(
+ {
+ "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
+ "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
+ }
+ )
+
+ # .text_emb_proj -> .embedding_proj
+ diffusers_checkpoint.update(
+ {
+ "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
+ "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
+ }
+ )
+
+ # .text_enc_proj -> .encoder_hidden_states_proj
+ diffusers_checkpoint.update(
+ {
+ "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
+ "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
+ }
+ )
+
+ # .positional_embedding -> .positional_embedding
+ diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
+
+ # .prd_emb -> .prd_embedding
+ diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
+
+ # .time_embed.2 -> .time_embedding.linear_2
+ diffusers_checkpoint.update(
+ {
+ "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
+ "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
+ }
+ )
+
+ # .resblocks. -> .transformer_blocks.
+ for idx in range(len(model.transformer_blocks)):
+ diffusers_transformer_prefix = f"transformer_blocks.{idx}"
+ original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
+
+ # .attn -> .attn1
+ diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
+ original_attention_prefix = f"{original_transformer_prefix}.attn"
+ diffusers_checkpoint.update(
+ prior_attention_to_diffusers(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ original_attention_prefix=original_attention_prefix,
+ attention_head_dim=model.attention_head_dim,
+ )
+ )
+
+ # .mlp -> .ff
+ diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
+ original_ff_prefix = f"{original_transformer_prefix}.mlp"
+ diffusers_checkpoint.update(
+ prior_ff_to_diffusers(
+ checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
+ )
+ )
+
+ # .ln_1 -> .norm1
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
+ f"{original_transformer_prefix}.ln_1.weight"
+ ],
+ f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
+ }
+ )
+
+ # .ln_2 -> .norm3
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
+ f"{original_transformer_prefix}.ln_2.weight"
+ ],
+ f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
+ }
+ )
+
+ # .final_ln -> .norm_out
+ diffusers_checkpoint.update(
+ {
+ "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
+ "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
+ }
+ )
+
+ # .out_proj -> .proj_to_clip_embeddings
+ diffusers_checkpoint.update(
+ {
+ "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
+ "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
+ }
+ )
+
+ # clip stats
+ clip_mean, clip_std = clip_stats_checkpoint
+ clip_mean = clip_mean[None, :]
+ clip_std = clip_std[None, :]
+
+ diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
+
+ return diffusers_checkpoint
+
+
+def prior_attention_to_diffusers(
+ checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
+):
+ diffusers_checkpoint = {}
+
+ # .c_qkv -> .{to_q, to_k, to_v}
+ [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
+ weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
+ bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
+ split=3,
+ chunk_size=attention_head_dim,
+ )
+
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_attention_prefix}.to_q.weight": q_weight,
+ f"{diffusers_attention_prefix}.to_q.bias": q_bias,
+ f"{diffusers_attention_prefix}.to_k.weight": k_weight,
+ f"{diffusers_attention_prefix}.to_k.bias": k_bias,
+ f"{diffusers_attention_prefix}.to_v.weight": v_weight,
+ f"{diffusers_attention_prefix}.to_v.bias": v_bias,
+ }
+ )
+
+ # .c_proj -> .to_out.0
+ diffusers_checkpoint.update(
+ {
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
+ diffusers_checkpoint = {
+ # .c_fc -> .net.0.proj
+ f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
+ f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
+ # .c_proj -> .net.2
+ f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
+ f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
+ }
+
+ return diffusers_checkpoint
+
+
+# done prior
+
+
+# decoder
+
+DECODER_ORIGINAL_PREFIX = "model"
+
+# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
+# update then.
+DECODER_CONFIG = {
+ "sample_size": 64,
+ "layers_per_block": 3,
+ "down_block_types": (
+ "ResnetDownsampleBlock2D",
+ "SimpleCrossAttnDownBlock2D",
+ "SimpleCrossAttnDownBlock2D",
+ "SimpleCrossAttnDownBlock2D",
+ ),
+ "up_block_types": (
+ "SimpleCrossAttnUpBlock2D",
+ "SimpleCrossAttnUpBlock2D",
+ "SimpleCrossAttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "block_out_channels": (320, 640, 960, 1280),
+ "in_channels": 3,
+ "out_channels": 6,
+ "cross_attention_dim": 1536,
+ "class_embed_type": "identity",
+ "attention_head_dim": 64,
+ "resnet_time_scale_shift": "scale_shift",
+}
+
+
+def decoder_model_from_original_config():
+ model = UNet2DConditionModel(**DECODER_CONFIG)
+
+ return model
+
+
+def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ original_unet_prefix = DECODER_ORIGINAL_PREFIX
+ num_head_channels = DECODER_CONFIG["attention_head_dim"]
+
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
+
+ # .input_blocks -> .down_blocks
+
+ original_down_block_idx = 1
+
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_down_block_idx=diffusers_down_block_idx,
+ original_down_block_idx=original_down_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=num_head_channels,
+ )
+
+ original_down_block_idx += num_original_down_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ # done .input_blocks -> .down_blocks
+
+ diffusers_checkpoint.update(
+ unet_midblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=num_head_channels,
+ )
+ )
+
+ # .output_blocks -> .up_blocks
+
+ original_up_block_idx = 0
+
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
+ model,
+ checkpoint,
+ diffusers_up_block_idx=diffusers_up_block_idx,
+ original_up_block_idx=original_up_block_idx,
+ original_unet_prefix=original_unet_prefix,
+ num_head_channels=num_head_channels,
+ )
+
+ original_up_block_idx += num_original_up_blocks
+
+ diffusers_checkpoint.update(checkpoint_update)
+
+ # done .output_blocks -> .up_blocks
+
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
+
+ return diffusers_checkpoint
+
+
+# done decoder
+
+# text proj
+
+
+def text_proj_from_original_config():
+ # From the conditional unet constructor where the dimension of the projected time embeddings is
+ # constructed
+ time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4
+
+ cross_attention_dim = DECODER_CONFIG["cross_attention_dim"]
+
+ model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim)
+
+ return model
+
+
+# Note that the input checkpoint is the original decoder checkpoint
+def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
+ diffusers_checkpoint = {
+ # .text_seq_proj.0 -> .encoder_hidden_states_proj
+ "encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"],
+ "encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"],
+ # .text_seq_proj.1 -> .text_encoder_hidden_states_norm
+ "text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"],
+ "text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"],
+ # .clip_tok_proj -> .clip_extra_context_tokens_proj
+ "clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"],
+ "clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"],
+ # .text_feat_proj -> .embedding_proj
+ "embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"],
+ "embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"],
+ # .cf_param -> .learned_classifier_free_guidance_embeddings
+ "learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"],
+ # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings
+ "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[
+ f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight"
+ ],
+ "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[
+ f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias"
+ ],
+ }
+
+ return diffusers_checkpoint
+
+
+# done text proj
+
+# super res unet first steps
+
+SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps"
+
+SUPER_RES_UNET_FIRST_STEPS_CONFIG = {
+ "sample_size": 256,
+ "layers_per_block": 3,
+ "down_block_types": (
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ "up_block_types": (
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ "block_out_channels": (320, 640, 960, 1280),
+ "in_channels": 6,
+ "out_channels": 3,
+ "add_attention": False,
+}
+
+
+def super_res_unet_first_steps_model_from_original_config():
+ model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG)
+
+ return model
+
+
+def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX
+
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
+
+ # .input_blocks ->