|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Union |
|
|
|
import boto3 |
|
import torch |
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
from lora_diffusion import patch_pipe, tune_lora_scale |
|
from pydash import chain |
|
|
|
from internals.data.dataAccessor import getStyles |
|
from internals.util.commons import download_file |
|
|
|
|
|
class LoraStyle: |
|
class LoraPatcher: |
|
def __init__(self, pipe, style: Dict[str, Any]): |
|
self.__style = style |
|
self.pipe = pipe |
|
|
|
@torch.inference_mode() |
|
def patch(self): |
|
path = self.__style["path"] |
|
if str(path).endswith((".pt", ".safetensors")): |
|
patch_pipe(self.pipe, self.__style["path"]) |
|
tune_lora_scale(self.pipe.unet, self.__style["weight"]) |
|
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"]) |
|
|
|
def kwargs(self): |
|
return {} |
|
|
|
def cleanup(self): |
|
tune_lora_scale(self.pipe.unet, 0.0) |
|
tune_lora_scale(self.pipe.text_encoder, 0.0) |
|
|
|
class LoraDiffuserPatcher: |
|
def __init__(self, pipe, style: Dict[str, Any]): |
|
self.__style = style |
|
self.pipe = pipe |
|
|
|
@torch.inference_mode() |
|
def patch(self): |
|
path = self.__style["path"] |
|
self.pipe.load_lora_weights( |
|
os.path.dirname(path), weight_name=os.path.basename(path) |
|
) |
|
|
|
def kwargs(self): |
|
return {} |
|
|
|
def cleanup(self): |
|
LoraStyle.unload_lora_weights(self.pipe) |
|
|
|
class EmptyLoraPatcher: |
|
def __init__(self, pipe): |
|
self.pipe = pipe |
|
|
|
def patch(self): |
|
"Patch will act as cleanup, to tune down any corrupted lora" |
|
self.cleanup() |
|
|
|
def kwargs(self): |
|
return {} |
|
|
|
def cleanup(self): |
|
tune_lora_scale(self.pipe.unet, 0.0) |
|
tune_lora_scale(self.pipe.text_encoder, 0.0) |
|
LoraStyle.unload_lora_weights(self.pipe) |
|
|
|
def load(self, model_dir: str): |
|
self.model = model_dir |
|
self.fetch_styles() |
|
|
|
def fetch_styles(self): |
|
model_dir = self.model |
|
result = getStyles() |
|
if result is not None: |
|
self.__styles = self.__parse_styles(model_dir, result["data"]) |
|
else: |
|
self.__styles = self.__get_default_styles(model_dir) |
|
self.__verify() |
|
|
|
def prepend_style_to_prompt(self, prompt: str, key: str) -> str: |
|
if key in self.__styles: |
|
style = self.__styles[key] |
|
return f"{', '.join(style['text'])}, {prompt}" |
|
return prompt |
|
|
|
def get_patcher( |
|
self, pipe, key: str |
|
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]: |
|
if key in self.__styles: |
|
style = self.__styles[key] |
|
if style["type"] == "diffuser": |
|
return self.LoraDiffuserPatcher(pipe, style) |
|
return self.LoraPatcher(pipe, style) |
|
return self.EmptyLoraPatcher(pipe) |
|
|
|
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict: |
|
styles = {} |
|
download_dir = Path(Path.home() / ".cache" / "lora") |
|
download_dir.mkdir(exist_ok=True) |
|
data = chain(data).uniq_by(lambda x: x["tag"]).value() |
|
for item in data: |
|
if item["attributes"] is not None: |
|
attr = json.loads(item["attributes"]) |
|
if "path" in attr: |
|
file_path = Path(download_dir / attr["path"].split("/")[-1]) |
|
|
|
if not file_path.exists(): |
|
s3_uri = attr["path"] |
|
download_file(s3_uri, file_path) |
|
|
|
styles[item["tag"]] = { |
|
"path": str(file_path), |
|
"weight": attr["weight"], |
|
"type": attr["type"], |
|
"text": attr["text"], |
|
"negativePrompt": attr["negativePrompt"], |
|
} |
|
if len(styles) == 0: |
|
return self.__get_default_styles(model_dir) |
|
return styles |
|
|
|
def __get_default_styles(self, model_dir: str) -> Dict: |
|
return { |
|
"nq6akX1CIp": { |
|
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors", |
|
"text": ["nq6akX1CIp style"], |
|
"weight": 0.5, |
|
"negativePrompt": [""], |
|
"type": "custom", |
|
}, |
|
"ghibli": { |
|
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin", |
|
"text": ["ghibli style"], |
|
"weight": 1, |
|
"negativePrompt": [""], |
|
"type": "custom", |
|
}, |
|
"eQAmnK2kB2": { |
|
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors", |
|
"text": ["eQAmnK2kB2 style"], |
|
"weight": 0.5, |
|
"negativePrompt": [""], |
|
"type": "custom", |
|
}, |
|
"to8contrast": { |
|
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin", |
|
"text": ["to8contrast style"], |
|
"weight": 0.5, |
|
"negativePrompt": [""], |
|
"type": "custom", |
|
}, |
|
"sfrrfz8vge": { |
|
"path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors", |
|
"text": ["sfrrfz8vge style"], |
|
"weight": 1.2, |
|
"negativePrompt": [""], |
|
"type": "custom", |
|
}, |
|
} |
|
|
|
def __verify(self): |
|
"A method to verify if lora exists within the required path otherwise throw error" |
|
|
|
for item in self.__styles.keys(): |
|
if not os.path.exists(self.__styles[item]["path"]): |
|
raise Exception( |
|
"Lora style model " |
|
+ item |
|
+ " not found at path: " |
|
+ self.__styles[item]["path"] |
|
) |
|
|
|
@staticmethod |
|
def unload_lora_weights(pipe): |
|
pipe.unet.set_attn_processor(AttnProcessor2_0()) |
|
pipe._remove_text_encoder_monkey_patch() |
|
|