import gradio as gr import jax import numpy as np import jax.numpy as jnp from flax.jax_utils import replicate from flax.training.common_utils import shard from PIL import Image from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel import cv2 # load control net and stable diffusion v1-5 controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd" ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 ) def create_key(seed=0): return jax.random.PRNGKey(seed) def process_mask(image): mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) mask = cv2.resize(mask,(512,512)) return mask def infer(prompts, negative_prompts, image): params["controlnet"] = controlnet_params num_samples = 1 #jax.device_count() rng = create_key(0) rng = jax.random.split(rng, jax.device_count()) im = process_mask(image) mask = Image.fromarray(im) prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) processed_image = pipe.prepare_image_inputs([mask] * num_samples) p_params = replicate(params) prompt_ids = shard(prompt_ids) negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) print(processed_image[0].shape) output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, prng_seed=rng, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) return output_images e_images = ['0.png', '0.png', '0.png', '0.png', '0.png', '2.png', '2.png', '2.png', '2.png',] e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east', 'a dog in the middle of the road, shadow on the ground,light direction north-west', 'a dog in the middle of the road, shadow on the ground,light direction south-west', 'a dog in the middle of the road, shadow on the ground,light direction south-east', 'a red rural house, shadow on the ground, light direction north', 'a red rural house, shadow on the ground, light direction east', 'a red rural house, shadow on the ground, light direction south', 'a red rural house, shadow on the ground, light direction west'] e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches', 'monochromatic, unrealistic, bad looking, full of glitches'] examples = [] for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts): examples.append([prompt, negative_prompt, image]) title = " # ControlLight: Light control through ControlNet and Depth Maps conditioning" with gr.Blocks() as demo: gr.Markdown(title) prompts = gr.Textbox(label='prompts') negative_prompts = gr.Textbox(label='negative_prompts') with gr.Row(): with gr.Column(): in_image = gr.Image(label="Depth Map Conditioning") with gr.Column(): out_image = gr.Gallery(label="Generated Image") with gr.Row(): btn = gr.Button("Run") with gr.Row(): gr.Examples(examples=examples, inputs=[prompts,negative_prompts, in_image], outputs=out_image, fn=infer, cache_examples=True) btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image) demo.launch()