RamAnanth1 commited on
Commit
6585503
1 Parent(s): e067e6a

Update app.py

Browse files

Add function to include human pose in control tasks

Files changed (1) hide show
  1. app.py +40 -1
app.py CHANGED
@@ -45,6 +45,8 @@ def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples,
45
  # TODO: Add other control tasks
46
  if input_control == "Scribble":
47
  return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
 
 
48
  return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
49
 
50
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
@@ -104,6 +106,42 @@ def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image
104
 
105
  results = [x_samples[i] for i in range(num_samples)]
106
  return [255 - detected_map] + results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def create_canvas(w, h):
109
  new_control_options = ["Interactive Scribble"]
@@ -113,7 +151,8 @@ def create_canvas(w, h):
113
  block = gr.Blocks().queue()
114
  control_task_list = [
115
  "Canny Edge Map",
116
- "Scribble"
 
117
  ]
118
  with block:
119
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")
 
45
  # TODO: Add other control tasks
46
  if input_control == "Scribble":
47
  return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
48
+ elif input_control == "Pose":
49
+ return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
50
  return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
51
 
52
  def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
 
106
 
107
  results = [x_samples[i] for i in range(num_samples)]
108
  return [255 - detected_map] + results
109
+
110
+ def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta):
111
+ with torch.no_grad():
112
+ input_image = HWC3(input_image)
113
+ detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
114
+ detected_map = HWC3(detected_map)
115
+ img = resize_image(input_image, image_resolution)
116
+ H, W, C = img.shape
117
+
118
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
119
+
120
+ control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
121
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
122
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
123
+
124
+ if seed == -1:
125
+ seed = random.randint(0, 65535)
126
+ seed_everything(seed)
127
+
128
+
129
+ cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
130
+ un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
131
+ shape = (4, H // 8, W // 8)
132
+
133
+
134
+ samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
135
+ shape, cond, verbose=False, eta=eta,
136
+ unconditional_guidance_scale=scale,
137
+ unconditional_conditioning=un_cond)
138
+
139
+
140
+ x_samples = pose_model.decode_first_stage(samples)
141
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
142
+
143
+ results = [x_samples[i] for i in range(num_samples)]
144
+ return [detected_map] + results
145
 
146
  def create_canvas(w, h):
147
  new_control_options = ["Interactive Scribble"]
 
151
  block = gr.Blocks().queue()
152
  control_task_list = [
153
  "Canny Edge Map",
154
+ "Scribble",
155
+ "Pose"
156
  ]
157
  with block:
158
  gr.Markdown("## Adding Conditional Control to Text-to-Image Diffusion Models")