import json import os from pathlib import Path from typing import Any, Dict, List, Union import boto3 import torch from diffusers.models.attention_processor import AttnProcessor2_0 from lora_diffusion import patch_pipe, tune_lora_scale from pydash import chain from internals.data.dataAccessor import getStyles from internals.util.commons import download_file class LoraStyle: class LoraPatcher: def __init__(self, pipe, style: Dict[str, Any]): self.__style = style self.pipe = pipe @torch.inference_mode() def patch(self): path = self.__style["path"] if str(path).endswith((".pt", ".safetensors")): patch_pipe(self.pipe, self.__style["path"]) tune_lora_scale(self.pipe.unet, self.__style["weight"]) tune_lora_scale(self.pipe.text_encoder, self.__style["weight"]) def kwargs(self): return {} def cleanup(self): tune_lora_scale(self.pipe.unet, 0.0) tune_lora_scale(self.pipe.text_encoder, 0.0) class LoraDiffuserPatcher: def __init__(self, pipe, style: Dict[str, Any]): self.__style = style self.pipe = pipe @torch.inference_mode() def patch(self): path = self.__style["path"] self.pipe.load_lora_weights( os.path.dirname(path), weight_name=os.path.basename(path) ) def kwargs(self): return {} def cleanup(self): LoraStyle.unload_lora_weights(self.pipe) class EmptyLoraPatcher: def __init__(self, pipe): self.pipe = pipe def patch(self): "Patch will act as cleanup, to tune down any corrupted lora" self.cleanup() def kwargs(self): return {} def cleanup(self): tune_lora_scale(self.pipe.unet, 0.0) tune_lora_scale(self.pipe.text_encoder, 0.0) LoraStyle.unload_lora_weights(self.pipe) def load(self, model_dir: str): self.model = model_dir self.fetch_styles() def fetch_styles(self): model_dir = self.model result = getStyles() if result is not None: self.__styles = self.__parse_styles(model_dir, result["data"]) else: self.__styles = self.__get_default_styles(model_dir) self.__verify() def prepend_style_to_prompt(self, prompt: str, key: str) -> str: if key in self.__styles: style = self.__styles[key] return f"{', '.join(style['text'])}, {prompt}" return prompt def get_patcher( self, pipe, key: str ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]: if key in self.__styles: style = self.__styles[key] if style["type"] == "diffuser": return self.LoraDiffuserPatcher(pipe, style) return self.LoraPatcher(pipe, style) return self.EmptyLoraPatcher(pipe) def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict: styles = {} download_dir = Path(Path.home() / ".cache" / "lora") download_dir.mkdir(exist_ok=True) data = chain(data).uniq_by(lambda x: x["tag"]).value() for item in data: if item["attributes"] is not None: attr = json.loads(item["attributes"]) if "path" in attr: file_path = Path(download_dir / attr["path"].split("/")[-1]) if not file_path.exists(): s3_uri = attr["path"] download_file(s3_uri, file_path) styles[item["tag"]] = { "path": str(file_path), "weight": attr["weight"], "type": attr["type"], "text": attr["text"], "negativePrompt": attr["negativePrompt"], } if len(styles) == 0: return self.__get_default_styles(model_dir) return styles def __get_default_styles(self, model_dir: str) -> Dict: return { "nq6akX1CIp": { "path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors", "text": ["nq6akX1CIp style"], "weight": 0.5, "negativePrompt": [""], "type": "custom", }, "ghibli": { "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin", "text": ["ghibli style"], "weight": 1, "negativePrompt": [""], "type": "custom", }, "eQAmnK2kB2": { "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors", "text": ["eQAmnK2kB2 style"], "weight": 0.5, "negativePrompt": [""], "type": "custom", }, "to8contrast": { "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin", "text": ["to8contrast style"], "weight": 0.5, "negativePrompt": [""], "type": "custom", }, "sfrrfz8vge": { "path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors", "text": ["sfrrfz8vge style"], "weight": 1.2, "negativePrompt": [""], "type": "custom", }, } def __verify(self): "A method to verify if lora exists within the required path otherwise throw error" for item in self.__styles.keys(): if not os.path.exists(self.__styles[item]["path"]): raise Exception( "Lora style model " + item + " not found at path: " + self.__styles[item]["path"] ) @staticmethod def unload_lora_weights(pipe): pipe.unet.set_attn_processor(AttnProcessor2_0()) # for pytorch 2.0 pipe._remove_text_encoder_monkey_patch()