File size: 5,231 Bytes
19b3da3 86248f3 19b3da3 7fbdac4 19b3da3 42ef134 19b3da3 42ef134 86248f3 42ef134 86248f3 42ef134 19b3da3 42ef134 19b3da3 42ef134 19b3da3 86248f3 42ef134 86248f3 42ef134 10230ea 19b3da3 86248f3 19b3da3 86248f3 8aeb9e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union
import boto3
import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from lora_diffusion import patch_pipe, tune_lora_scale
from pydash import chain
from internals.data.dataAccessor import getStyles
from internals.util.commons import download_file
from internals.util.config import get_is_sdxl
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
def run(pipe):
path = self.__style["path"]
if str(path).endswith((".pt", ".safetensors")):
patch_pipe(pipe, self.__style["path"])
tune_lora_scale(pipe.unet, self.__style["weight"])
tune_lora_scale(pipe.text_encoder, self.__style["weight"])
for p in self.pipe:
run(p)
def kwargs(self):
return {}
def cleanup(self):
def run(pipe):
tune_lora_scale(pipe.unet, 0.0)
tune_lora_scale(pipe.text_encoder, 0.0)
for p in self.pipe:
run(p)
class LoraDiffuserPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
def run(pipe):
path = self.__style["path"]
pipe.load_lora_weights(
os.path.dirname(path), weight_name=os.path.basename(path)
)
for p in self.pipe:
run(p)
def kwargs(self):
return {}
def cleanup(self):
def run(pipe):
LoraStyle.unload_lora_weights(pipe)
for p in self.pipe:
run(p)
class EmptyLoraPatcher:
def __init__(self, pipe):
self.pipe = pipe
def patch(self):
"Patch will act as cleanup, to tune down any corrupted lora"
self.cleanup()
def kwargs(self):
return {}
def cleanup(self):
def run(pipe):
tune_lora_scale(pipe.unet, 0.0)
tune_lora_scale(pipe.text_encoder, 0.0)
LoraStyle.unload_lora_weights(pipe)
for p in self.pipe:
run(p)
def load(self, model_dir: str):
self.model = model_dir
self.fetch_styles()
def fetch_styles(self):
model_dir = self.model
result = getStyles()
if result is not None:
self.__styles = self.__parse_styles(model_dir, result["data"])
if len(self.__styles) == 0:
print("Warning: No styles found for Model")
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
style = self.__styles[key]
return f"{', '.join(style['text'])}, {prompt}"
return prompt
def get_patcher(
self, pipe: Union[Any, List], key: str
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
pipe = [pipe] if not isinstance(pipe, list) else pipe
if key in self.__styles:
style = self.__styles[key]
if style["type"] == "diffuser":
return self.LoraDiffuserPatcher(pipe, style)
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher(pipe)
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
styles = {}
download_dir = Path(Path.home() / ".cache" / "lora")
download_dir.mkdir(exist_ok=True)
data = chain(data).uniq_by(lambda x: x["tag"]).value()
for item in data:
if item["attributes"] is not None:
attr = json.loads(item["attributes"])
if "path" in attr:
file_path = Path(download_dir / attr["path"].split("/")[-1])
if not file_path.exists():
s3_uri = attr["path"]
download_file(s3_uri, file_path)
styles[item["tag"]] = {
"path": str(file_path),
"weight": attr["weight"],
"type": attr["type"],
"text": attr["text"],
"negativePrompt": attr["negativePrompt"],
}
return styles
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)
@staticmethod
def unload_lora_weights(pipe):
pipe.unload_lora_weights()
|