File size: 4,189 Bytes
1bc457e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import List, Optional
from internals.data.task import Task
from internals.pipelines.commons import Text2Img
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.prompt_modifier import PromptModifier
from internals.util.anomaly import remove_colors
from internals.util.avatar import Avatar
from internals.util.config import num_return_sequences
from internals.util.lora_style import LoraStyle
def get_patched_prompt(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
prompt_modifier: PromptModifier,
):
def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
for i in range(len(prompt)):
prompt[i] = avatar.add_code_names(prompt[i])
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
if additional:
prompt[i] = additional + " " + prompt[i]
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
ori_prompt = [task.get_prompt()] * num_return_sequences
class_name = None
add_style_and_character(ori_prompt, class_name)
add_style_and_character(prompt, class_name)
print({"prompts": prompt})
return (prompt, ori_prompt)
def get_patched_prompt_text2img(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
prompt_modifier: PromptModifier,
) -> Text2Img.Params:
def add_style_and_character(prompt: str, prepend: str = ""):
prompt = avatar.add_code_names(prompt)
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
prompt = prepend + prompt
return prompt
if task.get_prompt_left() and task.get_prompt_right():
# prepend = "2characters, "
prepend = ""
if task.is_prompt_engineering():
mod_prompt = prompt_modifier.modify(task.get_prompt())
else:
mod_prompt = [task.get_prompt()] * num_return_sequences
prompt, prompt_left, prompt_right = [], [], []
for i in range(len(mod_prompt)):
mp = mod_prompt[i].replace(task.get_prompt(), "")
prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
prompt_left.append(
add_style_and_character(task.get_prompt_left(), prepend) + mp
)
prompt_right.append(
add_style_and_character(task.get_prompt_right(), prepend) + mp
)
params = Text2Img.Params(
prompt=prompt,
prompt_left=prompt_left,
prompt_right=prompt_right,
)
else:
if task.is_prompt_engineering():
mod_prompt = prompt_modifier.modify(task.get_prompt())
else:
mod_prompt = [task.get_prompt()] * num_return_sequences
mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]
params = Text2Img.Params(
prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences,
modified_prompt=mod_prompt,
)
print(params)
return params
def get_patched_prompt_tile_upscale(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
img_classifier: ImageClassifier,
img2text: Image2Text,
):
if task.get_prompt():
prompt = task.get_prompt()
else:
prompt = img2text.process(task.get_imageUrl())
# merge blip
if task.PROMPT.has_placeholder_blip_merge():
blip = img2text.process(task.get_imageUrl())
prompt = task.PROMPT.merge_blip(blip)
# remove anomalies in prompt
prompt = remove_colors(prompt)
prompt = avatar.add_code_names(prompt)
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
if not task.get_style():
class_name = img_classifier.classify(
task.get_imageUrl(), task.get_width(), task.get_height()
)
else:
class_name = ""
prompt = class_name + " " + prompt
prompt = prompt.strip()
print({"prompt": prompt})
return prompt
|