Models
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution $p\theta(\mathbf{x}{t-1}|\mathbf{x}_t)$.
The models are built on the base class [‘ModelMixin’] that is a torch.nn.module
with basic functionality for saving and loading models both locally and from the HuggingFace hub.
ModelMixin
Base class for all models.
ModelMixin takes care of storing the configuration of the models and handles methods for loading, downloading and saving models.
- config_name (
str
) — A filename under which the model should be stored when calling save_pretrained().
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as “activation checkpointing” or “checkpoint activations”.
from_pretrained
< source >( pretrained_model_name_or_path: typing.Union[str, os.PathLike, NoneType] **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
, optional) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like
google/ddpm-celebahq-256
. - A path to a directory containing model weights saved using
~ModelMixin.save_config
, e.g.,./my_model_directory/
.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like
-
cache_dir (
Union[str, os.PathLike]
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
torch_dtype (
str
ortorch.dtype
, optional) — Override the defaulttorch.dtype
and load the model under this dtype. If"auto"
is passed the dtype will be automatically derived from the model’s weights. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (i.e., do not try to download the model). -
use_auth_token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runningdiffusers-cli login
(stored in~/.huggingface
). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
subfolder (
str
, optional, defaults to""
) — In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. -
mirror (
str
, optional) — Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information.
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using model.eval()
(Dropout modules are deactivated). To train
the model, you should first set it back in training mode with model.train()
.
The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.
The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.
It is required to be logged in (huggingface-cli login
) when you want to use private or gated
models.
Activate the special “offline-mode” to use this method in a firewalled environment.
num_parameters
< source >(
only_trainable: bool = False
exclude_embeddings: bool = False
)
→
int
Get number of (optionally, trainable or non-embeddings) parameters in the module.
save_pretrained
< source >( save_directory: typing.Union[str, os.PathLike] is_main_process: bool = True save_function: typing.Callable = <function save at 0x7f9c65dc5ca0> )
Parameters
-
save_directory (
str
oros.PathLike
) — Directory to which to save. Will be created if it doesn’t exist. -
is_main_process (
bool
, optional, defaults toTrue
) — Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, setis_main_process=True
only on the main process to avoid race conditions. -
save_function (
Callable
) — The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replacetorch.save
by another method.
Save a model and its configuration file to a directory, so that it can be re-loaded using the
[from_pretrained()](/docs/diffusers/v0.6.0/en/api/models#diffusers.ModelMixin.from_pretrained)
class method.
UNet2DOutput
class diffusers.models.unet_2d.UNet2DOutput
< source >( sample: FloatTensor )
UNet2DModel
class diffusers.UNet2DModel
< source >( sample_size: typing.Optional[int] = None in_channels: int = 3 out_channels: int = 3 center_input_sample: bool = False time_embedding_type: str = 'positional' freq_shift: int = 0 flip_sin_to_cos: bool = True down_block_types: typing.Tuple[str] = ('DownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D') up_block_types: typing.Tuple[str] = ('AttnUpBlock2D', 'AttnUpBlock2D', 'AttnUpBlock2D', 'UpBlock2D') block_out_channels: typing.Tuple[int] = (224, 448, 672, 896) layers_per_block: int = 2 mid_block_scale_factor: float = 1 downsample_padding: int = 1 act_fn: str = 'silu' attention_head_dim: int = 8 norm_num_groups: int = 32 norm_eps: float = 1e-05 )
Parameters
-
sample_size (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
, optional) — Input sample size. -
in_channels (
int
, optional, defaults to 3) — Number of channels in the input image. -
out_channels (
int
, optional, defaults to 3) — Number of channels in the output. -
center_input_sample (
bool
, optional, defaults toFalse
) — Whether to center the input sample. -
time_embedding_type (
str
, optional, defaults to"positional"
) — Type of time embedding to use. -
freq_shift (
int
, optional, defaults to 0) — Frequency shift for fourier time embedding. -
flip_sin_to_cos (
bool
, optional, defaults to — obj:False
): Whether to flip sin to cos for fourier time embedding. -
down_block_types (
Tuple[str]
, optional, defaults to — obj:("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")
): Tuple of downsample block types. -
up_block_types (
Tuple[str]
, optional, defaults to — obj:("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")
): Tuple of upsample block types. -
block_out_channels (
Tuple[int]
, optional, defaults to — obj:(224, 448, 672, 896)
): Tuple of block output channels. -
layers_per_block (
int
, optional, defaults to2
) — The number of layers per block. -
mid_block_scale_factor (
float
, optional, defaults to1
) — The scale factor for the mid block. -
downsample_padding (
int
, optional, defaults to1
) — The padding for the downsample convolution. -
act_fn (
str
, optional, defaults to"silu"
) — The activation function to use. -
attention_head_dim (
int
, optional, defaults to8
) — The attention head dimension. -
norm_num_groups (
int
, optional, defaults to32
) — The number of groups for the normalization. -
norm_eps (
float
, optional, defaults to1e-5
) — The epsilon for the normalization.
UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
This model inherits from ModelMixin. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.)
forward
< source >(
sample: FloatTensor
timestep: typing.Union[torch.Tensor, float, int]
return_dict: bool = True
)
→
UNet2DOutput or tuple
Parameters
-
sample (
torch.FloatTensor
) — (batch, channel, height, width) noisy inputs tensor -
timestep (
torch.FloatTensor
orfloat
or `int) — (batch) timesteps -
return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a UNet2DOutput instead of a plain tuple.
Returns
UNet2DOutput or tuple
UNet2DOutput if return_dict
is True,
otherwise a tuple
. When returning a tuple, the first element is the sample tensor.
UNet2DConditionOutput
class diffusers.models.unet_2d_condition.UNet2DConditionOutput
< source >( sample: FloatTensor )
UNet2DConditionModel
class diffusers.UNet2DConditionModel
< source >( sample_size: typing.Optional[int] = None in_channels: int = 4 out_channels: int = 4 center_input_sample: bool = False flip_sin_to_cos: bool = True freq_shift: int = 0 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 downsample_padding: int = 1 mid_block_scale_factor: float = 1 act_fn: str = 'silu' norm_num_groups: int = 32 norm_eps: float = 1e-05 cross_attention_dim: int = 1280 attention_head_dim: int = 8 )
Parameters
-
sample_size (
int
, optional) — The size of the input sample. -
in_channels (
int
, optional, defaults to 4) — The number of channels in the input sample. -
out_channels (
int
, optional, defaults to 4) — The number of channels in the output. -
center_input_sample (
bool
, optional, defaults toFalse
) — Whether to center the input sample. -
flip_sin_to_cos (
bool
, optional, defaults toFalse
) — Whether to flip the sin to cos in the time embedding. -
freq_shift (
int
, optional, defaults to 0) — The frequency shift to apply to the time embedding. -
down_block_types (
Tuple[str]
, optional, defaults to("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
) — The tuple of downsample blocks to use. -
up_block_types (
Tuple[str]
, optional, defaults to("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)
) — The tuple of upsample blocks to use. -
block_out_channels (
Tuple[int]
, optional, defaults to(320, 640, 1280, 1280)
) — The tuple of output channels for each block. -
layers_per_block (
int
, optional, defaults to 2) — The number of layers per block. -
downsample_padding (
int
, optional, defaults to 1) — The padding to use for the downsampling convolution. -
mid_block_scale_factor (
float
, optional, defaults to 1.0) — The scale factor to use for the mid block. -
act_fn (
str
, optional, defaults to"silu"
) — The activation function to use. -
norm_num_groups (
int
, optional, defaults to 32) — The number of groups to use for the normalization. -
norm_eps (
float
, optional, defaults to 1e-5) — The epsilon to use for the normalization. -
cross_attention_dim (
int
, optional, defaults to 1280) — The dimension of the cross attention features. -
attention_head_dim (
int
, optional, defaults to 8) — The dimension of the attention heads.
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output.
This model inherits from ModelMixin. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.)
forward
< source >(
sample: FloatTensor
timestep: typing.Union[torch.Tensor, float, int]
encoder_hidden_states: Tensor
return_dict: bool = True
)
→
UNet2DConditionOutput or tuple
Parameters
-
sample (
torch.FloatTensor
) — (batch, channel, height, width) noisy inputs tensor -
timestep (
torch.FloatTensor
orfloat
orint
) — (batch) timesteps - encoder_hidden_states (
torch.FloatTensor
) — (batch, channel, height, width) encoder hidden states -
return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return a models.unet_2d_condition.UNet2DConditionOutput instead of a plain tuple.
Returns
UNet2DConditionOutput or tuple
UNet2DConditionOutput if return_dict
is True, otherwise a tuple
. When
returning a tuple, the first element is the sample tensor.
DecoderOutput
class diffusers.models.vae.DecoderOutput
< source >( sample: FloatTensor )
Output of decoding method.
VQEncoderOutput
class diffusers.models.vae.VQEncoderOutput
< source >( latents: FloatTensor )
Output of VQModel encoding method.
VQModel
class diffusers.VQModel
< source >( in_channels: int = 3 out_channels: int = 3 down_block_types: typing.Tuple[str] = ('DownEncoderBlock2D',) up_block_types: typing.Tuple[str] = ('UpDecoderBlock2D',) block_out_channels: typing.Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = 'silu' latent_channels: int = 3 sample_size: int = 32 num_vq_embeddings: int = 256 norm_num_groups: int = 32 )
Parameters
- in_channels (int, optional, defaults to 3) — Number of channels in the input image.
- out_channels (int, optional, defaults to 3) — Number of channels in the output.
-
down_block_types (
Tuple[str]
, optional, defaults to — obj:("DownEncoderBlock2D",)
): Tuple of downsample block types. -
up_block_types (
Tuple[str]
, optional, defaults to — obj:("UpDecoderBlock2D",)
): Tuple of upsample block types. -
block_out_channels (
Tuple[int]
, optional, defaults to — obj:(64,)
): Tuple of block output channels. -
act_fn (
str
, optional, defaults to"silu"
) — The activation function to use. -
latent_channels (
int
, optional, defaults to3
) — Number of channels in the latent space. -
sample_size (
int
, optional, defaults to32
) — TODO -
num_vq_embeddings (
int
, optional, defaults to256
) — Number of codebook vectors in the VQ-VAE.
VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu.
This model inherits from ModelMixin. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.)
forward
< source >( sample: FloatTensor return_dict: bool = True )
AutoencoderKLOutput
class diffusers.models.vae.AutoencoderKLOutput
< source >( latent_dist: DiagonalGaussianDistribution )
Output of AutoencoderKL encoding method.
AutoencoderKL
class diffusers.AutoencoderKL
< source >( in_channels: int = 3 out_channels: int = 3 down_block_types: typing.Tuple[str] = ('DownEncoderBlock2D',) up_block_types: typing.Tuple[str] = ('UpDecoderBlock2D',) block_out_channels: typing.Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = 'silu' latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 )
Parameters
- in_channels (int, optional, defaults to 3) — Number of channels in the input image.
- out_channels (int, optional, defaults to 3) — Number of channels in the output.
-
down_block_types (
Tuple[str]
, optional, defaults to — obj:("DownEncoderBlock2D",)
): Tuple of downsample block types. -
up_block_types (
Tuple[str]
, optional, defaults to — obj:("UpDecoderBlock2D",)
): Tuple of upsample block types. -
block_out_channels (
Tuple[int]
, optional, defaults to — obj:(64,)
): Tuple of block output channels. -
act_fn (
str
, optional, defaults to"silu"
) — The activation function to use. -
latent_channels (
int
, optional, defaults to4
) — Number of channels in the latent space. -
sample_size (
int
, optional, defaults to32
) — TODO
Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
This model inherits from ModelMixin. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.)
forward
< source >( sample: FloatTensor sample_posterior: bool = False return_dict: bool = True generator: typing.Optional[torch._C.Generator] = None )
FlaxModelMixin
Base class for all flax models.
FlaxModelMixin takes care of storing the configuration of the models and handles methods for loading, downloading and saving models.
from_pretrained
< source >( pretrained_model_name_or_path: typing.Union[str, os.PathLike] dtype: dtype = <class 'jax.numpy.float32'> *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids are namespaced under a user or organization name, like
CompVis/stable-diffusion-v1-4
. - A path to a directory containing model weights saved using save_pretrained(),
e.g.,
./my_model_directory/
.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids are namespaced under a user or organization name, like
-
dtype (
jax.numpy.dtype
, optional, defaults tojax.numpy.float32
) — The data type of the computation. Can be one ofjax.numpy.float32
,jax.numpy.float16
(on GPUs) andjax.numpy.bfloat16
(on TPUs).This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the given
dtype
.Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.
If you wish to change the dtype of the model parameters, see
~ModelMixin.to_fp16
and~ModelMixin.to_bf16
. -
model_args (sequence of positional arguments, optional) —
All remaining positional arguments will be passed to the underlying model’s
__init__
method. -
cache_dir (
Union[str, os.PathLike]
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (i.e., do not try to download the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file. -
kwargs (remaining dictionary of keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_config()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate a pretrained flax model from a pre-trained model configuration.
The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.
The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.
Examples:
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
save_pretrained
< source >( save_directory: typing.Union[str, os.PathLike] params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] is_main_process: bool = True )
Parameters
-
save_directory (
str
oros.PathLike
) — Directory to which to save. Will be created if it doesn’t exist. -
params (
Union[Dict, FrozenDict]
) — APyTree
of model parameters. -
is_main_process (
bool
, optional, defaults toTrue
) — Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, setis_main_process=True
only on the main process to avoid race conditions.
Save a model and its configuration file to a directory, so that it can be re-loaded using the
[from_pretrained()](/docs/diffusers/v0.6.0/en/api/models#diffusers.FlaxModelMixin.from_pretrained)
class method
to_bf16
< source >( params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] mask: typing.Any = None )
Cast the floating-point params
to jax.numpy.bfloat16
. This returns a new params
tree and does not cast
the params
in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Examples:
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> params = model.to_bf16(params)
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_bf16(params, mask)
to_fp16
< source >( params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] mask: typing.Any = None )
Cast the floating-point params
to jax.numpy.float16
. This returns a new params
tree and does not cast the
params
in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Examples:
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # By default, the model params will be in fp32, to cast these to float16
>>> params = model.to_fp16(params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_fp16(params, mask)
to_fp32
< source >( params: typing.Union[typing.Dict, flax.core.frozen_dict.FrozenDict] mask: typing.Any = None )
Cast the floating-point params
to jax.numpy.float32
. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new params
tree and does not cast the params
in place.
Examples:
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> params = model.to_f16(params)
>>> # now cast back to fp32
>>> params = model.to_fp32(params)
FlaxUNet2DConditionOutput
class diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
< source >( sample: ndarray )
“Returns a new object replacing the specified fields with new values.
FlaxUNet2DConditionModel
class diffusers.FlaxUNet2DConditionModel
< source >( sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: int = 8 cross_attention_dim: int = 1280 dropout: float = 0.0 dtype: dtype = <class 'jax.numpy.float32'> freq_shift: int = 0 parent: typing.Union[typing.Type[flax.linen.module.Module], typing.Type[flax.core.scope.Scope], typing.Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f9c5239d6a0> name: str = None )
Parameters
-
sample_size (
int
, optional) — The size of the input sample. -
in_channels (
int
, optional, defaults to 4) — The number of channels in the input sample. -
out_channels (
int
, optional, defaults to 4) — The number of channels in the output. -
down_block_types (
Tuple[str]
, optional, defaults to("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
) — The tuple of downsample blocks to use. The corresponding class names will be: “FlaxCrossAttnDownBlock2D”, “FlaxCrossAttnDownBlock2D”, “FlaxCrossAttnDownBlock2D”, “FlaxDownBlock2D” -
up_block_types (
Tuple[str]
, optional, defaults to("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)
) — The tuple of upsample blocks to use. The corresponding class names will be: “FlaxUpBlock2D”, “FlaxCrossAttnUpBlock2D”, “FlaxCrossAttnUpBlock2D”, “FlaxCrossAttnUpBlock2D” -
block_out_channels (
Tuple[int]
, optional, defaults to(320, 640, 1280, 1280)
) — The tuple of output channels for each block. -
layers_per_block (
int
, optional, defaults to 2) — The number of layers per block. -
attention_head_dim (
int
, optional, defaults to 8) — The dimension of the attention heads. -
cross_attention_dim (
int
, optional, defaults to 768) — The dimension of the cross attention features. -
dropout (
float
, optional, defaults to 0) — Dropout probability for down, up and bottleneck blocks.
FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep and returns sample shaped output.
This model inherits from FlaxModelMixin. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen flax.linen.Module subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
FlaxDecoderOutput
class diffusers.models.vae_flax.FlaxDecoderOutput
< source >( sample: ndarray )
Output of decoding method.
“Returns a new object replacing the specified fields with new values.
FlaxAutoencoderKLOutput
class diffusers.models.vae_flax.FlaxAutoencoderKLOutput
< source >( latent_dist: FlaxDiagonalGaussianDistribution )
Output of AutoencoderKL encoding method.
“Returns a new object replacing the specified fields with new values.
FlaxAutoencoderKL
class diffusers.FlaxAutoencoderKL
< source >( in_channels: int = 3 out_channels: int = 3 down_block_types: typing.Tuple[str] = ('DownEncoderBlock2D',) up_block_types: typing.Tuple[str] = ('UpDecoderBlock2D',) block_out_channels: typing.Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = 'silu' latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 dtype: dtype = <class 'jax.numpy.float32'> parent: typing.Union[typing.Type[flax.linen.module.Module], typing.Type[flax.core.scope.Scope], typing.Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x7f9c5239d6a0> name: str = None )
Parameters
-
in_channels (
int
, optional, defaults to 3) — Input channels -
out_channels (
int
, optional, defaults to 3) — Output channels -
down_block_types (
Tuple[str]
, optional, defaults to (DownEncoderBlock2D)) — DownEncoder block type -
up_block_types (
Tuple[str]
, optional, defaults to (UpDecoderBlock2D)) — UpDecoder block type -
block_out_channels (
Tuple[str]
, optional, defaults to (64,)) — Tuple containing the number of output channels for each block -
layers_per_block (
int
, optional, defaults to 2) — Number of Resnet layer for each block -
act_fn (
str
, optional, defaults to silu) — Activation function -
latent_channels (
int
, optional, defaults to 4) — Latent space channels -
norm_num_groups (
int
, optional, defaults to 32) — Norm num group -
sample_size (
int
, optional, defaults to 32) — Sample input size -
dtype (
jnp.dtype
, optional, defaults to jnp.float32) — parameters dtype
Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
This model is a Flax Linen flax.linen.Module subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as: