Spaces:
Running
on
Zero
Running
on
Zero
# Thanks: https://huggingface.co/spaces/stabilityai/stable-diffusion-3-medium | |
import spaces | |
import os | |
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from diffusers import StableDiffusion3Pipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
device = "cuda" | |
dtype = torch.float16 | |
repo = "stabilityai/stable-diffusion-3.5-large" | |
t2i = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=torch.bfloat16, token=os.environ["TOKEN"]).to(device) | |
model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Phi-3-mini-4k-instruct", | |
device_map="cuda", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
token=os.environ["TOKEN"] | |
) | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", token=os.environ["TOKEN"]) | |
upsampler = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
) | |
generation_args = { | |
"max_new_tokens": 200, | |
"return_full_text": False, | |
"temperature": 0.7, | |
"do_sample": True, | |
"top_p": 0.95 | |
} | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1344 | |
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)): | |
messages = [ | |
{"role": "user", "content": "次のプロンプトを想像を膨らませて英語に翻訳してください。「クールなアニメ風の女の子」"}, | |
{"role": "assistant", "content": "An anime style illustration of a cool-looking teenage girl with an edgy, confident expression. She has piercing eyes, a slight smirk, and colorful hair that flows in the wind. "}, | |
{"role": "user", "content": "次のプロンプトを想像を膨らませて英語に翻訳してください。「実写風の女子高生」"}, | |
{"role": "assistant", "content": "A photorealistic image of a female high school student standing on a city street. She is wearing a traditional Japanese school uniform, consisting of a navy blue blazer, a white blouse, and a knee-length plaid skirt. "}, | |
{"role": "user", "content": f"次のプロンプトを想像を膨らませて英語に翻訳してください。「{prompt}」" }, | |
] | |
output = upsampler(messages, **generation_args) | |
upsampled_prompt=output[0]['generated_text'] | |
print(upsampled_prompt) | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
image = t2i( | |
prompt = upsampled_prompt, | |
negative_prompt = negative_prompt, | |
guidance_scale = guidance_scale, | |
num_inference_steps = num_inference_steps, | |
width = width, | |
height = height, | |
generator = generator | |
).images[0] | |
return image, seed, upsampled_prompt | |
examples = [ | |
"美味しい肉", | |
"馬に乗った宇宙飛行士", | |
"アニメ風の美少女", | |
"女子高生の写真", | |
"寿司でできた家に入っているコーギー", | |
"バナナとアボカドが戦っている様子" | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 580px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f""" | |
# 日本語が入力できる SD3.5 Large | |
""") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="プロンプト", | |
show_label=False, | |
max_lines=1, | |
placeholder="作りたい画像の特徴を入力してください", | |
container=False, | |
) | |
run_button = gr.Button("実行", scale=0) | |
result = gr.Image(label="結果", show_label=False) | |
generated_prompt = gr.Textbox(label="生成に使ったプロンプト", show_label=False, interactive=False) | |
with gr.Accordion("詳細設定", open=False): | |
negative_prompt = gr.Text( | |
label="ネガティブプロンプト", | |
max_lines=1, | |
placeholder="画像から排除したい要素を入力してください", | |
) | |
seed = gr.Slider( | |
label="乱数のシード", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="ランダム生成", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="横", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="縦", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="プロンプトの忠実さ", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="推論回数", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
gr.Examples( | |
examples = examples, | |
inputs = [prompt] | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit, negative_prompt.submit], | |
fn = infer, | |
inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], | |
outputs = [result, seed, generated_prompt] | |
) | |
demo.launch() |