CM2000112 / internals /util /lora_style.py
jayparmr's picture
Upload folder using huggingface_hub
7fbdac4
raw
history blame
5.23 kB
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union
import boto3
import torch
from diffusers.models.attention_processor import AttnProcessor2_0
from lora_diffusion import patch_pipe, tune_lora_scale
from pydash import chain
from internals.data.dataAccessor import getStyles
from internals.util.commons import download_file
from internals.util.config import get_is_sdxl
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
def run(pipe):
path = self.__style["path"]
if str(path).endswith((".pt", ".safetensors")):
patch_pipe(pipe, self.__style["path"])
tune_lora_scale(pipe.unet, self.__style["weight"])
tune_lora_scale(pipe.text_encoder, self.__style["weight"])
for p in self.pipe:
run(p)
def kwargs(self):
return {}
def cleanup(self):
def run(pipe):
tune_lora_scale(pipe.unet, 0.0)
tune_lora_scale(pipe.text_encoder, 0.0)
for p in self.pipe:
run(p)
class LoraDiffuserPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
def run(pipe):
path = self.__style["path"]
pipe.load_lora_weights(
os.path.dirname(path), weight_name=os.path.basename(path)
)
for p in self.pipe:
run(p)
def kwargs(self):
return {}
def cleanup(self):
def run(pipe):
LoraStyle.unload_lora_weights(pipe)
for p in self.pipe:
run(p)
class EmptyLoraPatcher:
def __init__(self, pipe):
self.pipe = pipe
def patch(self):
"Patch will act as cleanup, to tune down any corrupted lora"
self.cleanup()
def kwargs(self):
return {}
def cleanup(self):
def run(pipe):
tune_lora_scale(pipe.unet, 0.0)
tune_lora_scale(pipe.text_encoder, 0.0)
LoraStyle.unload_lora_weights(pipe)
for p in self.pipe:
run(p)
def load(self, model_dir: str):
self.model = model_dir
self.fetch_styles()
def fetch_styles(self):
model_dir = self.model
result = getStyles()
if result is not None:
self.__styles = self.__parse_styles(model_dir, result["data"])
if len(self.__styles) == 0:
print("Warning: No styles found for Model")
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
style = self.__styles[key]
return f"{', '.join(style['text'])}, {prompt}"
return prompt
def get_patcher(
self, pipe: Union[Any, List], key: str
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
pipe = [pipe] if not isinstance(pipe, list) else pipe
if key in self.__styles:
style = self.__styles[key]
if style["type"] == "diffuser":
return self.LoraDiffuserPatcher(pipe, style)
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher(pipe)
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
styles = {}
download_dir = Path(Path.home() / ".cache" / "lora")
download_dir.mkdir(exist_ok=True)
data = chain(data).uniq_by(lambda x: x["tag"]).value()
for item in data:
if item["attributes"] is not None:
attr = json.loads(item["attributes"])
if "path" in attr:
file_path = Path(download_dir / attr["path"].split("/")[-1])
if not file_path.exists():
s3_uri = attr["path"]
download_file(s3_uri, file_path)
styles[item["tag"]] = {
"path": str(file_path),
"weight": attr["weight"],
"type": attr["type"],
"text": attr["text"],
"negativePrompt": attr["negativePrompt"],
}
return styles
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)
@staticmethod
def unload_lora_weights(pipe):
pipe.unload_lora_weights()