File size: 6,105 Bytes
19b3da3 86248f3 19b3da3 86248f3 19b3da3 86248f3 19b3da3 86248f3 19b3da3 86248f3 19b3da3 86248f3 8aeb9e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union
import boto3
import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from lora_diffusion import patch_pipe, tune_lora_scale
from pydash import chain
from internals.data.dataAccessor import getStyles
from internals.util.commons import download_file
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
path = self.__style["path"]
if str(path).endswith((".pt", ".safetensors")):
patch_pipe(self.pipe, self.__style["path"])
tune_lora_scale(self.pipe.unet, self.__style["weight"])
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
class LoraDiffuserPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
path = self.__style["path"]
self.pipe.load_lora_weights(
os.path.dirname(path), weight_name=os.path.basename(path)
)
def kwargs(self):
return {}
def cleanup(self):
LoraStyle.unload_lora_weights(self.pipe)
class EmptyLoraPatcher:
def __init__(self, pipe):
self.pipe = pipe
def patch(self):
"Patch will act as cleanup, to tune down any corrupted lora"
self.cleanup()
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
LoraStyle.unload_lora_weights(self.pipe)
def load(self, model_dir: str):
self.model = model_dir
self.fetch_styles()
def fetch_styles(self):
model_dir = self.model
result = getStyles()
if result is not None:
self.__styles = self.__parse_styles(model_dir, result["data"])
else:
self.__styles = self.__get_default_styles(model_dir)
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
style = self.__styles[key]
return f"{', '.join(style['text'])}, {prompt}"
return prompt
def get_patcher(
self, pipe, key: str
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
if key in self.__styles:
style = self.__styles[key]
if style["type"] == "diffuser":
return self.LoraDiffuserPatcher(pipe, style)
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher(pipe)
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
styles = {}
download_dir = Path(Path.home() / ".cache" / "lora")
download_dir.mkdir(exist_ok=True)
data = chain(data).uniq_by(lambda x: x["tag"]).value()
for item in data:
if item["attributes"] is not None:
attr = json.loads(item["attributes"])
if "path" in attr:
file_path = Path(download_dir / attr["path"].split("/")[-1])
if not file_path.exists():
s3_uri = attr["path"]
download_file(s3_uri, file_path)
styles[item["tag"]] = {
"path": str(file_path),
"weight": attr["weight"],
"type": attr["type"],
"text": attr["text"],
"negativePrompt": attr["negativePrompt"],
}
if len(styles) == 0:
return self.__get_default_styles(model_dir)
return styles
def __get_default_styles(self, model_dir: str) -> Dict:
return {
"nq6akX1CIp": {
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
"text": ["nq6akX1CIp style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"ghibli": {
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
"text": ["ghibli style"],
"weight": 1,
"negativePrompt": [""],
"type": "custom",
},
"eQAmnK2kB2": {
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
"text": ["eQAmnK2kB2 style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"to8contrast": {
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
"text": ["to8contrast style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"sfrrfz8vge": {
"path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors",
"text": ["sfrrfz8vge style"],
"weight": 1.2,
"negativePrompt": [""],
"type": "custom",
},
}
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)
@staticmethod
def unload_lora_weights(pipe):
pipe.unload_lora_weights()
|