lmattingly13 commited on
Commit
049b8a4
·
1 Parent(s): 4b5247c

added resize method, now large images are working

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -8,6 +8,8 @@ from PIL import Image
8
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
  import cv2
10
 
 
 
11
  title = "ControlNet for Cartoon-ifying"
12
  description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
13
  examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]
@@ -30,17 +32,31 @@ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
30
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def create_key(seed=0):
34
  return jax.random.PRNGKey(seed)
35
 
36
  def infer(prompts, image):
37
  params["controlnet"] = controlnet_params
38
 
 
39
  num_samples = 1 #jax.device_count()
40
  rng = create_key(0)
41
  rng = jax.random.split(rng, jax.device_count())
42
- im = image
43
- image = Image.fromarray(im)
44
 
45
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
46
  processed_image = pipe.prepare_image_inputs([image] * num_samples)
 
8
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
  import cv2
10
 
11
+
12
+
13
  title = "ControlNet for Cartoon-ifying"
14
  description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
15
  examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]
 
32
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
33
  )
34
 
35
+ def resize_image(im, max_size):
36
+ im_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
37
+
38
+ height, width = im_np.shape[:2]
39
+
40
+ scale_factor = max_size / max(height, width)
41
+
42
+ resized_np = cv2.resize(im_np, (int(width * scale_factor), int(height * scale_factor)))
43
+
44
+ resized_im = Image.fromarray(resized_np)
45
+
46
+ return resized_im
47
+
48
  def create_key(seed=0):
49
  return jax.random.PRNGKey(seed)
50
 
51
  def infer(prompts, image):
52
  params["controlnet"] = controlnet_params
53
 
54
+ image = resize_image(image, 500)
55
  num_samples = 1 #jax.device_count()
56
  rng = create_key(0)
57
  rng = jax.random.split(rng, jax.device_count())
58
+ #im = image
59
+ #image = Image.fromarray(im)
60
 
61
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
62
  processed_image = pipe.prepare_image_inputs([image] * num_samples)