import inspect import re from pathlib import Path from typing import Callable, List, Optional, Tuple, Union import diffusers import numpy as np import PIL import torch from accelerate import init_empty_weights from diffusers import ( AutoencoderKL, DDIMScheduler, EulerDiscreteScheduler, LCMScheduler, LMSDiscreteScheduler, PNDMScheduler, StableDiffusionXLPipeline, ) from diffusers.configuration_utils import FrozenDict from diffusers.utils.deprecation_utils import deprecate from einops import rearrange from PIL import Image from PIL.PngImagePlugin import PngInfo from safetensors.torch import load_file from tqdm import tqdm from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, ) import external.llite.library.model_util as model_util import external.llite.library.sdxl_model_util as sdxl_model_util import external.llite.library.sdxl_original_unet as sdxl_original_unet import external.llite.library.sdxl_train_util as sdxl_train_util import external.llite.library.train_util as train_util from external.llite.library.original_unet import FlashAttentionFunction from external.llite.library.sdxl_original_unet import InferSdxlUNet2DConditionModel from external.llite.networks.control_net_lllite import ControlNetLLLite from external.llite.networks.lora import LoRANetwork from internals.pipelines.commons import AbstractPipeline from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_file class PipelineLike: def __init__( self, device, vae: AutoencoderKL, text_encoders: List[CLIPTextModel], tokenizers: List[CLIPTokenizer], unet: InferSdxlUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, ): super().__init__() self.device = device self.clip_skip = clip_skip if ( hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1 ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate( "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if ( hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate( "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) self.vae = vae self.text_encoders = text_encoders self.tokenizers = tokenizers self.unet: InferSdxlUNet2DConditionModel = unet self.scheduler = scheduler self.safety_checker = None self.clip_vision_model: CLIPVisionModelWithProjection = None self.clip_vision_processor: CLIPImageProcessor = None self.clip_vision_strength = 0.0 # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): self.token_replacements_list.append({}) # ControlNet # not supported yet self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): self.token_replacements_list[text_encoder_index][ target_token_id ] = rep_token_ids def set_enable_control_net(self, en: bool): self.control_net_enabled = en def preprocess_image(self, image): w, h = image.size # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 def get_unweighted_text_embeddings( self, text_encoder: CLIPTextModel, text_input: torch.Tensor, chunk_length: int, clip_skip: int, eos: int, pad: int, no_boseos_middle: Optional[bool] = True, ): """ When the length of tokens is a multiple of the capacity of the text encoder, it should be split into chunks and sent to the text encoder individually. """ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) if max_embeddings_multiples > 1: text_embeddings = [] pool = None for i in range(max_embeddings_multiples): # extract the i-th chunk text_input_chunk = text_input[ :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 ].clone() # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] if pad == eos: # v1 text_input_chunk[:, -1] = text_input[0, -1] else: # v2 for j in range(len(text_input_chunk)): # 最後に普通の文字がある if ( text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad ): text_input_chunk[j, -1] = eos if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD text_input_chunk[j, 1] = eos # -2 is same for Text Encoder 1 and 2 enc_out = text_encoder( text_input_chunk, output_hidden_states=True, return_dict=True ) text_embedding = enc_out["hidden_states"][-2] if pool is None: # use 1st chunk, if provided pool = enc_out.get("text_embeds", None) if pool is not None: pool = train_util.pool_workaround( text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos, ) if no_boseos_middle: if i == 0: # discard the ending token text_embedding = text_embedding[:, :-1] elif i == max_embeddings_multiples - 1: # discard the starting token text_embedding = text_embedding[:, 1:] else: # discard both starting and ending tokens text_embedding = text_embedding[:, 1:-1] text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: enc_out = text_encoder( text_input, output_hidden_states=True, return_dict=True ) text_embeddings = enc_out["hidden_states"][-2] # text encoder 1 doesn't return this pool = enc_out.get("text_embeds", None) if pool is not None: pool = train_util.pool_workaround( text_encoder, enc_out["last_hidden_state"], text_input, eos ) return text_embeddings, pool def preprocess_mask(self, mask): mask = mask.convert("L") w, h = mask.size # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = 1 - mask # repaint white, keep black mask = torch.from_numpy(mask) return mask def get_prompts_with_weights( self, tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int, ): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. """ tokens = [] weights = [] truncated = False def parse_prompt_attention(text): """ Parses a string with attention tokens and returns a list of pairs: text and its associated weight. Accepted tokens are: (abc) - increases attention to abc by a multiplier of 1.1 (abc:3.12) - increases attention to abc by a multiplier of 3.12 [abc] - decreases attention to abc by a multiplier of 1.1 \( - literal character '(' \[ - literal character '[' \) - literal character ')' \] - literal character ']' \\ - literal character '\' anything else - just text >>> parse_prompt_attention('normal text') [['normal text', 1.0]] >>> parse_prompt_attention('an (important) word') [['an ', 1.0], ['important', 1.1], [' word', 1.0]] >>> parse_prompt_attention('(unbalanced') [['unbalanced', 1.1]] >>> parse_prompt_attention('\(literal\]') [['(literal]', 1.0]] >>> parse_prompt_attention('(unnecessary)(parens)') [['unnecessaryparens', 1.1]] >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') [['a ', 1.0], ['house', 1.5730000000000004], [' ', 1.1], ['on', 1.0], [' a ', 1.1], ['hill', 0.55], [', sun, ', 1.1], ['sky', 1.4641000000000006], ['.', 1.1]] """ res = [] round_brackets = [] square_brackets = [] round_bracket_multiplier = 1.1 square_bracket_multiplier = 1 / 1.1 def multiply_range(start_position, multiplier): for p in range(start_position, len(res)): res[p][1] *= multiplier # keep break as separate token text = text.replace("BREAK", "\\BREAK\\") re_attention = re.compile( r""" \\\(| \\\)| \\\[| \\]| \\\\| \\| \(| \[| :([+-]?[.\d]+)\)| \)| ]| [^\\()\[\]:]+| : """, re.X, ) for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) if text.startswith("\\"): res.append([text[1:], 1.0]) elif text == "(": round_brackets.append(len(res)) elif text == "[": square_brackets.append(len(res)) elif weight is not None and len(round_brackets) > 0: multiply_range(round_brackets.pop(), float(weight)) elif text == ")" and len(round_brackets) > 0: multiply_range(round_brackets.pop(), round_bracket_multiplier) elif text == "]" and len(square_brackets) > 0: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: res.append([text, 1.0]) for pos in round_brackets: multiply_range(pos, round_bracket_multiplier) for pos in square_brackets: multiply_range(pos, square_bracket_multiplier) if len(res) == 0: res = [["", 1.0]] # merge runs of identical weights i = 0 while i + 1 < len(res): if ( res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK" ): res[i][0] += res[i + 1][0] res.pop(i + 1) else: i += 1 return res for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] text_weight = [] for word, weight in texts_and_weights: if word.strip() == "BREAK": # pad until next multiple of tokenizer's max token length pad_len = tokenizer.model_max_length - ( len(text_token) % tokenizer.model_max_length ) print(f"BREAK pad_len: {pad_len}") for i in range(pad_len): # v2のときEOSをつけるべきかどうかわからないぜ # if i == 0: # text_token.append(tokenizer.eos_token_id) # else: text_token.append(tokenizer.pad_token_id) text_weight.append(1.0) continue # tokenize and discard the starting and the ending token token = tokenizer(word).input_ids[1:-1] token = token_replacer(token) # for Textual Inversion text_token += token # copy the weight by length of token text_weight += [weight] * len(token) # stop if the text is too long (longer than truncation limit) if len(text_token) > max_length: truncated = True break # truncate if len(text_token) > max_length: truncated = True text_token = text_token[:max_length] text_weight = text_weight[:max_length] tokens.append(text_token) weights.append(text_weight) if truncated: print( "warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" ) return tokens, weights def pad_tokens_and_weights( self, tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77, ): r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = ( max_length if no_boseos_middle else max_embeddings_multiples * chunk_length ) for i in range(len(tokens)): tokens[i] = ( [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) ) if no_boseos_middle: weights[i] = ( [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) ) else: w = [] if len(weights[i]) == 0: w = [1.0] * weights_length else: for j in range(max_embeddings_multiples): # weight for starting token in this chunk w.append(1.0) w += weights[i][ j * (chunk_length - 2) : min( len(weights[i]), (j + 1) * (chunk_length - 2) ) ] w.append(1.0) # weight for ending token in this chunk w += [1.0] * (weights_length - len(w)) weights[i] = w[:] return tokens, weights def get_unweighted_text_embeddings( self, text_encoder: CLIPTextModel, text_input: torch.Tensor, chunk_length: int, clip_skip: int, eos: int, pad: int, no_boseos_middle: Optional[bool] = True, ): """ When the length of tokens is a multiple of the capacity of the text encoder, it should be split into chunks and sent to the text encoder individually. """ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) if max_embeddings_multiples > 1: text_embeddings = [] pool = None for i in range(max_embeddings_multiples): # extract the i-th chunk text_input_chunk = text_input[ :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 ].clone() # cover the head and the tail by the starting and the ending tokens text_input_chunk[:, 0] = text_input[0, 0] if pad == eos: # v1 text_input_chunk[:, -1] = text_input[0, -1] else: # v2 for j in range(len(text_input_chunk)): # 最後に普通の文字がある if ( text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad ): text_input_chunk[j, -1] = eos if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD text_input_chunk[j, 1] = eos # -2 is same for Text Encoder 1 and 2 enc_out = text_encoder( text_input_chunk, output_hidden_states=True, return_dict=True ) text_embedding = enc_out["hidden_states"][-2] if pool is None: # use 1st chunk, if provided pool = enc_out.get("text_embeds", None) if pool is not None: pool = train_util.pool_workaround( text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos, ) if no_boseos_middle: if i == 0: # discard the ending token text_embedding = text_embedding[:, :-1] elif i == max_embeddings_multiples - 1: # discard the starting token text_embedding = text_embedding[:, 1:] else: # discard both starting and ending tokens text_embedding = text_embedding[:, 1:-1] text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: enc_out = text_encoder( text_input, output_hidden_states=True, return_dict=True ) text_embeddings = enc_out["hidden_states"][-2] # text encoder 1 doesn't return this pool = enc_out.get("text_embeds", None) if pool is not None: pool = train_util.pool_workaround( text_encoder, enc_out["last_hidden_state"], text_input, eos ) return text_embeddings, pool def get_weighted_text_embeddings( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, prompt: Union[str, List[str]], uncond_prompt: Optional[Union[str, List[str]]] = None, max_embeddings_multiples: Optional[int] = 1, no_boseos_middle: Optional[bool] = False, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, clip_skip=None, token_replacer=None, device=None, **kwargs, ): max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 if isinstance(prompt, str): prompt = [prompt] # split the prompts with "AND". each prompt must have the same number of splits new_prompts = [] for p in prompt: new_prompts.extend(p.split(" AND ")) prompt = new_prompts if not skip_parsing: prompt_tokens, prompt_weights = self.get_prompts_with_weights( tokenizer, token_replacer, prompt, max_length - 2 ) if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] uncond_tokens, uncond_weights = self.get_prompts_with_weights( tokenizer, token_replacer, uncond_prompt, max_length - 2 ) else: prompt_tokens = [ token[1:-1] for token in tokenizer( prompt, max_length=max_length, truncation=True ).input_ids ] prompt_weights = [[1.0] * len(token) for token in prompt_tokens] if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] uncond_tokens = [ token[1:-1] for token in tokenizer( uncond_prompt, max_length=max_length, truncation=True ).input_ids ] uncond_weights = [[1.0] * len(token) for token in uncond_tokens] # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) if uncond_prompt is not None: max_length = max(max_length, max([len(token) for token in uncond_tokens])) max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (tokenizer.model_max_length - 2) + 1, ) max_embeddings_multiples = max(1, max_embeddings_multiples) max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 # pad the length of tokens and weights bos = tokenizer.bos_token_id eos = tokenizer.eos_token_id pad = tokenizer.pad_token_id prompt_tokens, prompt_weights = self.pad_tokens_and_weights( prompt_tokens, prompt_weights, max_length, bos, eos, pad, no_boseos_middle=no_boseos_middle, chunk_length=tokenizer.model_max_length, ) prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) if uncond_prompt is not None: uncond_tokens, uncond_weights = self.pad_tokens_and_weights( uncond_tokens, uncond_weights, max_length, bos, eos, pad, no_boseos_middle=no_boseos_middle, chunk_length=tokenizer.model_max_length, ) uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) # get the embeddings text_embeddings, text_pool = self.get_unweighted_text_embeddings( text_encoder, prompt_tokens, tokenizer.model_max_length, clip_skip, eos, pad, no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor( prompt_weights, dtype=text_embeddings.dtype, device=device ) if uncond_prompt is not None: uncond_embeddings, uncond_pool = self.get_unweighted_text_embeddings( text_encoder, uncond_tokens, tokenizer.model_max_length, clip_skip, eos, pad, no_boseos_middle=no_boseos_middle, ) uncond_weights = torch.tensor( uncond_weights, dtype=uncond_embeddings.dtype, device=device ) # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? # →全体でいいんじゃないかな if (not skip_parsing) and (not skip_weighting): previous_mean = ( text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) ) text_embeddings *= prompt_weights.unsqueeze(-1) current_mean = ( text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) ) text_embeddings *= ( (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) ) if uncond_prompt is not None: previous_mean = ( uncond_embeddings.float() .mean(axis=[-2, -1]) .to(uncond_embeddings.dtype) ) uncond_embeddings *= uncond_weights.unsqueeze(-1) current_mean = ( uncond_embeddings.float() .mean(axis=[-2, -1]) .to(uncond_embeddings.dtype) ) uncond_embeddings *= ( (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) ) if uncond_prompt is not None: return ( text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens, ) return text_embeddings, text_pool, None, None, prompt_tokens def get_token_replacer(self, tokenizer): tokenizer_index = self.tokenizers.index(tokenizer) token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): # print("replace_tokens", tokens, "=>", token_replacements) if isinstance(tokens, torch.Tensor): tokens = tokens.tolist() new_tokens = [] for token in tokens: if token in token_replacements: replacement = token_replacements[token] new_tokens.extend(replacement) else: new_tokens.append(token) return new_tokens return replace_tokens def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, init_image: Union[ torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image] ] = None, mask_image: Union[ torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image] ] = None, height: int = 1024, width: int = 1024, original_height: int = None, original_width: int = None, original_height_negative: int = None, original_width_negative: int = None, crop_top: int = 0, crop_left: int = 0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_scale: float = None, strength: float = 0.8, # num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", vae_batch_size: float = None, return_latents: bool = False, # return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, img2img_noise=None, clip_guide_images=None, **kwargs, ): # TODO support secondary prompt num_images_per_prompt = 1 # fixed because already prompt is repeated if isinstance(prompt, str): batch_size = 1 prompt = [prompt] elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError( f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" ) reginonal_network = " AND " in prompt[0] vae_batch_size = ( batch_size if vae_batch_size is None else ( int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)) ) ) if strength < 0 or strength > 1: raise ValueError( f"The value of strength should in [0.0, 1.0] but is {strength}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError( f"`height` and `width` have to be divisible by 8 but are {height} and {width}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # get prompt text embeddings # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if not do_classifier_free_guidance and negative_scale is not None: print(f"negative_scale is ignored if guidance scalle <= 1.0") negative_scale = None # get unconditional embeddings for classifier free guidance if negative_prompt is None: negative_prompt = [""] * batch_size elif isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * batch_size if batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) tes_text_embs = [] tes_uncond_embs = [] tes_real_uncond_embs = [] for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): token_replacer = self.get_token_replacer(tokenizer) # use last text_pool, because it is from text encoder 2 ( text_embeddings, text_pool, uncond_embeddings, uncond_pool, _, ) = self.get_weighted_text_embeddings( tokenizer, text_encoder, prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, clip_skip=self.clip_skip, token_replacer=token_replacer, device=self.device, **kwargs, ) tes_text_embs.append(text_embeddings) tes_uncond_embs.append(uncond_embeddings) if negative_scale is not None: _, real_uncond_embeddings, _ = self.get_weighted_text_embeddings( token_replacer, prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 uncond_prompt=[""] * batch_size, max_embeddings_multiples=max_embeddings_multiples, clip_skip=self.clip_skip, token_replacer=token_replacer, device=self.device, **kwargs, ) tes_real_uncond_embs.append(real_uncond_embeddings) # concat text encoder outputs text_embeddings = tes_text_embs[0] uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat( [text_embeddings, tes_text_embs[i]], dim=2 ) # n,77,2048 if do_classifier_free_guidance: uncond_embeddings = torch.cat( [uncond_embeddings, tes_uncond_embs[i]], dim=2 ) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) else: text_embeddings = torch.cat( [uncond_embeddings, text_embeddings, real_uncond_embeddings] ) if self.control_nets: # ControlNetのhintにguide imageを流用する if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] if isinstance(clip_guide_images[0], PIL.Image.Image): clip_guide_images = [ self.preprocess_image(im) for im in clip_guide_images ] clip_guide_images = torch.cat(clip_guide_images) if isinstance(clip_guide_images, list): clip_guide_images = torch.stack(clip_guide_images) clip_guide_images = clip_guide_images.to( self.device, dtype=text_embeddings.dtype ) # create size embs if original_height is None: original_height = height if original_width is None: original_width = width if original_height_negative is None: original_height_negative = original_height if original_width_negative is None: original_width_negative = original_width if crop_top is None: crop_top = 0 if crop_left is None: crop_left = 0 emb1 = sdxl_train_util.get_timestep_embedding( torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256 ) uc_emb1 = sdxl_train_util.get_timestep_embedding( torch.FloatTensor( [original_height_negative, original_width_negative] ).unsqueeze(0), 256, ) emb2 = sdxl_train_util.get_timestep_embedding( torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256 ) emb3 = sdxl_train_util.get_timestep_embedding( torch.FloatTensor([height, width]).unsqueeze(0), 256 ) c_vector = ( torch.cat([emb1, emb2, emb3], dim=1) .to(self.device, dtype=text_embeddings.dtype) .repeat(batch_size, 1) ) uc_vector = ( torch.cat([uc_emb1, emb2, emb3], dim=1) .to(self.device, dtype=text_embeddings.dtype) .repeat(batch_size, 1) ) if reginonal_network: # use last pool for conditioning num_sub_prompts = len(text_pool) // batch_size text_pool = text_pool[ num_sub_prompts - 1 :: num_sub_prompts ] # last subprompt if init_image is not None and self.clip_vision_model is not None: print( f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}" ) vision_input = self.clip_vision_processor( init_image, return_tensors="pt", device=self.device ) pixel_values = vision_input["pixel_values"].to( self.device, dtype=text_embeddings.dtype ) clip_vision_embeddings = self.clip_vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True ) clip_vision_embeddings = clip_vision_embeddings.image_embeds if len(clip_vision_embeddings) == 1 and batch_size > 1: clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength assert ( clip_vision_embeddings.shape == text_pool.shape ), f"{clip_vision_embeddings.shape} != {text_pool.shape}" text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) if do_classifier_free_guidance: uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) vector_embeddings = torch.cat([uc_vector, c_vector]) else: vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) latents_dtype = text_embeddings.dtype init_latents_orig = None mask = None if init_image is None: # get the initial random noise unless the user supplied it # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. latents_shape = ( batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8, ) if latents is None: if self.device.type == "mps": # randn does not exist on mps latents = torch.randn( latents_shape, generator=generator, device="cpu", dtype=latents_dtype, ).to(self.device) else: latents = torch.randn( latents_shape, generator=generator, device=self.device, dtype=latents_dtype, ) else: if latents.shape != latents_shape: raise ValueError( f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" ) latents = latents.to(self.device) timesteps = self.scheduler.timesteps.to(self.device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma else: # image to tensor if isinstance(init_image, PIL.Image.Image): init_image = [init_image] if isinstance(init_image[0], PIL.Image.Image): init_image = [self.preprocess_image(im) for im in init_image] init_image = torch.cat(init_image) if isinstance(init_image, list): init_image = torch.stack(init_image) # mask image to tensor if mask_image is not None: if isinstance(mask_image, PIL.Image.Image): mask_image = [mask_image] if isinstance(mask_image[0], PIL.Image.Image): mask_image = torch.cat( [self.preprocess_mask(im) for im in mask_image] ) # H*W, 0 for repaint # encode the init image into latents and scale the latents init_image = init_image.to(device=self.device, dtype=latents_dtype) if init_image.size()[-2:] == (height // 8, width // 8): init_latents = init_image else: if vae_batch_size >= batch_size: init_latent_dist = self.vae.encode( init_image.to(self.vae.dtype) ).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: if torch.cuda.is_available(): torch.cuda.empty_cache() init_latents = [] for i in tqdm( range(0, min(batch_size, len(init_image)), vae_batch_size) ): init_latent_dist = self.vae.encode( ( init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) ).to(self.vae.dtype) ).latent_dist init_latents.append( init_latent_dist.sample(generator=generator) ) init_latents = torch.cat(init_latents) init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents if len(init_latents) == 1: init_latents = init_latents.repeat((batch_size, 1, 1, 1)) init_latents_orig = init_latents # preprocess mask if mask_image is not None: mask = mask_image.to(device=self.device, dtype=latents_dtype) if len(mask) == 1: mask = mask.repeat((batch_size, 1, 1, 1)) # check sizes if not mask.shape == init_latents.shape: raise ValueError("The mask and init_image should be the same size!") # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor( [timesteps] * batch_size * num_images_per_prompt, device=self.device ) # add noise to latents using the timesteps latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].to(self.device) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set( inspect.signature(self.scheduler.step).parameters.keys() ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta num_latent_input = ( (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 ) if self.control_nets: # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) if self.control_net_enabled: for control_net, _ in self.control_nets: with torch.no_grad(): control_net.set_cond_image(clip_guide_images) else: for control_net, _ in self.control_nets: control_net.set_cond_image(None) each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # disable control net if ratio is set if self.control_nets and self.control_net_enabled: for j, ((control_net, ratio), enabled) in enumerate( zip(self.control_nets, each_control_net_enabled) ): if not enabled or ratio >= 1.0: continue if ratio < i / len(timesteps): print( f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})" ) control_net.set_cond_image(None) each_control_net_enabled[j] = False # predict the noise residual # TODO Diffusers' ControlNet # if self.control_nets and self.control_net_enabled: # if reginonal_network: # num_sub_and_neg_prompts = len(text_embeddings) // batch_size # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt # else: # text_emb_last = text_embeddings # # not working yet # noise_pred = original_control_net.call_unet_and_control_net( # i, # num_latent_input, # self.unet, # self.control_nets, # guided_hints, # i / len(timesteps), # latent_model_input, # t, # text_emb_last, # ).sample # else: noise_pred = self.unet( latent_model_input, t, text_embeddings, vector_embeddings ) # perform guidance if do_classifier_free_guidance: if negative_scale is None: noise_pred_uncond, noise_pred_text = noise_pred.chunk( num_latent_input ) # uncond by negative prompt noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) else: ( noise_pred_negative, noise_pred_text, noise_pred_uncond, ) = noise_pred.chunk( num_latent_input ) # uncond is real uncond noise_pred = ( noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond) ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs ).prev_sample if mask is not None: # masking init_latents_proper = self.scheduler.add_noise( init_latents_orig, img2img_noise, torch.tensor([t]) ) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if i % callback_steps == 0: if callback is not None: callback(i, t, latents) if is_cancelled_callback is not None and is_cancelled_callback(): return None if return_latents: return latents latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: if torch.cuda.is_available(): torch.cuda.empty_cache() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( self.vae.decode( ( latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0) ).to(self.vae.dtype) ).sample ) image = torch.cat(images) image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() if torch.cuda.is_available(): torch.cuda.empty_cache() if output_type == "pil": # image = self.numpy_to_pil(image) image = (image * 255).round().astype("uint8") image = [Image.fromarray(im) for im in image] return image class SDXLLLiteImg2ImgPipeline: from diffusers import UNet2DConditionModel def __init__(self): self.SCHEDULER_LINEAR_START = 0.00085 self.SCHEDULER_LINEAR_END = 0.0120 self.SCHEDULER_TIMESTEPS = 1000 self.SCHEDLER_SCHEDULE = "scaled_linear" self.LATENT_CHANNELS = 4 self.DOWNSAMPLING_FACTOR = 8 def replace_unet_modules( self, unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa, ): if mem_eff_attn: print("Enable memory efficient attention for U-Net") # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い unet.set_use_memory_efficient_attention(False, True) elif xformers: print("Enable xformers for U-Net") try: import xformers.ops except ImportError: raise ImportError("No xformers / xformersがインストールされていないようです") unet.set_use_memory_efficient_attention(True, False) elif sdpa: print("Enable SDPA for U-Net") unet.set_use_memory_efficient_attention(False, False) unet.set_use_sdpa(True) # TODO common train_util.py def replace_vae_modules( self, vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa ): if mem_eff_attn: self.replace_vae_attn_to_memory_efficient() elif xformers: # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う elif sdpa: self.replace_vae_attn_to_sdpa() def replace_vae_attn_to_memory_efficient(self): print( "VAE Attention.forward has been replaced to FlashAttention (not xformers)" ) flash_func = FlashAttentionFunction def forward_flash_attn(self, hidden_states, **kwargs): q_bucket_size = 512 k_bucket_size = 1024 residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view( batch, channel, height * width ).transpose(1, 2) # proj to q, k, v query_proj = self.to_q(hidden_states) key_proj = self.to_k(hidden_states) value_proj = self.to_v(hidden_states) query_proj, key_proj, value_proj = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj), ) out = flash_func.apply( query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size, ) out = rearrange(out, "b h n d -> b n (h d)") # compute next hidden_states # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape( batch, channel, height, width ) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states def forward_flash_attn_0_14(self, hidden_states, **kwargs): if not hasattr(self, "to_q"): self.to_q = self.query self.to_k = self.key self.to_v = self.value self.to_out = [self.proj_attn, torch.nn.Identity()] self.heads = self.num_heads return forward_flash_attn(self, hidden_states, **kwargs) if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 else: diffusers.models.attention_processor.Attention.forward = forward_flash_attn def replace_vae_attn_to_xformers(self): print("VAE: Attention.forward has been replaced to xformers") import xformers.ops def forward_xformers(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view( batch, channel, height * width ).transpose(1, 2) # proj to q, k, v query_proj = self.to_q(hidden_states) key_proj = self.to_k(hidden_states) value_proj = self.to_v(hidden_states) query_proj, key_proj, value_proj = map( lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj), ) query_proj = query_proj.contiguous() key_proj = key_proj.contiguous() value_proj = value_proj.contiguous() out = xformers.ops.memory_efficient_attention( query_proj, key_proj, value_proj, attn_bias=None ) out = rearrange(out, "b h n d -> b n (h d)") # compute next hidden_states # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape( batch, channel, height, width ) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states def forward_xformers_0_14(self, hidden_states, **kwargs): if not hasattr(self, "to_q"): self.to_q = self.query self.to_k = self.key self.to_v = self.value self.to_out = [self.proj_attn, torch.nn.Identity()] self.heads = self.num_heads return forward_xformers(self, hidden_states, **kwargs) if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 else: diffusers.models.attention_processor.Attention.forward = forward_xformers def replace_vae_attn_to_sdpa(): print("VAE: Attention.forward has been replaced to sdpa") def forward_sdpa(self, hidden_states, **kwargs): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view( batch, channel, height * width ).transpose(1, 2) # proj to q, k, v query_proj = self.to_q(hidden_states) key_proj = self.to_k(hidden_states) value_proj = self.to_v(hidden_states) query_proj, key_proj, value_proj = map( lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj), ) out = torch.nn.functional.scaled_dot_product_attention( query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False, ) out = rearrange(out, "b n h d -> b n (h d)") # compute next hidden_states # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape( batch, channel, height, width ) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states def forward_sdpa_0_14(self, hidden_states, **kwargs): if not hasattr(self, "to_q"): self.to_q = self.query self.to_k = self.key self.to_v = self.value self.to_out = [self.proj_attn, torch.nn.Identity()] self.heads = self.num_heads return forward_sdpa(self, hidden_states, **kwargs) if diffusers.__version__ < "0.15.0": diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 else: diffusers.models.attention_processor.Attention.forward = forward_sdpa def load(self, pipeline: AbstractPipeline, controlnet_urls: Optional[List[str]]): pipeline.pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl") pipeline.pipe.fuse_lora() self.dtype = pipeline.pipe.dtype self.device = pipeline.pipe.device state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl( pipeline.pipe.unet.state_dict() ) with init_empty_weights(): original_unet = ( sdxl_original_unet.SdxlUNet2DConditionModel() ) # overwrite unet sdxl_model_util._load_state_dict_on_device( original_unet, state_dict, device=pipeline.pipe.device, dtype=pipeline.pipe.dtype, ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel( original_unet ) sched_init_args = {} has_steps_offset = True has_clip_sample = True scheduler_num_noises_per_step = 1 mem_eff = not (True or False) self.replace_unet_modules(unet, mem_eff, True, False) self.replace_vae_modules(pipeline.pipe.vae, mem_eff, True, False) scheduler_cls = LCMScheduler scheduler_module = diffusers.schedulers.scheduling_ddim if has_steps_offset: sched_init_args["steps_offset"] = 1 if has_clip_sample: sched_init_args["clip_sample"] = False class NoiseManager: def __init__(self): self.sampler_noises = None self.sampler_noise_index = 0 def reset_sampler_noises(self, noises): self.sampler_noise_index = 0 self.sampler_noises = noises def randn( self, shape, device=None, dtype=None, layout=None, generator=None ): # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) if self.sampler_noises is not None and self.sampler_noise_index < len( self.sampler_noises ): noise = self.sampler_noises[self.sampler_noise_index] if shape != noise.shape: noise = None else: noise = None if noise == None: print( f"unexpected noise request: {self.sampler_noise_index}, {shape}" ) noise = torch.randn( shape, dtype=dtype, device=device, generator=generator ) self.sampler_noise_index += 1 return noise class TorchRandReplacer: def __init__(self, noise_manager): self.noise_manager = noise_manager def __getattr__(self, item): if item == "randn": return self.noise_manager.randn if hasattr(torch, item): return getattr(torch, item) raise AttributeError( "'{}' object has no attribute '{}'".format( type(self).__name__, item ) ) noise_manager = NoiseManager() if scheduler_module is not None: scheduler_module.torch = TorchRandReplacer(noise_manager) scheduler = scheduler_cls( num_train_timesteps=self.SCHEDULER_TIMESTEPS, beta_start=self.SCHEDULER_LINEAR_START, beta_end=self.SCHEDULER_LINEAR_END, beta_schedule=self.SCHEDLER_SCHEDULE, **sched_init_args, ) device = torch.device( pipeline.pipe.device if torch.cuda.is_available() else "cpu" ) # vae.to(vae_dtype).to(device) # vae.eval() # text_encoder1.to(dtype).to(device) # text_encoder2.to(dtype).to(device) print(pipeline.pipe.dtype) unet.to(pipeline.pipe.dtype).to(pipeline.pipe.device) # text_encoder1.eval() # text_encoder2.eval() unet.eval() control_nets: List[Tuple[ControlNetLLLite, float]] = [] for link in controlnet_urls: net_path = Path.home() / ".cache" / link.split("/")[-1] download_file(link, net_path) print(f"loading controlnet {net_path}") state_dict = load_file(net_path) mlp_dim = None cond_emb_dim = None for key, value in state_dict.items(): if mlp_dim is None and "down.0.weight" in key: mlp_dim = value.shape[0] elif cond_emb_dim is None and "conditioning1.0" in key: cond_emb_dim = value.shape[0] * 2 if mlp_dim is not None and cond_emb_dim is not None: break assert ( mlp_dim is not None and cond_emb_dim is not None ), f"invalid control net: {link}" multiplier = 0.2 # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] ratio = 1.0 control_net = ControlNetLLLite( unet, cond_emb_dim, mlp_dim, multiplier=multiplier ) control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(pipeline.pipe.dtype).to(device) control_net.set_batch_cond_only(False, False) control_nets.append((control_net, ratio)) networks = [] self.pipe = PipelineLike( device, pipeline.pipe.vae, [pipeline.pipe.text_encoder, pipeline.pipe.text_encoder_2], [pipeline.pipe.tokenizer, pipeline.pipe.tokenizer_2], unet, scheduler, 2, ) self.pipe.set_control_nets(control_nets) clear_cuda_and_gc() pipeline.pipe.unload_lora_weights() pipeline.pipe.unfuse_lora() clear_cuda_and_gc() def __call__( self, prompt: str, negative_prompt: str, seed: int, image: Image.Image, condition_image: Union[Image.Image, List[Image.Image]], height: int = 1024, width: int = 1024, num_inference_steps: int = 24, guidance_scale=1.0, ): noise_shape = ( self.LATENT_CHANNELS, height // self.DOWNSAMPLING_FACTOR, width // self.DOWNSAMPLING_FACTOR, ) i2i_noises = torch.zeros( (1, *noise_shape), device=self.device, dtype=self.dtype ) i2i_noises[0] = torch.randn(noise_shape, device=self.device, dtype=self.dtype) images = self.pipe( prompt=prompt, negative_prompt=negative_prompt, seed=seed, init_image=image, height=height, width=width, strength=1.0, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, clip_guide_images=condition_image, img2img_noise=i2i_noises, ) return images