Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import os | |
# Diffusers | |
from diffusers import ( | |
FlaxControlNetModel, | |
FlaxStableDiffusionControlNetPipeline | |
) | |
from diffusers.utils import load_image | |
# PyTorch | |
import torch | |
# Numpy | |
import numpy as np | |
# Jax | |
import jax | |
import jax.numpy as jnp | |
from jax import pmap | |
# Flax | |
import flax | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false" | |
def create_key(seed=0): | |
return jax.random.PRNGKey(seed) | |
# load control net and stable diffusion v1-5 | |
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
"learner/jax-diffuser-event", from_flax=True, dtype=jnp.bfloat16 | |
) | |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
controlnet=controlnet, | |
from_pt=True, | |
dtype=jnp.bfloat16, | |
#safety_checker=None, | |
) | |
# inference function takes prompt, negative prompt and image | |
def infer(prompts, negative_prompts, image): | |
params["controlnet"] = controlnet_params | |
num_samples = 1 # jax.device_count() | |
rng = create_key(0) | |
rng = jax.random.split(rng, jax.device_count()) | |
battlemap_image = Image.fromarray(image) | |
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) | |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) | |
processed_image = pipe.prepare_image_inputs([battlemap_image] * num_samples) #battlemap_image | |
p_params = replicate(params) | |
prompt_ids = shard(prompt_ids) | |
negative_prompt_ids = shard(negative_prompt_ids) | |
processed_image = shard(processed_image) | |
output = pipe( | |
prompt_ids=prompt_ids, | |
image=processed_image, | |
params=p_params, | |
# params = params, | |
prng_seed=rng, | |
num_inference_steps=50, | |
neg_prompt_ids=negative_prompt_ids, | |
jit=True, | |
).images | |
output_image = pipe.numpy_to_pil( | |
np.asarray(output.reshape((num_samples,) + output.shape[-3:])) | |
) | |
return output_image | |
title = "ControlNet + Stable Diffusion for Battlemaps" | |
description = """Sketch your game battlemap and add some prompts to let the magic happen 🪄. | |
Pretrained on battlemaps images. | |
By Orgrim, Karm and Robin | |
""" | |
# you need to pass inputs and outputs according to inference function | |
gr.Interface( | |
fn=infer, | |
inputs=["text", "text", "image"], | |
outputs="image", | |
title=title, | |
description=description, | |
examples=[ | |
["underground, castle, cave, medieval, knights", "outside, sunny, modern, green", "map.png"] | |
] | |
).launch() | |