CM2000112 / internals /util /lora_style.py
jayparmr's picture
Upload folder using huggingface_hub
86248f3
raw
history blame
6.2 kB
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union
import boto3
import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from lora_diffusion import patch_pipe, tune_lora_scale
from pydash import chain
from internals.data.dataAccessor import getStyles
from internals.util.commons import download_file
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
path = self.__style["path"]
if str(path).endswith((".pt", ".safetensors")):
patch_pipe(self.pipe, self.__style["path"])
tune_lora_scale(self.pipe.unet, self.__style["weight"])
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
class LoraDiffuserPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
path = self.__style["path"]
self.pipe.load_lora_weights(
os.path.dirname(path), weight_name=os.path.basename(path)
)
def kwargs(self):
return {}
def cleanup(self):
LoraStyle.unload_lora_weights(self.pipe)
class EmptyLoraPatcher:
def __init__(self, pipe):
self.pipe = pipe
def patch(self):
"Patch will act as cleanup, to tune down any corrupted lora"
self.cleanup()
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
LoraStyle.unload_lora_weights(self.pipe)
def load(self, model_dir: str):
self.model = model_dir
self.fetch_styles()
def fetch_styles(self):
model_dir = self.model
result = getStyles()
if result is not None:
self.__styles = self.__parse_styles(model_dir, result["data"])
else:
self.__styles = self.__get_default_styles(model_dir)
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
style = self.__styles[key]
return f"{', '.join(style['text'])}, {prompt}"
return prompt
def get_patcher(
self, pipe, key: str
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
if key in self.__styles:
style = self.__styles[key]
if style["type"] == "diffuser":
return self.LoraDiffuserPatcher(pipe, style)
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher(pipe)
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
styles = {}
download_dir = Path(Path.home() / ".cache" / "lora")
download_dir.mkdir(exist_ok=True)
data = chain(data).uniq_by(lambda x: x["tag"]).value()
for item in data:
if item["attributes"] is not None:
attr = json.loads(item["attributes"])
if "path" in attr:
file_path = Path(download_dir / attr["path"].split("/")[-1])
if not file_path.exists():
s3_uri = attr["path"]
download_file(s3_uri, file_path)
styles[item["tag"]] = {
"path": str(file_path),
"weight": attr["weight"],
"type": attr["type"],
"text": attr["text"],
"negativePrompt": attr["negativePrompt"],
}
if len(styles) == 0:
return self.__get_default_styles(model_dir)
return styles
def __get_default_styles(self, model_dir: str) -> Dict:
return {
"nq6akX1CIp": {
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
"text": ["nq6akX1CIp style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"ghibli": {
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
"text": ["ghibli style"],
"weight": 1,
"negativePrompt": [""],
"type": "custom",
},
"eQAmnK2kB2": {
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
"text": ["eQAmnK2kB2 style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"to8contrast": {
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
"text": ["to8contrast style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"sfrrfz8vge": {
"path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors",
"text": ["sfrrfz8vge style"],
"weight": 1.2,
"negativePrompt": [""],
"type": "custom",
},
}
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)
@staticmethod
def unload_lora_weights(pipe):
pipe.unet.set_attn_processor(AttnProcessor2_0()) # for pytorch 2.0
pipe._remove_text_encoder_monkey_patch()