File size: 3,151 Bytes
258d8c9
c74095e
975dc6e
c74095e
 
 
 
 
 
258d8c9
880828c
 
 
c74095e
 
3eed896
 
 
 
 
258d8c9
c74095e
 
 
 
 
258d8c9
 
090c9fa
 
c74095e
 
 
 
 
090c9fa
c74095e
febb26d
090c9fa
8118b09
c74095e
 
febb26d
090c9fa
8118b09
c74095e
 
 
 
 
 
 
090c9fa
c74095e
 
 
 
 
258d8c9
c131c56
125cff5
f9183eb
c131c56
880828c
 
 
 
806402e
 
 
 
880828c
ffe36d8
 
806402e
ffe36d8
 
 
 
 
880828c
ffe36d8
 
880828c
125cff5
5ad5f5e
 
125cff5
880828c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2

with open("test.html") as f:
    lines = f.readlines()

def create_key(seed=0):
    return jax.random.PRNGKey(seed)
def wandb_report(url):
    iframe = f'<iframe src ={url} style="border:none;height:1024px;width:100%"/frame>'
    return gr.HTML(iframe)

report_url = 'https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5'

controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "JFoz/dog-cat-pose", dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)

def infer(prompts, negative_prompts, image):

    params["controlnet"] = controlnet_params
    
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    image = Image.fromarray(image)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=50,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    
    output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return output_images


#control_image = "https://huggingface.co/spaces/kfahn/Animal_Pose_Control_Net/blob/main/image_control.png"

with gr.Blocks(theme='kfahn/AnimalPose') as demo:  
  gr.Markdown(
      """
      # Animal Pose Control Net
      ## This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
      [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)  
      [Diffusers model](https://huggingface.co/JFoz/dog-pose)  
      [Github](https://github.com/fi4cr/animalpose)   
      [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5)
      """)
  with gr.Column():
    with gr.Row():
      keypoint_tool = gr.HTML(lines)
    with gr.Row():
      pos_prompts  = gr.Textbox(label="Prompt")
    with gr.Row():
      neg_prompts  = gr.Textbox(label="Negative Prompt")
    with gr.Row():
      image = gr.Image()
    with gr.Row():
      report = wandb_report(report_url)

#gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = "gallery",
#            examples=[["a Labrador crossing the road", "low quality", control_image]])   

gr.Interface(fn=infer, inputs = ["text", "text", "image"], outputs = "gallery")   

demo.launch()