File size: 4,189 Bytes
1bc457e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import List, Optional

from internals.data.task import Task
from internals.pipelines.commons import Text2Img
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.prompt_modifier import PromptModifier
from internals.util.anomaly import remove_colors
from internals.util.avatar import Avatar
from internals.util.config import num_return_sequences
from internals.util.lora_style import LoraStyle


def get_patched_prompt(
    task: Task,
    avatar: Avatar,
    lora_style: LoraStyle,
    prompt_modifier: PromptModifier,
):
    def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
        for i in range(len(prompt)):
            prompt[i] = avatar.add_code_names(prompt[i])
            prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
            if additional:
                prompt[i] = additional + " " + prompt[i]

    prompt = task.get_prompt()

    if task.is_prompt_engineering():
        prompt = prompt_modifier.modify(prompt)
    else:
        prompt = [prompt] * num_return_sequences

    ori_prompt = [task.get_prompt()] * num_return_sequences

    class_name = None
    add_style_and_character(ori_prompt, class_name)
    add_style_and_character(prompt, class_name)

    print({"prompts": prompt})

    return (prompt, ori_prompt)


def get_patched_prompt_text2img(
    task: Task,
    avatar: Avatar,
    lora_style: LoraStyle,
    prompt_modifier: PromptModifier,
) -> Text2Img.Params:
    def add_style_and_character(prompt: str, prepend: str = ""):
        prompt = avatar.add_code_names(prompt)
        prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
        prompt = prepend + prompt
        return prompt

    if task.get_prompt_left() and task.get_prompt_right():
        # prepend = "2characters, "
        prepend = ""
        if task.is_prompt_engineering():
            mod_prompt = prompt_modifier.modify(task.get_prompt())
        else:
            mod_prompt = [task.get_prompt()] * num_return_sequences

        prompt, prompt_left, prompt_right = [], [], []
        for i in range(len(mod_prompt)):
            mp = mod_prompt[i].replace(task.get_prompt(), "")
            prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
            prompt_left.append(
                add_style_and_character(task.get_prompt_left(), prepend) + mp
            )
            prompt_right.append(
                add_style_and_character(task.get_prompt_right(), prepend) + mp
            )

        params = Text2Img.Params(
            prompt=prompt,
            prompt_left=prompt_left,
            prompt_right=prompt_right,
        )
    else:
        if task.is_prompt_engineering():
            mod_prompt = prompt_modifier.modify(task.get_prompt())
        else:
            mod_prompt = [task.get_prompt()] * num_return_sequences
        mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]

        params = Text2Img.Params(
            prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences,
            modified_prompt=mod_prompt,
        )

    print(params)

    return params


def get_patched_prompt_tile_upscale(
    task: Task,
    avatar: Avatar,
    lora_style: LoraStyle,
    img_classifier: ImageClassifier,
    img2text: Image2Text,
):
    if task.get_prompt():
        prompt = task.get_prompt()
    else:
        prompt = img2text.process(task.get_imageUrl())

    # merge blip
    if task.PROMPT.has_placeholder_blip_merge():
        blip = img2text.process(task.get_imageUrl())
        prompt = task.PROMPT.merge_blip(blip)

    # remove anomalies in prompt
    prompt = remove_colors(prompt)

    prompt = avatar.add_code_names(prompt)
    prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())

    if not task.get_style():
        class_name = img_classifier.classify(
            task.get_imageUrl(), task.get_width(), task.get_height()
        )
    else:
        class_name = ""
    prompt = class_name + " " + prompt
    prompt = prompt.strip()

    print({"prompt": prompt})

    return prompt