Learner's picture
Update app.py (#1)
e454d64
import gradio as gr
from PIL import Image
import os
# Diffusers
from diffusers import (
FlaxControlNetModel,
FlaxStableDiffusionControlNetPipeline
)
from diffusers.utils import load_image
# PyTorch
import torch
# Numpy
import numpy as np
# Jax
import jax
import jax.numpy as jnp
from jax import pmap
# Flax
import flax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
def create_key(seed=0):
return jax.random.PRNGKey(seed)
# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"learner/jax-diffuser-event", from_flax=True, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
from_pt=True,
dtype=jnp.bfloat16,
#safety_checker=None,
)
# inference function takes prompt, negative prompt and image
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 # jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
battlemap_image = Image.fromarray(image)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([battlemap_image] * num_samples) #battlemap_image
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
# params = params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_image = pipe.numpy_to_pil(
np.asarray(output.reshape((num_samples,) + output.shape[-3:]))
)
return output_image
title = "ControlNet + Stable Diffusion for Battlemaps"
description = """Sketch your game battlemap and add some prompts to let the magic happen 🪄.
Pretrained on battlemaps images.
By Orgrim, Karm and Robin
"""
# you need to pass inputs and outputs according to inference function
gr.Interface(
fn=infer,
inputs=["text", "text", "image"],
outputs="image",
title=title,
description=description,
examples=[
["underground, castle, cave, medieval, knights", "outside, sunny, modern, green", "map.png"]
]
).launch()