from typing import List, Optional from internals.data.task import Task from internals.pipelines.commons import Text2Img from internals.pipelines.img_classifier import ImageClassifier from internals.pipelines.img_to_text import Image2Text from internals.pipelines.prompt_modifier import PromptModifier from internals.util.anomaly import remove_colors from internals.util.avatar import Avatar from internals.util.config import get_num_return_sequences from internals.util.lora_style import LoraStyle def get_patched_prompt( task: Task, avatar: Avatar, lora_style: LoraStyle, prompt_modifier: PromptModifier, ): def add_style_and_character(prompt: List[str], additional: Optional[str] = None): for i in range(len(prompt)): prompt[i] = avatar.add_code_names(prompt[i]) prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) prompt[i] = lora_style.append_style_to_prompt(prompt[i], task.get_style()) if additional: prompt[i] = additional + " " + prompt[i] prompt = task.get_prompt() if task.is_prompt_engineering(): prompt = prompt_modifier.modify(prompt) else: prompt = [prompt] * get_num_return_sequences() ori_prompt = [task.get_prompt()] * get_num_return_sequences() class_name = None add_style_and_character(ori_prompt, class_name) add_style_and_character(prompt, class_name) print({"prompts": prompt}) return (prompt, ori_prompt) def get_patched_prompt_text2img( task: Task, avatar: Avatar, lora_style: LoraStyle, prompt_modifier: PromptModifier, ) -> Text2Img.Params: def add_style_and_character(prompt: str, prepend: str = ""): prompt = avatar.add_code_names(prompt) prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) prompt = lora_style.append_style_to_prompt(prompt, task.get_style()) prompt = prepend + prompt return prompt if task.get_prompt_left() and task.get_prompt_right(): # prepend = "2characters, " prepend = "" if task.is_prompt_engineering(): mod_prompt = prompt_modifier.modify(task.get_prompt()) else: mod_prompt = [task.get_prompt()] * get_num_return_sequences() prompt, prompt_left, prompt_right = [], [], [] for i in range(len(mod_prompt)): mp = mod_prompt[i].replace(task.get_prompt(), "") prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp) prompt_left.append( add_style_and_character(task.get_prompt_left(), prepend) + mp ) prompt_right.append( add_style_and_character(task.get_prompt_right(), prepend) + mp ) params = Text2Img.Params( prompt=prompt, prompt_left=prompt_left, prompt_right=prompt_right, ) else: if task.is_prompt_engineering(): mod_prompt = prompt_modifier.modify(task.get_prompt()) else: mod_prompt = [task.get_prompt()] * get_num_return_sequences() mod_prompt = [add_style_and_character(mp) for mp in mod_prompt] params = Text2Img.Params( prompt=[add_style_and_character(task.get_prompt())] * get_num_return_sequences(), modified_prompt=mod_prompt, ) print(params) return params def get_patched_prompt_tile_upscale( task: Task, avatar: Avatar, lora_style: LoraStyle, img_classifier: ImageClassifier, img2text: Image2Text, is_sdxl=False, ): if task.get_prompt(): prompt = task.get_prompt() else: prompt = img2text.process(task.get_imageUrl()) # merge blip if task.PROMPT.has_placeholder_blip_merge(): blip = img2text.process(task.get_imageUrl()) prompt = task.PROMPT.merge_blip(blip) # remove anomalies in prompt if not is_sdxl: prompt = remove_colors(prompt) prompt = avatar.add_code_names(prompt) prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) prompt = lora_style.append_style_to_prompt(prompt, task.get_style()) if not task.get_style(): class_name = img_classifier.classify( task.get_imageUrl(), task.get_width(), task.get_height() ) else: class_name = "" prompt = class_name + " " + prompt prompt = prompt.strip() print({"prompt": prompt}) return prompt