import sys, os
import gradio as gr
## if kgen not exist
try:
import kgen
except:
GH_TOKEN = os.getenv("GITHUB_TOKEN")
git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TITPOP-KGen@titpop"
## call pip install
os.system(f"pip install git+{git_url}")
import re
import random
from time import time
import torch
from transformers import set_seed
if sys.platform == "win32":
#dev env in windows, @spaces.GPU will cause problem
def GPU(func):
return func
else:
from spaces import GPU
import kgen.models as models
import kgen.executor.titpop as titpop
from kgen.formatter import seperate_tags, apply_format
from kgen.generate import generate
from diff import load_model, encode_prompts
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
sdxl_pipe = load_model()
models.load_model(
"KBlueLeaf/TITPOP-200M-dev",
device="cuda",
subfolder="dan-cc-coyo_epoch2",
)
generate(max_new_tokens=4)
DEFAULT_TAGS = """
1girl, king halo (umamusume), umamusume,
ningen mame, ciloranko, ogipote, misu kasumi,
solo, leaning forward, sky,
masterpiece, absurdres, sensitive, newest
""".strip()
DEFAULT_NL = """
An illustration of a girl
""".strip()
def format_time(timing):
total = timing["total"]
generate_pass = timing["generate_pass"]
result = ""
result += f"""
### Process Time
| Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second|
|-|-|-|
"""
if "generated_tokens" in timing:
total_generated_tokens = timing["generated_tokens"]
total_input_tokens = timing["input_tokens"]
if "generated_tokens" in timing and "total_sampling" in timing:
sampling_time = timing["total_sampling"] / 1000
process_time = timing["prompt_process"] / 1000
model_time = timing["total_eval"] / 1000
result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second|
| Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second|
| Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second|
"""
if "generated_tokens" in timing:
result += f"""
### Processed Tokens:
* {total_input_tokens:} Input Tokens
* {total_generated_tokens:} Output Tokens
"""
return result
@GPU
@torch.no_grad()
def generate(
tags,
nl_prompt,
black_list,
temp,
output_format,
target_length,
top_p,
min_p,
top_k,
seed,
escape_brackets,
):
default_format = DEFAULT_FORMAT[output_format]
titpop.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()]
generation_setting = {
"seed": seed,
"temperature": temp,
"top_p": top_p,
"min_p": min_p,
"top_k": top_k,
}
inputs = seperate_tags(tags.split(","))
if nl_prompt:
if "<|extended|>" in default_format:
inputs["extended"] = nl_prompt
elif "<|generated|>" in default_format:
inputs["generated"] = nl_prompt
input_prompt = apply_format(inputs, default_format)
if escape_brackets:
input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt)
meta, operations, general, nl_prompt = titpop.parse_titpop_request(
seperate_tags(tags.split(",")),
nl_prompt,
tag_length_target=target_length,
generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt,
)
t0 = time()
for result, timing in titpop.titpop_runner_generator(
meta, operations, general, nl_prompt, **generation_setting
):
result = apply_format(result, default_format)
if escape_brackets:
result = re.sub(r"([()\[\]])", r"\\\1", result)
timing["total"] = time() - t0
yield result, input_prompt, format_time(timing)
@GPU
@torch.no_grad()
def generate_image(
seed,
prompt,
prompt2,
):
torch.cuda.empty_cache()
set_seed(seed)
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
)
result2 = sdxl_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_embeds2,
negative_pooled_prompt_embeds=neg_pooled_embeds2,
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
yield result2, None
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
)
set_seed(seed)
result = sdxl_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_embeds2,
negative_pooled_prompt_embeds=neg_pooled_embeds2,
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
torch.cuda.empty_cache()
yield result2, result
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Accordion("Introduction and Instructions", open=False):
gr.Markdown(
"""
## TITPOP Demo
### What is this
TITPOP is a tool to extend, generate, refine the input prompt for T2I models.
It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models.
You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2)
### How to use this demo
1. Enter your tags(optional): put the desired tags into "danboru tags" box
2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box
3. Enter your black list(optional): put the desired black list into "black list" box
4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ...
4. Click "TITPOP" button: you will see refined prompt on "result" box
5. If you like the result, click "Generate Image From Result" button
* You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt
* The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False
### Why inference code is private? When will it be open sourced?
1. This model/tool is still under development, currently is early Alpha version.
2. I'm doing some research and projects based on this.
3. The model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself.
4. Once the project/research are done, I will open source all these models/codes with Apache2 license.
### Notification
**ITPOP is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model**
"""
)
with gr.Row():
with gr.Column(scale=5):
with gr.Row():
with gr.Column(scale=3):
tags_input = gr.TextArea(
label="Danbooru Tags",
lines=7,
show_copy_button=True,
interactive=True,
value=DEFAULT_TAGS,
placeholder="Enter danbooru tags here",
)
nl_prompt_input = gr.Textbox(
label="Natural Language Prompt",
lines=7,
show_copy_button=True,
interactive=True,
value=DEFAULT_NL,
placeholder="Enter Natural Language Prompt here",
)
black_list = gr.TextArea(
label="Black List (seperated by comma)",
lines=4,
interactive=True,
value="monochrome",
placeholder="Enter tag/nl black list here",
)
with gr.Column(scale=2):
output_format = gr.Dropdown(
label="Output Format",
choices=list(DEFAULT_FORMAT.keys()),
value="Both, tag first (recommend)"
)
target_length = gr.Dropdown(
label="Target Length",
choices=["very_short", "short", "long", "very_long"],
value="long",
)
temp = gr.Slider(
label="Temp",
minimum=0.0,
maximum=1.5,
value=0.5,
step=0.05,
)
top_p = gr.Slider(
label="Top P",
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.05,
)
min_p = gr.Slider(
label="Min P",
minimum=0.0,
maximum=0.2,
value=0.05,
step=0.01,
)
top_k = gr.Slider(
label="Top K", minimum=0, maximum=120, value=60, step=1
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=2147483647,
value=20090220,
step=1,
)
escape_brackets = gr.Checkbox(
label="Escape Brackets", value=False
)
submit = gr.Button("TITPOP!", variant="primary")
with gr.Accordion("Speed statstics", open=False):
cost_time = gr.Markdown()
with gr.Column(scale=5):
result = gr.TextArea(
label="Result", lines=8, show_copy_button=True, interactive=False
)
input_prompt = gr.Textbox(
label="Input Prompt", lines=1, interactive=False, visible=False
)
gen_img = gr.Button("Generate Image from Result", variant="primary", interactive=False)
with gr.Row():
with gr.Column():
img1 = gr.Image(label="Original Propmt", interactive=False)
with gr.Column():
img2 = gr.Image(label="Generated Prompt", interactive=False)
def generate_wrapper(*args):
yield "", "", "", gr.update(interactive=False),
for i in generate(*args):
yield *i, gr.update(interactive=False)
yield *i, gr.update(interactive=True)
submit.click(
generate_wrapper,
[
tags_input,
nl_prompt_input,
black_list,
temp,
output_format,
target_length,
top_p,
min_p,
top_k,
seed,
escape_brackets,
],
[
result,
input_prompt,
cost_time,
gen_img,
],
queue=True,
)
def generate_image_wrapper(seed, result, input_prompt):
for img1, img2 in generate_image(seed, result, input_prompt):
yield img1, img2, gr.update(interactive=False)
yield img1, img2, gr.update(interactive=True)
gen_img.click(
generate_image_wrapper,
[seed, result, input_prompt],
[img1, img2, submit],
queue=True,
)
gen_img.click(
lambda *args: gr.update(interactive=False),
None,
[submit],
queue=False,
)
demo.launch()