Adapter commited on
Commit
aa0bbd7
·
1 Parent(s): b8fb5b9
Files changed (4) hide show
  1. app.py +4 -2
  2. demo/demos.py +26 -0
  3. demo/model.py +62 -1
  4. ldm/modules/encoders/adapter.py +0 -1
app.py CHANGED
@@ -8,14 +8,14 @@ os.system('mim install mmcv-full==1.7.0')
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
- from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose, create_demo_color, create_demo_color_sketch, create_demo_openpose, create_demo_style_sketch
12
  import torch
13
  import subprocess
14
  import shlex
15
  from huggingface_hub import hf_hub_url
16
 
17
  urls = {
18
- 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth', 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth','third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth"],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
  'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
@@ -67,6 +67,8 @@ with gr.Blocks(css='style.css') as demo:
67
  create_demo_openpose(model.process_openpose)
68
  with gr.TabItem('Keypose'):
69
  create_demo_keypose(model.process_keypose)
 
 
70
  with gr.TabItem('Sketch'):
71
  create_demo_sketch(model.process_sketch)
72
  with gr.TabItem('Draw'):
 
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
+ from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose, create_demo_color, create_demo_color_sketch, create_demo_openpose, create_demo_style_sketch, create_demo_canny
12
  import torch
13
  import subprocess
14
  import shlex
15
  from huggingface_hub import hf_hub_url
16
 
17
  urls = {
18
+ 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth', 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth','third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth", "models/t2iadapter_canny_sd14v1.pth"],
19
  'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
20
  'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
21
  }
 
67
  create_demo_openpose(model.process_openpose)
68
  with gr.TabItem('Keypose'):
69
  create_demo_keypose(model.process_keypose)
70
+ with gr.TabItem('Canny'):
71
+ create_demo_canny(model.process_canny)
72
  with gr.TabItem('Sketch'):
73
  create_demo_sketch(model.process_sketch)
74
  with gr.TabItem('Draw'):
demo/demos.py CHANGED
@@ -90,6 +90,32 @@ def create_demo_sketch(process):
90
  run_button.click(fn=process, inputs=ips, outputs=[result])
91
  return demo
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def create_demo_color_sketch(process):
94
  with gr.Blocks() as demo:
95
  with gr.Row():
 
90
  run_button.click(fn=process, inputs=ips, outputs=[result])
91
  return demo
92
 
93
+ def create_demo_canny(process):
94
+ with gr.Blocks() as demo:
95
+ with gr.Row():
96
+ gr.Markdown('## T2I-Adapter (Canny)')
97
+ with gr.Row():
98
+ with gr.Column():
99
+ input_img = gr.Image(source='upload', type="numpy")
100
+ prompt = gr.Textbox(label="Prompt")
101
+ neg_prompt = gr.Textbox(label="Negative Prompt",
102
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
103
+ pos_prompt = gr.Textbox(label="Positive Prompt",
104
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
105
+ with gr.Row():
106
+ type_in = gr.inputs.Radio(['Canny', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a canny map)')
107
+ color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the canny background\n (Only work for canny input)')
108
+ run_button = gr.Button(label="Run")
109
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the canny to the result)", minimum=0, maximum=1, value=1, step=0.1)
110
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
111
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
112
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
113
+ with gr.Column():
114
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
115
+ ips = [input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
116
+ run_button.click(fn=process, inputs=ips, outputs=[result])
117
+ return demo
118
+
119
  def create_demo_color_sketch(process):
120
  with gr.Blocks() as demo:
121
  with gr.Row():
demo/model.py CHANGED
@@ -74,7 +74,6 @@ def imshow_keypoints(img,
74
  if idx > 1:
75
  continue
76
  kpts = kpts['keypoints']
77
- # print(kpts)
78
  kpts = np.array(kpts, copy=False)
79
 
80
  # draw each point on image
@@ -138,6 +137,9 @@ class Model_all:
138
  self.sampler = PLMSSampler(self.base_model)
139
 
140
  # sketch part
 
 
 
141
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
142
  use_conv=False).to(device)
143
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
@@ -277,6 +279,65 @@ class Model_all:
277
 
278
  return [im_edge, x_samples_ddim]
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  @torch.no_grad()
281
  def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
282
  if self.current_base != base_model:
 
74
  if idx > 1:
75
  continue
76
  kpts = kpts['keypoints']
 
77
  kpts = np.array(kpts, copy=False)
78
 
79
  # draw each point on image
 
137
  self.sampler = PLMSSampler(self.base_model)
138
 
139
  # sketch part
140
+ self.model_canny = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
141
+ use_conv=False).to(device)
142
+ self.model_canny.load_state_dict(torch.load("models/t2iadapter_canny_sd14v1.pth", map_location=device))
143
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
144
  use_conv=False).to(device)
145
  self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
 
279
 
280
  return [im_edge, x_samples_ddim]
281
 
282
+ @torch.no_grad()
283
+ def process_canny(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale,
284
+ con_strength, base_model):
285
+ if self.current_base != base_model:
286
+ ckpt = os.path.join("models", base_model)
287
+ pl_sd = torch.load(ckpt, map_location="cuda")
288
+ if "state_dict" in pl_sd:
289
+ sd = pl_sd["state_dict"]
290
+ else:
291
+ sd = pl_sd
292
+ self.base_model.load_state_dict(sd, strict=False)
293
+ self.current_base = base_model
294
+ if 'anything' in base_model.lower():
295
+ self.load_vae()
296
+
297
+ con_strength = int((1 - con_strength) * 50)
298
+ if fix_sample == 'True':
299
+ seed_everything(42)
300
+ im = cv2.resize(input_img, (512, 512))
301
+
302
+ if type_in == 'Canny':
303
+ if color_back == 'White':
304
+ im = 255 - im
305
+ im_edge = im.copy()
306
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0) / 255.
307
+ elif type_in == 'Image':
308
+ im = cv2.Canny(im,100,200)
309
+ im = img2tensor(im[..., None], bgr2rgb=True, float32=True).unsqueeze(0) / 255.
310
+ im_edge = tensor2img(im)
311
+
312
+ # extract condition features
313
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
314
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
315
+ features_adapter = self.model_canny(im.to(self.device))
316
+ shape = [4, 64, 64]
317
+
318
+ # sampling
319
+ samples_ddim, _ = self.sampler.sample(S=50,
320
+ conditioning=c,
321
+ batch_size=1,
322
+ shape=shape,
323
+ verbose=False,
324
+ unconditional_guidance_scale=scale,
325
+ unconditional_conditioning=nc,
326
+ eta=0.0,
327
+ x_T=None,
328
+ features_adapter1=features_adapter,
329
+ mode='sketch',
330
+ con_strength=con_strength)
331
+
332
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
333
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
334
+ x_samples_ddim = x_samples_ddim.to('cpu')
335
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
336
+ x_samples_ddim = 255. * x_samples_ddim
337
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
338
+
339
+ return [im_edge, x_samples_ddim]
340
+
341
  @torch.no_grad()
342
  def process_color_sketch(self, input_img_sketch, input_img_color, type_in, type_in_color, w_sketch, w_color, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
343
  if self.current_base != base_model:
ldm/modules/encoders/adapter.py CHANGED
@@ -64,7 +64,6 @@ class ResnetBlock(nn.Module):
64
  if in_c != out_c or sk==False:
65
  self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
66
  else:
67
- # print('n_in')
68
  self.in_conv = None
69
  self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
70
  self.act = nn.ReLU()
 
64
  if in_c != out_c or sk==False:
65
  self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
66
  else:
 
67
  self.in_conv = None
68
  self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
69
  self.act = nn.ReLU()