import sys, os import gradio as gr ## if kgen not exist try: import kgen except: GH_TOKEN = os.getenv("GITHUB_TOKEN") git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TITPOP-KGen@titpop" ## call pip install os.system(f"pip install git+{git_url}") import re import random from time import time import torch from transformers import set_seed if sys.platform == "win32": #dev env in windows, @spaces.GPU will cause problem def GPU(func): return func else: from spaces import GPU import kgen.models as models import kgen.executor.titpop as titpop from kgen.formatter import seperate_tags, apply_format from kgen.generate import generate from diff import load_model, encode_prompts from meta import DEFAULT_NEGATIVE_PROMPT sdxl_pipe = load_model() models.load_model( "KBlueLeaf/TITPOP-200M-dev", device="cuda", subfolder="dan-cc-coyo_epoch2", ) generate(max_new_tokens=4) DEFAULT_FORMAT = """<|special|>, <|characters|>, <|copyrights|>, <|artist|>, <|general|>, <|extended|>. <|quality|>, <|meta|>, <|rating|> """.strip() DEFAULT_TAGS = """ 1girl, ningen mame, ciloranko, solo, dragon girl, masterpiece, absurdres, safe, newest """.strip() DEFAULT_NL = """ An illustration of a girl """.strip() def format_time(timing): total = timing["total"] generate_pass = timing["generate_pass"] result = "" result += f""" ### Process Time | Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second| |-|-|-| """ if "generated_tokens" in timing: total_generated_tokens = timing["generated_tokens"] total_input_tokens = timing["input_tokens"] if "generated_tokens" in timing and "total_sampling" in timing: sampling_time = timing["total_sampling"] / 1000 process_time = timing["prompt_process"] / 1000 model_time = timing["total_eval"] / 1000 result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second| | Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second| | Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second| """ if "generated_tokens" in timing: result += f""" ### Processed Tokens: * {total_input_tokens:} Input Tokens * {total_generated_tokens:} Output Tokens """ return result @GPU @torch.no_grad() def generate( tags, nl_prompt, black_list, temp, target_length, top_p, min_p, top_k, seed, escape_brackets, ): titpop.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()] generation_setting = { "seed": seed, "temperature": temp, "top_p": top_p, "min_p": min_p, "top_k": top_k, } inputs = seperate_tags(tags.split(",")) if nl_prompt: if "<|extended|>" in DEFAULT_FORMAT: inputs["extended"] = nl_prompt elif "<|generated|>" in DEFAULT_FORMAT: inputs["generated"] = nl_prompt input_prompt = apply_format(inputs, DEFAULT_FORMAT) if escape_brackets: input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt) meta, operations, general, nl_prompt = titpop.parse_titpop_request( seperate_tags(tags.split(",")), nl_prompt, tag_length_target=target_length, generate_extra_nl_prompt="<|generated|>" in DEFAULT_FORMAT or not nl_prompt, ) t0 = time() for result, timing in titpop.titpop_runner_generator( meta, operations, general, nl_prompt, **generation_setting ): result = apply_format(result, DEFAULT_FORMAT) if escape_brackets: result = re.sub(r"([()\[\]])", r"\\\1", result) timing["total"] = time() - t0 yield result, input_prompt, format_time(timing) @GPU @torch.no_grad() def generate_image( seed, prompt, prompt2, ): torch.cuda.empty_cache() prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT) ) set_seed(seed) result = sdxl_pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_embeds2, negative_pooled_prompt_embeds=neg_pooled_embeds2, num_inference_steps=24, width=1024, height=1024, guidance_scale=6.0, ).images[0] prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = ( encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT) ) set_seed(seed) result2 = sdxl_pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_embeds2, negative_pooled_prompt_embeds=neg_pooled_embeds2, num_inference_steps=24, width=1024, height=1024, guidance_scale=6.0, ).images[0] torch.cuda.empty_cache() return result2, result if __name__ == "__main__": with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("""# TITPOP DEMO""") with gr.Accordion("Introduction and Instructions", open=False): gr.Markdown( """ ### What is this: TITPOP **The implementation is a little bit inefficient, image gen may be a little bit slower than expected.** """ ) with gr.Row(): with gr.Column(scale=5): with gr.Row(): with gr.Column(scale=3): tags_input = gr.TextArea( label="Danbooru Tags", lines=6, show_copy_button=True, interactive=True, value=DEFAULT_TAGS, placeholder="Enter danbooru tags here", ) nl_prompt_input = gr.Textbox( label="Natural Language Prompt", lines=6, show_copy_button=True, interactive=True, value=DEFAULT_NL, placeholder="Enter Natural Language Prompt here", ) black_list = gr.TextArea( label="Black List (seperated by comma)", lines=4, interactive=True, value="monochrome", placeholder="Enter tag/nl black list here", ) with gr.Column(scale=2): target_length = gr.Dropdown( label="Target Length", choices=["very_short", "short", "long", "very_long"], value="short", ) temp = gr.Slider( label="Temp", minimum=0.0, maximum=1.5, value=0.5, step=0.05, ) top_p = gr.Slider( label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05, ) min_p = gr.Slider( label="Min P", minimum=0.0, maximum=0.2, value=0.05, step=0.01, ) top_k = gr.Slider( label="Top K", minimum=0, maximum=120, value=60, step=1 ) with gr.Row(): seed = gr.Number( label="Seed", minimum=0, maximum=2147483647, value=20090220, step=1, ) escape_brackets = gr.Checkbox( label="Escape Brackets", value=False ) submit = gr.Button("TITPOP!", variant="primary") with gr.Accordion("Speed statstics", open=False): cost_time = gr.Markdown() with gr.Column(scale=5): result = gr.TextArea( label="Result", lines=8, show_copy_button=True, interactive=False ) input_prompt = gr.Textbox( label="Input Prompt", lines=1, interactive=False, visible=False ) gen_img = gr.Button("Generate Image from Result", variant="primary", interactive=False) with gr.Row(): with gr.Column(): img1 = gr.Image(label="Original Propmt", interactive=False) with gr.Column(): img2 = gr.Image(label="Generated Prompt", interactive=False) def generate_wrapper(*args): yield "", "", "", gr.update(interactive=False), for i in generate(*args): yield *i, gr.update(interactive=False) yield *i, gr.update(interactive=True) submit.click( generate_wrapper, [ tags_input, nl_prompt_input, black_list, temp, target_length, top_p, min_p, top_k, seed, escape_brackets, ], [ result, input_prompt, cost_time, gen_img, ], queue=True, ) gen_img.click( lambda *args: (*generate_image(*args), gr.update(interactive=True)), [seed, result, input_prompt], [img1, img2, submit], queue=True, ) gen_img.click( lambda *args: gr.update(interactive=False), None, [submit], queue=False, ) demo.launch()