|
from transformers import pipeline, set_seed,AutoTokenizer, AutoModelForSeq2SeqLM |
|
import random |
|
import re |
|
|
|
import os,sys |
|
import folder_paths |
|
|
|
|
|
import importlib.util |
|
|
|
import comfy.utils |
|
|
|
import torch |
|
import random |
|
from lark import Lark, Transformer, v_args |
|
|
|
|
|
global _available |
|
_available=True |
|
|
|
def get_text_generator_path(): |
|
try: |
|
return folder_paths.get_folder_paths('prompt_generator')[0] |
|
except: |
|
return os.path.join(folder_paths.models_dir, "prompt_generator") |
|
|
|
prompt_generator=get_text_generator_path() |
|
|
|
text_generator_model_path=os.path.join(prompt_generator, "text2image-prompt-generator") |
|
if not os.path.exists(text_generator_model_path): |
|
print(f"## text_generator_model not found: {text_generator_model_path}, pls download from https://huggingface.co/succinctly/text2image-prompt-generator/tree/main") |
|
text_generator_model_path='succinctly/text2image-prompt-generator' |
|
|
|
zh_en_model_path=os.path.join(prompt_generator, "opus-mt-zh-en") |
|
if not os.path.exists(zh_en_model_path): |
|
print(f"## zh_en_model not found: {zh_en_model_path}, pls download from https://huggingface.co/Helsinki-NLP/opus-mt-zh-en/tree/main") |
|
zh_en_model_path='Helsinki-NLP/opus-mt-zh-en' |
|
|
|
|
|
def is_installed(package): |
|
try: |
|
spec = importlib.util.find_spec(package) |
|
except ModuleNotFoundError: |
|
return False |
|
return spec is not None |
|
|
|
|
|
try: |
|
if is_installed('sentencepiece')==False: |
|
import subprocess |
|
|
|
|
|
print('#pip install sentencepiece') |
|
|
|
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'sentencepiece'], capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0 and is_installed('sentencepiece'): |
|
print("#install success") |
|
_available=True |
|
else: |
|
print("#install error") |
|
_available=False |
|
|
|
else: |
|
_available=True |
|
|
|
except: |
|
_available=False |
|
|
|
|
|
|
|
def translate(text): |
|
global text_pipe,zh_en_model,zh_en_tokenizer |
|
|
|
if zh_en_model==None: |
|
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() |
|
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path,padding=True, truncation=True) |
|
|
|
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
with torch.no_grad(): |
|
encoded = zh_en_tokenizer([text], return_tensors="pt") |
|
encoded.to(zh_en_model.device) |
|
sequences = zh_en_model.generate(**encoded) |
|
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
|
|
|
|
def text_generate(text_pipe,input,seed=None): |
|
|
|
if seed==None: |
|
seed = random.randint(100, 1000000) |
|
|
|
set_seed(seed) |
|
|
|
for count in range(6): |
|
sequences = text_pipe(input, max_length=random.randint(60, 90), num_return_sequences=8) |
|
list = [] |
|
for sequence in sequences: |
|
line = sequence['generated_text'].strip() |
|
if line != input and len(line) > (len(input) + 4) and line.endswith((":", "-", "—")) is False: |
|
list.append(line) |
|
|
|
result = "\n".join(list) |
|
result = re.sub('[^ ]+\.[^ ]+','', result) |
|
result = result.replace("<", "").replace(">", "") |
|
if result != "": |
|
return result |
|
if count == 5: |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
|
def correct_prompt_syntax(prompt=""): |
|
|
|
|
|
corrected_elements = [] |
|
|
|
prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':') |
|
|
|
prompt = re.sub(r'\s+', ' ', prompt).strip() |
|
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']') |
|
|
|
|
|
prompt_elements = prompt.split(',') |
|
|
|
def balance_brackets(element, open_bracket, close_bracket): |
|
open_brackets_count = element.count(open_bracket) |
|
close_brackets_count = element.count(close_bracket) |
|
return element + close_bracket * (open_brackets_count - close_brackets_count) |
|
|
|
for element in prompt_elements: |
|
element = element.strip() |
|
|
|
|
|
if not element: |
|
continue |
|
|
|
|
|
if element[0] in '([': |
|
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']') |
|
elif element[0] == '<': |
|
corrected_element = balance_brackets(element, '<', '>') |
|
else: |
|
|
|
corrected_element = element.lstrip(')]') |
|
|
|
corrected_elements.append(corrected_element) |
|
|
|
|
|
return ','.join(corrected_elements) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_language(input_str): |
|
|
|
count_cn = count_en = 0 |
|
for char in input_str: |
|
if '\u4e00' <= char <= '\u9fff': |
|
count_cn += 1 |
|
elif char.isalpha(): |
|
count_en += 1 |
|
|
|
|
|
if count_cn > count_en: |
|
return "cn" |
|
elif count_en > count_cn: |
|
return "en" |
|
else: |
|
return "unknow" |
|
|
|
|
|
|
|
|
|
|
|
grammar = """ |
|
start: sentence |
|
sentence: phrase ("," phrase)* |
|
phrase: emphasis | weight | word | lora | embedding | schedule |
|
emphasis: "(" sentence ")" -> emphasis |
|
| "[" sentence "]" -> weak_emphasis |
|
weight: "(" word ":" NUMBER ")" |
|
schedule: "[" word ":" word ":" NUMBER "]" |
|
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">" |
|
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)? |
|
word: WORD |
|
|
|
NUMBER: /\s*-?\d+(\.\d+)?\s*/ |
|
WORD: /[^,:\(\)\[\]<>]+/ |
|
""" |
|
|
|
|
|
|
|
@v_args(inline=True) |
|
class ChinesePromptTranslate(Transformer): |
|
|
|
def sentence(self, *args): |
|
return ", ".join(args) |
|
|
|
def phrase(self, *args): |
|
return "".join(args) |
|
|
|
def emphasis(self, *args): |
|
|
|
return "(" + "".join(args) + ")" |
|
|
|
def weak_emphasis(self, *args): |
|
print('weak_emphasis:',args) |
|
return "[" + "".join(args) + "]" |
|
|
|
def embedding(self,*args): |
|
print('prompt embedding',args[0]) |
|
if len(args) == 1: |
|
|
|
|
|
embedding_name = str(args[0]) |
|
return f"embedding:{embedding_name}" |
|
elif len(args) > 1: |
|
embedding_name,*numbers = args |
|
|
|
if len(numbers)==2: |
|
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}" |
|
elif len(numbers)==1: |
|
return f"embedding:{embedding_name}:{numbers[0]}" |
|
else: |
|
return f"embedding:{embedding_name}" |
|
|
|
def lora(self,*args): |
|
print('lora prompt',*args) |
|
if len(args) == 1: |
|
return f"<lora:{loar_name}>" |
|
elif len(args) > 1: |
|
|
|
_,loar_name,*numbers = args |
|
loar_name = str(loar_name).strip() |
|
if len(numbers)==2: |
|
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>" |
|
elif len(numbers)==1: |
|
return f"<lora:{loar_name}:{numbers[0]}>" |
|
else: |
|
return f"<lora:{loar_name}>" |
|
|
|
def weight(self, word,number): |
|
translated_word = translate(str(word)).rstrip('.') |
|
return f"({translated_word}:{str(number).strip()})" |
|
|
|
def schedule(self,*args): |
|
print('prompt schedule',args) |
|
data = [str(arg).strip() for arg in args] |
|
|
|
return f"[{':'.join(data)}]" |
|
|
|
def word(self, word): |
|
|
|
if detect_language(str(word)) == "cn": |
|
return translate(str(word)).rstrip('.') |
|
else: |
|
return str(word).rstrip('.') |
|
|
|
class ChinesePrompt: |
|
|
|
global _available |
|
available=_available |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"text": ("STRING",{"multiline": True,"default": "", "dynamicPrompts": False}), |
|
"generation": (["on","off"],{"default": "off"}), |
|
}, |
|
|
|
"optional":{ |
|
"seed":("INT", {"default": 100, "min": 100, "max": 1000000}), |
|
|
|
}, |
|
|
|
} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
RETURN_NAMES = ("prompt",) |
|
|
|
FUNCTION = "run" |
|
|
|
CATEGORY = "♾️Mixlab/Prompt" |
|
OUTPUT_NODE = True |
|
INPUT_IS_LIST = True |
|
OUTPUT_IS_LIST = (True,) |
|
|
|
global text_pipe,zh_en_model,zh_en_tokenizer |
|
|
|
text_pipe= None |
|
zh_en_model=None |
|
zh_en_tokenizer=None |
|
|
|
def run(self,text,seed,generation): |
|
|
|
|
|
seed=seed[0] |
|
generation=generation[0] |
|
|
|
|
|
pbar = comfy.utils.ProgressBar(len(text)+1) |
|
texts = [correct_prompt_syntax(t) for t in text] |
|
|
|
global text_pipe,zh_en_model,zh_en_tokenizer |
|
if zh_en_model==None: |
|
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() |
|
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path,padding=True, truncation=True) |
|
|
|
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
text_pipe=pipeline('text-generation', model=text_generator_model_path,device="cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
prompt_result=[] |
|
|
|
|
|
en_texts=[] |
|
|
|
for t in texts: |
|
if t: |
|
|
|
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate()) |
|
|
|
result = parser.parse(t).children |
|
|
|
|
|
en_texts.append(result[0]) |
|
|
|
zh_en_model.to('cpu') |
|
print("test en_text",en_texts) |
|
|
|
|
|
pbar.update(1) |
|
for t in en_texts: |
|
if generation=='on': |
|
prompt =text_generate(text_pipe,t,seed) |
|
|
|
lines = prompt.split("\n") |
|
longest_line = max(lines, key=len) |
|
|
|
prompt_result.append(longest_line) |
|
else: |
|
prompt_result.append(t) |
|
pbar.update(1) |
|
|
|
text_pipe.model.to('cpu') |
|
|
|
print('prompt_result',prompt_result,) |
|
|
|
if len(prompt_result)==0: |
|
prompt_result=[""] |
|
return { |
|
"ui":{ |
|
"prompt": prompt_result |
|
}, |
|
"result": (prompt_result,)} |
|
|
|
|
|
|
|
|
|
class PromptGenerate: |
|
|
|
global _available |
|
available=_available |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"text": ("STRING",{"multiline": True,"default": "", "dynamicPrompts": False}), |
|
}, |
|
|
|
"optional":{ |
|
"multiple": (["off","on"],), |
|
"seed":("INT", {"default": 100, "min": 100, "max": 1000000}), |
|
}, |
|
|
|
} |
|
|
|
RETURN_TYPES = ("STRING",) |
|
RETURN_NAMES = ("prompt",) |
|
|
|
FUNCTION = "run" |
|
|
|
CATEGORY = "♾️Mixlab/Prompt" |
|
OUTPUT_NODE = True |
|
INPUT_IS_LIST = True |
|
OUTPUT_IS_LIST = (True,) |
|
|
|
global text_pipe |
|
|
|
text_pipe= None |
|
|
|
|
|
def run(self,text,multiple,seed): |
|
global text_pipe |
|
|
|
seed=seed[0] |
|
|
|
multiple=multiple[0] |
|
|
|
|
|
pbar = comfy.utils.ProgressBar(len(text)) |
|
|
|
text_pipe=pipeline('text-generation', model=text_generator_model_path,device="cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
prompt_result=[] |
|
|
|
for t in text: |
|
prompt =text_generate(text_pipe,t,seed) |
|
prompt = prompt.split("\n") |
|
if multiple=='off': |
|
prompt = [max(prompt, key=len)] |
|
|
|
for p in prompt: |
|
prompt_result.append(p) |
|
pbar.update(1) |
|
|
|
text_pipe.model.to('cpu') |
|
|
|
return { |
|
"ui":{ |
|
"prompt": prompt_result |
|
}, |
|
"result": (prompt_result,)} |
|
|