|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Union |
|
|
|
import boto3 |
|
import torch |
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
from lora_diffusion import patch_pipe, tune_lora_scale |
|
from pydash import chain |
|
|
|
from internals.data.dataAccessor import getStyles |
|
from internals.util.commons import download_file |
|
from internals.util.config import get_is_sdxl |
|
|
|
|
|
class LoraStyle: |
|
class LoraPatcher: |
|
def __init__(self, pipe, style: Dict[str, Any]): |
|
self.__style = style |
|
self.pipe = pipe |
|
|
|
@torch.inference_mode() |
|
def patch(self): |
|
def run(pipe): |
|
path = self.__style["path"] |
|
if str(path).endswith((".pt", ".safetensors")): |
|
patch_pipe(pipe, self.__style["path"]) |
|
tune_lora_scale(pipe.unet, self.__style["weight"]) |
|
tune_lora_scale(pipe.text_encoder, self.__style["weight"]) |
|
|
|
for p in self.pipe: |
|
run(p) |
|
|
|
def kwargs(self): |
|
return {} |
|
|
|
def cleanup(self): |
|
def run(pipe): |
|
tune_lora_scale(pipe.unet, 0.0) |
|
tune_lora_scale(pipe.text_encoder, 0.0) |
|
|
|
for p in self.pipe: |
|
run(p) |
|
|
|
class LoraDiffuserPatcher: |
|
def __init__(self, pipe, style: Dict[str, Any]): |
|
self.__style = style |
|
self.pipe = pipe |
|
|
|
@torch.inference_mode() |
|
def patch(self): |
|
def run(pipe): |
|
path = self.__style["path"] |
|
pipe.load_lora_weights( |
|
os.path.dirname(path), weight_name=os.path.basename(path) |
|
) |
|
|
|
for p in self.pipe: |
|
run(p) |
|
|
|
def kwargs(self): |
|
return {} |
|
|
|
def cleanup(self): |
|
def run(pipe): |
|
LoraStyle.unload_lora_weights(pipe) |
|
|
|
for p in self.pipe: |
|
run(p) |
|
|
|
class EmptyLoraPatcher: |
|
def __init__(self, pipe): |
|
self.pipe = pipe |
|
|
|
def patch(self): |
|
"Patch will act as cleanup, to tune down any corrupted lora" |
|
self.cleanup() |
|
|
|
def kwargs(self): |
|
return {} |
|
|
|
def cleanup(self): |
|
def run(pipe): |
|
tune_lora_scale(pipe.unet, 0.0) |
|
tune_lora_scale(pipe.text_encoder, 0.0) |
|
LoraStyle.unload_lora_weights(pipe) |
|
|
|
for p in self.pipe: |
|
run(p) |
|
|
|
def load(self, model_dir: str): |
|
self.model = model_dir |
|
self.fetch_styles() |
|
|
|
def fetch_styles(self): |
|
model_dir = self.model |
|
result = getStyles() |
|
if result is not None: |
|
self.__styles = self.__parse_styles(model_dir, result["data"]) |
|
if len(self.__styles) == 0: |
|
print("Warning: No styles found for Model") |
|
self.__verify() |
|
|
|
def prepend_style_to_prompt(self, prompt: str, key: str) -> str: |
|
if key in self.__styles: |
|
style = self.__styles[key] |
|
return f"{', '.join(style['text'])}, {prompt}" |
|
return prompt |
|
|
|
def get_patcher( |
|
self, pipe: Union[Any, List], key: str |
|
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]: |
|
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes" |
|
pipe = [pipe] if not isinstance(pipe, list) else pipe |
|
|
|
if key in self.__styles: |
|
style = self.__styles[key] |
|
if style["type"] == "diffuser": |
|
return self.LoraDiffuserPatcher(pipe, style) |
|
return self.LoraPatcher(pipe, style) |
|
return self.EmptyLoraPatcher(pipe) |
|
|
|
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict: |
|
styles = {} |
|
download_dir = Path(Path.home() / ".cache" / "lora") |
|
download_dir.mkdir(exist_ok=True) |
|
data = chain(data).uniq_by(lambda x: x["tag"]).value() |
|
for item in data: |
|
if item["attributes"] is not None: |
|
attr = json.loads(item["attributes"]) |
|
if "path" in attr: |
|
file_path = Path(download_dir / attr["path"].split("/")[-1]) |
|
|
|
if not file_path.exists(): |
|
s3_uri = attr["path"] |
|
download_file(s3_uri, file_path) |
|
|
|
styles[item["tag"]] = { |
|
"path": str(file_path), |
|
"weight": attr["weight"], |
|
"type": attr["type"], |
|
"text": attr["text"], |
|
"negativePrompt": attr["negativePrompt"], |
|
} |
|
return styles |
|
|
|
def __verify(self): |
|
"A method to verify if lora exists within the required path otherwise throw error" |
|
|
|
for item in self.__styles.keys(): |
|
if not os.path.exists(self.__styles[item]["path"]): |
|
raise Exception( |
|
"Lora style model " |
|
+ item |
|
+ " not found at path: " |
|
+ self.__styles[item]["path"] |
|
) |
|
|
|
@staticmethod |
|
def unload_lora_weights(pipe): |
|
pipe.unload_lora_weights() |
|
|