import json import os from pathlib import Path from typing import Any, Dict, List, Union import boto3 import torch from diffusers.models.attention_processor import AttnProcessor2_0 from lora_diffusion import patch_pipe, tune_lora_scale from pydash import chain from internals.data.dataAccessor import getStyles from internals.util.config import get_is_sdxl from internals.util.commons import download_file class LoraStyle: class LoraPatcher: def __init__(self, pipe, style: Dict[str, Any]): self.__style = style self.pipe = pipe @torch.inference_mode() def patch(self): def run(pipe): path = self.__style["path"] if str(path).endswith((".pt", ".safetensors")): patch_pipe(pipe, self.__style["path"]) tune_lora_scale(pipe.unet, self.__style["weight"]) tune_lora_scale(pipe.text_encoder, self.__style["weight"]) for p in self.pipe: run(p) def kwargs(self): return {} def cleanup(self): def run(pipe): tune_lora_scale(pipe.unet, 0.0) tune_lora_scale(pipe.text_encoder, 0.0) for p in self.pipe: run(p) class LoraDiffuserPatcher: def __init__(self, pipe, style: Dict[str, Any]): self.__style = style self.pipe = pipe @torch.inference_mode() def patch(self): def run(pipe): path = self.__style["path"] pipe.load_lora_weights( os.path.dirname(path), weight_name=os.path.basename(path) ) for p in self.pipe: run(p) def kwargs(self): return {} def cleanup(self): def run(pipe): LoraStyle.unload_lora_weights(pipe) for p in self.pipe: run(p) class EmptyLoraPatcher: def __init__(self, pipe): self.pipe = pipe def patch(self): "Patch will act as cleanup, to tune down any corrupted lora" self.cleanup() def kwargs(self): return {} def cleanup(self): def run(pipe): tune_lora_scale(pipe.unet, 0.0) tune_lora_scale(pipe.text_encoder, 0.0) LoraStyle.unload_lora_weights(pipe) for p in self.pipe: run(p) def load(self, model_dir: str): self.model = model_dir self.fetch_styles() def fetch_styles(self): model_dir = self.model result = getStyles() if result is not None: self.__styles = self.__parse_styles(model_dir, result["data"]) if len(self.__styles) == 0: print("Warning: No styles found for Model") self.__verify() def prepend_style_to_prompt(self, prompt: str, key: str) -> str: if key in self.__styles: style = self.__styles[key] return f"{', '.join(style['text'])}, {prompt}" return prompt def get_patcher( self, pipe: Union[Any, List], key: str ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]: "Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes" pipe = [pipe] if not isinstance(pipe, list) else pipe if get_is_sdxl(): print("Warning: Lora is not supported on SDXL") return self.EmptyLoraPatcher(pipe) if key in self.__styles: style = self.__styles[key] if style["type"] == "diffuser": return self.LoraDiffuserPatcher(pipe, style) return self.LoraPatcher(pipe, style) return self.EmptyLoraPatcher(pipe) def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict: styles = {} download_dir = Path(Path.home() / ".cache" / "lora") download_dir.mkdir(exist_ok=True) data = chain(data).uniq_by(lambda x: x["tag"]).value() for item in data: if item["attributes"] is not None: attr = json.loads(item["attributes"]) if "path" in attr: file_path = Path(download_dir / attr["path"].split("/")[-1]) if not file_path.exists(): s3_uri = attr["path"] download_file(s3_uri, file_path) styles[item["tag"]] = { "path": str(file_path), "weight": attr["weight"], "type": attr["type"], "text": attr["text"], "negativePrompt": attr["negativePrompt"], } return styles def __verify(self): "A method to verify if lora exists within the required path otherwise throw error" for item in self.__styles.keys(): if not os.path.exists(self.__styles[item]["path"]): raise Exception( "Lora style model " + item + " not found at path: " + self.__styles[item]["path"] ) @staticmethod def unload_lora_weights(pipe): pipe.unload_lora_weights()