File size: 2,316 Bytes
19b3da3 35575bb 19b3da3 b71808f 19b3da3 b71808f 19b3da3 b71808f 86248f3 b71808f 19b3da3 86248f3 19b3da3 35575bb 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from typing import List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from internals.util.config import get_num_return_sequences
class PromptModifier:
__loaded = False
def __init__(self, num_of_sequences: Optional[int] = 4):
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
self.__num_of_sequences = num_of_sequences
def load(self):
if self.__loaded:
return
self.prompter_model = AutoModelForCausalLM.from_pretrained(
"Gustavosta/MagicPrompt-Stable-Diffusion"
)
self.prompter_tokenizer = AutoTokenizer.from_pretrained(
"Gustavosta/MagicPrompt-Stable-Diffusion"
)
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
self.prompter_tokenizer.padding_side = "left"
self.__loaded = True
def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]:
self.load()
eos_id = self.prompter_tokenizer.eos_token_id
# restricted_words_list = ["octane", "cyber"]
# restricted_words_token_ids = prompter_tokenizer(
# restricted_words_list, add_special_tokens=False
# ).input_ids
num_of_sequences = num_of_sequences or self.__num_of_sequences
generation_config = GenerationConfig(
do_sample=False,
max_new_tokens=75,
num_beams=4,
num_return_sequences=get_num_return_sequences(),
eos_token_id=eos_id,
pad_token_id=eos_id,
length_penalty=-1.0,
)
input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids
outputs = self.prompter_model.generate(
input_ids, generation_config=generation_config
)
output_texts = self.prompter_tokenizer.batch_decode(
outputs, skip_special_tokens=True
)
output_texts = self.__patch_blacklist_words(output_texts)
return output_texts
def __patch_blacklist_words(self, texts: List[str]):
def replace_all(text, dic):
for i, j in dic.items():
text = text.replace(i, j)
return text
return [replace_all(text, self.__blacklist) for text in texts]
|