Spaces:
No application file
No application file
# | |
# Copyright 2023 The HuggingFace Inc. team. | |
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gc | |
import os | |
from collections import OrderedDict | |
from copy import copy | |
from typing import List, Optional, Union | |
import numpy as np | |
import onnx | |
import onnx_graphsurgeon as gs | |
import PIL.Image | |
import tensorrt as trt | |
import torch | |
from huggingface_hub import snapshot_download | |
from onnx import shape_inference | |
from polygraphy import cuda | |
from polygraphy.backend.common import bytes_from_path | |
from polygraphy.backend.onnx.loader import fold_constants | |
from polygraphy.backend.trt import ( | |
CreateConfig, | |
Profile, | |
engine_from_bytes, | |
engine_from_network, | |
network_from_onnx_path, | |
save_engine, | |
) | |
from polygraphy.backend.trt import util as trt_util | |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
from diffusers.models import AutoencoderKL, UNet2DConditionModel | |
from diffusers.pipelines.stable_diffusion import ( | |
StableDiffusionImg2ImgPipeline, | |
StableDiffusionPipelineOutput, | |
StableDiffusionSafetyChecker, | |
) | |
from diffusers.schedulers import DDIMScheduler | |
from diffusers.utils import DIFFUSERS_CACHE, logging | |
""" | |
Installation instructions | |
python3 -m pip install --upgrade transformers diffusers>=0.16.0 | |
python3 -m pip install --upgrade tensorrt>=8.6.1 | |
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com | |
python3 -m pip install onnxruntime | |
""" | |
TRT_LOGGER = trt.Logger(trt.Logger.ERROR) | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
# Map of numpy dtype -> torch dtype | |
numpy_to_torch_dtype_dict = { | |
np.uint8: torch.uint8, | |
np.int8: torch.int8, | |
np.int16: torch.int16, | |
np.int32: torch.int32, | |
np.int64: torch.int64, | |
np.float16: torch.float16, | |
np.float32: torch.float32, | |
np.float64: torch.float64, | |
np.complex64: torch.complex64, | |
np.complex128: torch.complex128, | |
} | |
if np.version.full_version >= "1.24.0": | |
numpy_to_torch_dtype_dict[np.bool_] = torch.bool | |
else: | |
numpy_to_torch_dtype_dict[np.bool] = torch.bool | |
# Map of torch dtype -> numpy dtype | |
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} | |
def device_view(t): | |
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) | |
def preprocess_image(image): | |
""" | |
image: torch.Tensor | |
""" | |
w, h = image.size | |
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 | |
image = image.resize((w, h)) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image).contiguous() | |
return 2.0 * image - 1.0 | |
class Engine: | |
def __init__(self, engine_path): | |
self.engine_path = engine_path | |
self.engine = None | |
self.context = None | |
self.buffers = OrderedDict() | |
self.tensors = OrderedDict() | |
def __del__(self): | |
[buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] | |
del self.engine | |
del self.context | |
del self.buffers | |
del self.tensors | |
def build( | |
self, | |
onnx_path, | |
fp16, | |
input_profile=None, | |
enable_preview=False, | |
enable_all_tactics=False, | |
timing_cache=None, | |
workspace_size=0, | |
): | |
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") | |
p = Profile() | |
if input_profile: | |
for name, dims in input_profile.items(): | |
assert len(dims) == 3 | |
p.add(name, min=dims[0], opt=dims[1], max=dims[2]) | |
config_kwargs = {} | |
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] | |
if enable_preview: | |
# Faster dynamic shapes made optional since it increases engine build time. | |
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) | |
if workspace_size > 0: | |
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} | |
if not enable_all_tactics: | |
config_kwargs["tactic_sources"] = [] | |
engine = engine_from_network( | |
network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), | |
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), | |
save_timing_cache=timing_cache, | |
) | |
save_engine(engine, path=self.engine_path) | |
def load(self): | |
logger.warning(f"Loading TensorRT engine: {self.engine_path}") | |
self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) | |
def activate(self): | |
self.context = self.engine.create_execution_context() | |
def allocate_buffers(self, shape_dict=None, device="cuda"): | |
for idx in range(trt_util.get_bindings_per_profile(self.engine)): | |
binding = self.engine[idx] | |
if shape_dict and binding in shape_dict: | |
shape = shape_dict[binding] | |
else: | |
shape = self.engine.get_binding_shape(binding) | |
dtype = trt.nptype(self.engine.get_binding_dtype(binding)) | |
if self.engine.binding_is_input(binding): | |
self.context.set_binding_shape(idx, shape) | |
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) | |
self.tensors[binding] = tensor | |
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) | |
def infer(self, feed_dict, stream): | |
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) | |
# shallow copy of ordered dict | |
device_buffers = copy(self.buffers) | |
for name, buf in feed_dict.items(): | |
assert isinstance(buf, cuda.DeviceView) | |
device_buffers[name] = buf | |
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] | |
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) | |
if not noerror: | |
raise ValueError("ERROR: inference failed.") | |
return self.tensors | |
class Optimizer: | |
def __init__(self, onnx_graph): | |
self.graph = gs.import_onnx(onnx_graph) | |
def cleanup(self, return_onnx=False): | |
self.graph.cleanup().toposort() | |
if return_onnx: | |
return gs.export_onnx(self.graph) | |
def select_outputs(self, keep, names=None): | |
self.graph.outputs = [self.graph.outputs[o] for o in keep] | |
if names: | |
for i, name in enumerate(names): | |
self.graph.outputs[i].name = name | |
def fold_constants(self, return_onnx=False): | |
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) | |
self.graph = gs.import_onnx(onnx_graph) | |
if return_onnx: | |
return onnx_graph | |
def infer_shapes(self, return_onnx=False): | |
onnx_graph = gs.export_onnx(self.graph) | |
if onnx_graph.ByteSize() > 2147483648: | |
raise TypeError("ERROR: model size exceeds supported 2GB limit") | |
else: | |
onnx_graph = shape_inference.infer_shapes(onnx_graph) | |
self.graph = gs.import_onnx(onnx_graph) | |
if return_onnx: | |
return onnx_graph | |
class BaseModel: | |
def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): | |
self.model = model | |
self.name = "SD Model" | |
self.fp16 = fp16 | |
self.device = device | |
self.min_batch = 1 | |
self.max_batch = max_batch_size | |
self.min_image_shape = 256 # min image resolution: 256x256 | |
self.max_image_shape = 1024 # max image resolution: 1024x1024 | |
self.min_latent_shape = self.min_image_shape // 8 | |
self.max_latent_shape = self.max_image_shape // 8 | |
self.embedding_dim = embedding_dim | |
self.text_maxlen = text_maxlen | |
def get_model(self): | |
return self.model | |
def get_input_names(self): | |
pass | |
def get_output_names(self): | |
pass | |
def get_dynamic_axes(self): | |
return None | |
def get_sample_input(self, batch_size, image_height, image_width): | |
pass | |
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): | |
return None | |
def get_shape_dict(self, batch_size, image_height, image_width): | |
return None | |
def optimize(self, onnx_graph): | |
opt = Optimizer(onnx_graph) | |
opt.cleanup() | |
opt.fold_constants() | |
opt.infer_shapes() | |
onnx_opt_graph = opt.cleanup(return_onnx=True) | |
return onnx_opt_graph | |
def check_dims(self, batch_size, image_height, image_width): | |
assert batch_size >= self.min_batch and batch_size <= self.max_batch | |
assert image_height % 8 == 0 or image_width % 8 == 0 | |
latent_height = image_height // 8 | |
latent_width = image_width // 8 | |
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape | |
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape | |
return (latent_height, latent_width) | |
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): | |
min_batch = batch_size if static_batch else self.min_batch | |
max_batch = batch_size if static_batch else self.max_batch | |
latent_height = image_height // 8 | |
latent_width = image_width // 8 | |
min_image_height = image_height if static_shape else self.min_image_shape | |
max_image_height = image_height if static_shape else self.max_image_shape | |
min_image_width = image_width if static_shape else self.min_image_shape | |
max_image_width = image_width if static_shape else self.max_image_shape | |
min_latent_height = latent_height if static_shape else self.min_latent_shape | |
max_latent_height = latent_height if static_shape else self.max_latent_shape | |
min_latent_width = latent_width if static_shape else self.min_latent_shape | |
max_latent_width = latent_width if static_shape else self.max_latent_shape | |
return ( | |
min_batch, | |
max_batch, | |
min_image_height, | |
max_image_height, | |
min_image_width, | |
max_image_width, | |
min_latent_height, | |
max_latent_height, | |
min_latent_width, | |
max_latent_width, | |
) | |
def getOnnxPath(model_name, onnx_dir, opt=True): | |
return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") | |
def getEnginePath(model_name, engine_dir): | |
return os.path.join(engine_dir, model_name + ".plan") | |
def build_engines( | |
models: dict, | |
engine_dir, | |
onnx_dir, | |
onnx_opset, | |
opt_image_height, | |
opt_image_width, | |
opt_batch_size=1, | |
force_engine_rebuild=False, | |
static_batch=False, | |
static_shape=True, | |
enable_preview=False, | |
enable_all_tactics=False, | |
timing_cache=None, | |
max_workspace_size=0, | |
): | |
built_engines = {} | |
if not os.path.isdir(onnx_dir): | |
os.makedirs(onnx_dir) | |
if not os.path.isdir(engine_dir): | |
os.makedirs(engine_dir) | |
# Export models to ONNX | |
for model_name, model_obj in models.items(): | |
engine_path = getEnginePath(model_name, engine_dir) | |
if force_engine_rebuild or not os.path.exists(engine_path): | |
logger.warning("Building Engines...") | |
logger.warning("Engine build can take a while to complete") | |
onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) | |
onnx_opt_path = getOnnxPath(model_name, onnx_dir) | |
if force_engine_rebuild or not os.path.exists(onnx_opt_path): | |
if force_engine_rebuild or not os.path.exists(onnx_path): | |
logger.warning(f"Exporting model: {onnx_path}") | |
model = model_obj.get_model() | |
with torch.inference_mode(), torch.autocast("cuda"): | |
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) | |
torch.onnx.export( | |
model, | |
inputs, | |
onnx_path, | |
export_params=True, | |
opset_version=onnx_opset, | |
do_constant_folding=True, | |
input_names=model_obj.get_input_names(), | |
output_names=model_obj.get_output_names(), | |
dynamic_axes=model_obj.get_dynamic_axes(), | |
) | |
del model | |
torch.cuda.empty_cache() | |
gc.collect() | |
else: | |
logger.warning(f"Found cached model: {onnx_path}") | |
# Optimize onnx | |
if force_engine_rebuild or not os.path.exists(onnx_opt_path): | |
logger.warning(f"Generating optimizing model: {onnx_opt_path}") | |
onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) | |
onnx.save(onnx_opt_graph, onnx_opt_path) | |
else: | |
logger.warning(f"Found cached optimized model: {onnx_opt_path} ") | |
# Build TensorRT engines | |
for model_name, model_obj in models.items(): | |
engine_path = getEnginePath(model_name, engine_dir) | |
engine = Engine(engine_path) | |
onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) | |
onnx_opt_path = getOnnxPath(model_name, onnx_dir) | |
if force_engine_rebuild or not os.path.exists(engine.engine_path): | |
engine.build( | |
onnx_opt_path, | |
fp16=True, | |
input_profile=model_obj.get_input_profile( | |
opt_batch_size, | |
opt_image_height, | |
opt_image_width, | |
static_batch=static_batch, | |
static_shape=static_shape, | |
), | |
enable_preview=enable_preview, | |
timing_cache=timing_cache, | |
workspace_size=max_workspace_size, | |
) | |
built_engines[model_name] = engine | |
# Load and activate TensorRT engines | |
for model_name, model_obj in models.items(): | |
engine = built_engines[model_name] | |
engine.load() | |
engine.activate() | |
return built_engines | |
def runEngine(engine, feed_dict, stream): | |
return engine.infer(feed_dict, stream) | |
class CLIP(BaseModel): | |
def __init__(self, model, device, max_batch_size, embedding_dim): | |
super(CLIP, self).__init__( | |
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim | |
) | |
self.name = "CLIP" | |
def get_input_names(self): | |
return ["input_ids"] | |
def get_output_names(self): | |
return ["text_embeddings", "pooler_output"] | |
def get_dynamic_axes(self): | |
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} | |
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): | |
self.check_dims(batch_size, image_height, image_width) | |
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( | |
batch_size, image_height, image_width, static_batch, static_shape | |
) | |
return { | |
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] | |
} | |
def get_shape_dict(self, batch_size, image_height, image_width): | |
self.check_dims(batch_size, image_height, image_width) | |
return { | |
"input_ids": (batch_size, self.text_maxlen), | |
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), | |
} | |
def get_sample_input(self, batch_size, image_height, image_width): | |
self.check_dims(batch_size, image_height, image_width) | |
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) | |
def optimize(self, onnx_graph): | |
opt = Optimizer(onnx_graph) | |
opt.select_outputs([0]) # delete graph output#1 | |
opt.cleanup() | |
opt.fold_constants() | |
opt.infer_shapes() | |
opt.select_outputs([0], names=["text_embeddings"]) # rename network output | |
opt_onnx_graph = opt.cleanup(return_onnx=True) | |
return opt_onnx_graph | |
def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): | |
return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) | |
class UNet(BaseModel): | |
def __init__( | |
self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 | |
): | |
super(UNet, self).__init__( | |
model=model, | |
fp16=fp16, | |
device=device, | |
max_batch_size=max_batch_size, | |
embedding_dim=embedding_dim, | |
text_maxlen=text_maxlen, | |
) | |
self.unet_dim = unet_dim | |
self.name = "UNet" | |
def get_input_names(self): | |
return ["sample", "timestep", "encoder_hidden_states"] | |
def get_output_names(self): | |
return ["latent"] | |
def get_dynamic_axes(self): | |
return { | |
"sample": {0: "2B", 2: "H", 3: "W"}, | |
"encoder_hidden_states": {0: "2B"}, | |
"latent": {0: "2B", 2: "H", 3: "W"}, | |
} | |
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
( | |
min_batch, | |
max_batch, | |
_, | |
_, | |
_, | |
_, | |
min_latent_height, | |
max_latent_height, | |
min_latent_width, | |
max_latent_width, | |
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) | |
return { | |
"sample": [ | |
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), | |
(2 * batch_size, self.unet_dim, latent_height, latent_width), | |
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), | |
], | |
"encoder_hidden_states": [ | |
(2 * min_batch, self.text_maxlen, self.embedding_dim), | |
(2 * batch_size, self.text_maxlen, self.embedding_dim), | |
(2 * max_batch, self.text_maxlen, self.embedding_dim), | |
], | |
} | |
def get_shape_dict(self, batch_size, image_height, image_width): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
return { | |
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), | |
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), | |
"latent": (2 * batch_size, 4, latent_height, latent_width), | |
} | |
def get_sample_input(self, batch_size, image_height, image_width): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
dtype = torch.float16 if self.fp16 else torch.float32 | |
return ( | |
torch.randn( | |
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device | |
), | |
torch.tensor([1.0], dtype=torch.float32, device=self.device), | |
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), | |
) | |
def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False): | |
return UNet( | |
model, | |
fp16=True, | |
device=device, | |
max_batch_size=max_batch_size, | |
embedding_dim=embedding_dim, | |
unet_dim=(9 if inpaint else 4), | |
) | |
class VAE(BaseModel): | |
def __init__(self, model, device, max_batch_size, embedding_dim): | |
super(VAE, self).__init__( | |
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim | |
) | |
self.name = "VAE decoder" | |
def get_input_names(self): | |
return ["latent"] | |
def get_output_names(self): | |
return ["images"] | |
def get_dynamic_axes(self): | |
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} | |
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
( | |
min_batch, | |
max_batch, | |
_, | |
_, | |
_, | |
_, | |
min_latent_height, | |
max_latent_height, | |
min_latent_width, | |
max_latent_width, | |
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) | |
return { | |
"latent": [ | |
(min_batch, 4, min_latent_height, min_latent_width), | |
(batch_size, 4, latent_height, latent_width), | |
(max_batch, 4, max_latent_height, max_latent_width), | |
] | |
} | |
def get_shape_dict(self, batch_size, image_height, image_width): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
return { | |
"latent": (batch_size, 4, latent_height, latent_width), | |
"images": (batch_size, 3, image_height, image_width), | |
} | |
def get_sample_input(self, batch_size, image_height, image_width): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) | |
def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): | |
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) | |
class TorchVAEEncoder(torch.nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.vae_encoder = model | |
def forward(self, x): | |
return self.vae_encoder.encode(x).latent_dist.sample() | |
class VAEEncoder(BaseModel): | |
def __init__(self, model, device, max_batch_size, embedding_dim): | |
super(VAEEncoder, self).__init__( | |
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim | |
) | |
self.name = "VAE encoder" | |
def get_model(self): | |
vae_encoder = TorchVAEEncoder(self.model) | |
return vae_encoder | |
def get_input_names(self): | |
return ["images"] | |
def get_output_names(self): | |
return ["latent"] | |
def get_dynamic_axes(self): | |
return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} | |
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): | |
assert batch_size >= self.min_batch and batch_size <= self.max_batch | |
min_batch = batch_size if static_batch else self.min_batch | |
max_batch = batch_size if static_batch else self.max_batch | |
self.check_dims(batch_size, image_height, image_width) | |
( | |
min_batch, | |
max_batch, | |
min_image_height, | |
max_image_height, | |
min_image_width, | |
max_image_width, | |
_, | |
_, | |
_, | |
_, | |
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) | |
return { | |
"images": [ | |
(min_batch, 3, min_image_height, min_image_width), | |
(batch_size, 3, image_height, image_width), | |
(max_batch, 3, max_image_height, max_image_width), | |
] | |
} | |
def get_shape_dict(self, batch_size, image_height, image_width): | |
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) | |
return { | |
"images": (batch_size, 3, image_height, image_width), | |
"latent": (batch_size, 4, latent_height, latent_width), | |
} | |
def get_sample_input(self, batch_size, image_height, image_width): | |
self.check_dims(batch_size, image_height, image_width) | |
return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) | |
def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False): | |
return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) | |
class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): | |
r""" | |
Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. | |
This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
vae ([`AutoencoderKL`]): | |
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
text_encoder ([`CLIPTextModel`]): | |
Frozen text-encoder. Stable Diffusion uses the text portion of | |
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically | |
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. | |
tokenizer (`CLIPTokenizer`): | |
Tokenizer of class | |
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | |
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | |
scheduler ([`SchedulerMixin`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | |
safety_checker ([`StableDiffusionSafetyChecker`]): | |
Classification module that estimates whether generated images could be considered offensive or harmful. | |
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. | |
feature_extractor ([`CLIPFeatureExtractor`]): | |
Model that extracts features from generated images to be used as inputs for the `safety_checker`. | |
""" | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNet2DConditionModel, | |
scheduler: DDIMScheduler, | |
safety_checker: StableDiffusionSafetyChecker, | |
feature_extractor: CLIPFeatureExtractor, | |
requires_safety_checker: bool = True, | |
stages=["clip", "unet", "vae", "vae_encoder"], | |
image_height: int = 512, | |
image_width: int = 512, | |
max_batch_size: int = 16, | |
# ONNX export parameters | |
onnx_opset: int = 17, | |
onnx_dir: str = "onnx", | |
# TensorRT engine build parameters | |
engine_dir: str = "engine", | |
build_preview_features: bool = True, | |
force_engine_rebuild: bool = False, | |
timing_cache: str = "timing_cache", | |
): | |
super().__init__( | |
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker | |
) | |
self.vae.forward = self.vae.decode | |
self.stages = stages | |
self.image_height, self.image_width = image_height, image_width | |
self.inpaint = False | |
self.onnx_opset = onnx_opset | |
self.onnx_dir = onnx_dir | |
self.engine_dir = engine_dir | |
self.force_engine_rebuild = force_engine_rebuild | |
self.timing_cache = timing_cache | |
self.build_static_batch = False | |
self.build_dynamic_shape = False | |
self.build_preview_features = build_preview_features | |
self.max_batch_size = max_batch_size | |
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. | |
if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: | |
self.max_batch_size = 4 | |
self.stream = None # loaded in loadResources() | |
self.models = {} # loaded in __loadModels() | |
self.engine = {} # loaded in build_engines() | |
def __loadModels(self): | |
# Load pipeline models | |
self.embedding_dim = self.text_encoder.config.hidden_size | |
models_args = { | |
"device": self.torch_device, | |
"max_batch_size": self.max_batch_size, | |
"embedding_dim": self.embedding_dim, | |
"inpaint": self.inpaint, | |
} | |
if "clip" in self.stages: | |
self.models["clip"] = make_CLIP(self.text_encoder, **models_args) | |
if "unet" in self.stages: | |
self.models["unet"] = make_UNet(self.unet, **models_args) | |
if "vae" in self.stages: | |
self.models["vae"] = make_VAE(self.vae, **models_args) | |
if "vae_encoder" in self.stages: | |
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) | |
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): | |
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", False) | |
use_auth_token = kwargs.pop("use_auth_token", None) | |
revision = kwargs.pop("revision", None) | |
cls.cached_folder = ( | |
pretrained_model_name_or_path | |
if os.path.isdir(pretrained_model_name_or_path) | |
else snapshot_download( | |
pretrained_model_name_or_path, | |
cache_dir=cache_dir, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
) | |
) | |
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): | |
super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings) | |
self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) | |
self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) | |
self.timing_cache = os.path.join(self.cached_folder, self.timing_cache) | |
# set device | |
self.torch_device = self._execution_device | |
logger.warning(f"Running inference on device: {self.torch_device}") | |
# load models | |
self.__loadModels() | |
# build engines | |
self.engine = build_engines( | |
self.models, | |
self.engine_dir, | |
self.onnx_dir, | |
self.onnx_opset, | |
opt_image_height=self.image_height, | |
opt_image_width=self.image_width, | |
force_engine_rebuild=self.force_engine_rebuild, | |
static_batch=self.build_static_batch, | |
static_shape=not self.build_dynamic_shape, | |
enable_preview=self.build_preview_features, | |
timing_cache=self.timing_cache, | |
) | |
return self | |
def __initialize_timesteps(self, timesteps, strength): | |
self.scheduler.set_timesteps(timesteps) | |
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 | |
init_timestep = int(timesteps * strength) + offset | |
init_timestep = min(init_timestep, timesteps) | |
t_start = max(timesteps - init_timestep + offset, 0) | |
timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device) | |
return timesteps, t_start | |
def __preprocess_images(self, batch_size, images=()): | |
init_images = [] | |
for image in images: | |
image = image.to(self.torch_device).float() | |
image = image.repeat(batch_size, 1, 1, 1) | |
init_images.append(image) | |
return tuple(init_images) | |
def __encode_image(self, init_image): | |
init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ | |
"latent" | |
] | |
init_latents = 0.18215 * init_latents | |
return init_latents | |
def __encode_prompt(self, prompt, negative_prompt): | |
r""" | |
Encodes the prompt into text encoder hidden states. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
prompt to be encoded | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. | |
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). | |
""" | |
# Tokenize prompt | |
text_input_ids = ( | |
self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
.input_ids.type(torch.int32) | |
.to(self.torch_device) | |
) | |
text_input_ids_inp = device_view(text_input_ids) | |
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt | |
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ | |
"text_embeddings" | |
].clone() | |
# Tokenize negative prompt | |
uncond_input_ids = ( | |
self.tokenizer( | |
negative_prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
.input_ids.type(torch.int32) | |
.to(self.torch_device) | |
) | |
uncond_input_ids_inp = device_view(uncond_input_ids) | |
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ | |
"text_embeddings" | |
] | |
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) | |
return text_embeddings | |
def __denoise_latent( | |
self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None | |
): | |
if not isinstance(timesteps, torch.Tensor): | |
timesteps = self.scheduler.timesteps | |
for step_index, timestep in enumerate(timesteps): | |
# Expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) | |
if isinstance(mask, torch.Tensor): | |
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) | |
# Predict the noise residual | |
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep | |
sample_inp = device_view(latent_model_input) | |
timestep_inp = device_view(timestep_float) | |
embeddings_inp = device_view(text_embeddings) | |
noise_pred = runEngine( | |
self.engine["unet"], | |
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, | |
self.stream, | |
)["latent"] | |
# Perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample | |
latents = 1.0 / 0.18215 * latents | |
return latents | |
def __decode_latent(self, latents): | |
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] | |
images = (images / 2 + 0.5).clamp(0, 1) | |
return images.cpu().permute(0, 2, 3, 1).float().numpy() | |
def __loadResources(self, image_height, image_width, batch_size): | |
self.stream = cuda.Stream() | |
# Allocate buffers for TensorRT engine bindings | |
for model_name, obj in self.models.items(): | |
self.engine[model_name].allocate_buffers( | |
shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device | |
) | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
image: Union[torch.FloatTensor, PIL.Image.Image] = None, | |
strength: float = 0.8, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
): | |
r""" | |
Function invoked when calling the pipeline for generation. | |
Args: | |
prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | |
instead. | |
image (`PIL.Image.Image`): | |
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will | |
be masked out with `mask_image` and repainted according to `prompt`. | |
strength (`float`, *optional*, defaults to 0.8): | |
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` | |
will be used as a starting point, adding more noise to it the larger the `strength`. The number of | |
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will | |
be maximum and the denoising process will run for the full number of iterations specified in | |
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. | |
num_inference_steps (`int`, *optional*, defaults to 50): | |
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
expense of slower inference. | |
guidance_scale (`float`, *optional*, defaults to 7.5): | |
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
`guidance_scale` is defined as `w` of equation 2. of [Imagen | |
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | |
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | |
usually at the expense of lower image quality. | |
negative_prompt (`str` or `List[str]`, *optional*): | |
The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. | |
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). | |
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) | |
to make generation deterministic. | |
""" | |
self.generator = generator | |
self.denoising_steps = num_inference_steps | |
self.guidance_scale = guidance_scale | |
# Pre-compute latent input scales and linear multistep coefficients | |
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) | |
# Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
prompt = [prompt] | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") | |
if negative_prompt is None: | |
negative_prompt = [""] * batch_size | |
if negative_prompt is not None and isinstance(negative_prompt, str): | |
negative_prompt = [negative_prompt] | |
assert len(prompt) == len(negative_prompt) | |
if batch_size > self.max_batch_size: | |
raise ValueError( | |
f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" | |
) | |
# load resources | |
self.__loadResources(self.image_height, self.image_width, batch_size) | |
with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): | |
# Initialize timesteps | |
timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength) | |
latent_timestep = timesteps[:1].repeat(batch_size) | |
# Pre-process input image | |
if isinstance(image, PIL.Image.Image): | |
image = preprocess_image(image) | |
init_image = self.__preprocess_images(batch_size, (image,))[0] | |
# VAE encode init image | |
init_latents = self.__encode_image(init_image) | |
# Add noise to latents using timesteps | |
noise = torch.randn( | |
init_latents.shape, generator=self.generator, device=self.torch_device, dtype=torch.float32 | |
) | |
latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) | |
# CLIP text encoder | |
text_embeddings = self.__encode_prompt(prompt, negative_prompt) | |
# UNet denoiser | |
latents = self.__denoise_latent(latents, text_embeddings, timesteps=timesteps, step_offset=t_start) | |
# VAE decode latent | |
images = self.__decode_latent(latents) | |
images = self.numpy_to_pil(images) | |
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) | |