Ming Li commited on
Commit
b5515fe
·
1 Parent(s): 8fad46e
app.py CHANGED
@@ -7,15 +7,8 @@ import torch
7
 
8
  from app_canny import create_demo as create_demo_canny
9
  from app_depth import create_demo as create_demo_depth
10
- from app_ip2p import create_demo as create_demo_ip2p
11
  from app_lineart import create_demo as create_demo_lineart
12
- from app_mlsd import create_demo as create_demo_mlsd
13
- from app_normal import create_demo as create_demo_normal
14
- from app_openpose import create_demo as create_demo_openpose
15
- from app_scribble import create_demo as create_demo_scribble
16
- from app_scribble_interactive import create_demo as create_demo_scribble_interactive
17
  from app_segmentation import create_demo as create_demo_segmentation
18
- from app_shuffle import create_demo as create_demo_shuffle
19
  from app_softedge import create_demo as create_demo_softedge
20
  from model import Model
21
  from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
 
7
 
8
  from app_canny import create_demo as create_demo_canny
9
  from app_depth import create_demo as create_demo_depth
 
10
  from app_lineart import create_demo as create_demo_lineart
 
 
 
 
 
11
  from app_segmentation import create_demo as create_demo_segmentation
 
12
  from app_softedge import create_demo as create_demo_softedge
13
  from model import Model
14
  from settings import ALLOW_CHANGING_BASE_MODEL, DEFAULT_MODEL_ID, SHOW_DUPLICATE_BUTTON
app_canny.py CHANGED
@@ -16,8 +16,8 @@ def create_demo(process):
16
  with gr.Blocks() as demo:
17
  with gr.Row():
18
  with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt")
21
  run_button = gr.Button("Run")
22
  with gr.Accordion("Advanced options", open=False):
23
  num_samples = gr.Slider(
@@ -31,10 +31,10 @@ def create_demo(process):
31
  step=256,
32
  )
33
  canny_low_threshold = gr.Slider(
34
- label="Canny low threshold", minimum=1, maximum=255, value=100, step=1
35
  )
36
  canny_high_threshold = gr.Slider(
37
- label="Canny high threshold", minimum=1, maximum=255, value=200, step=1
38
  )
39
  num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
40
  guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
 
16
  with gr.Blocks() as demo:
17
  with gr.Row():
18
  with gr.Column():
19
+ image = gr.Image(value='images/canny_demo.jpg')
20
+ prompt = gr.Textbox(label="Prompt", value='BEAUTIFUL PORTRAIT PAINTINGS BY EMMA UBER')
21
  run_button = gr.Button("Run")
22
  with gr.Accordion("Advanced options", open=False):
23
  num_samples = gr.Slider(
 
31
  step=256,
32
  )
33
  canny_low_threshold = gr.Slider(
34
+ label="Canny low threshold", minimum=0, maximum=1.0, value=0.1, step=0.05
35
  )
36
  canny_high_threshold = gr.Slider(
37
+ label="Canny high threshold", minimum=0, maximum=1.0, value=0.2, step=0.05
38
  )
39
  num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
40
  guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
app_ip2p.py DELETED
@@ -1,87 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt")
21
- run_button = gr.Button("Run")
22
- with gr.Accordion("Advanced options", open=False):
23
- num_samples = gr.Slider(
24
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
25
- )
26
- image_resolution = gr.Slider(
27
- label="Image resolution",
28
- minimum=256,
29
- maximum=MAX_IMAGE_RESOLUTION,
30
- value=DEFAULT_IMAGE_RESOLUTION,
31
- step=256,
32
- )
33
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
34
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
35
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
36
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
37
- a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
38
- n_prompt = gr.Textbox(
39
- label="Negative prompt",
40
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
41
- )
42
- with gr.Column():
43
- result = gr.Gallery(label="Output", show_label=False, columns=4, object_fit="scale-down")
44
- inputs = [
45
- image,
46
- prompt,
47
- a_prompt,
48
- n_prompt,
49
- num_samples,
50
- image_resolution,
51
- num_steps,
52
- guidance_scale,
53
- seed,
54
- ]
55
- prompt.submit(
56
- fn=randomize_seed_fn,
57
- inputs=[seed, randomize_seed],
58
- outputs=seed,
59
- queue=False,
60
- api_name=False,
61
- ).then(
62
- fn=process,
63
- inputs=inputs,
64
- outputs=result,
65
- api_name=False,
66
- )
67
- run_button.click(
68
- fn=randomize_seed_fn,
69
- inputs=[seed, randomize_seed],
70
- outputs=seed,
71
- queue=False,
72
- api_name=False,
73
- ).then(
74
- fn=process,
75
- inputs=inputs,
76
- outputs=result,
77
- api_name="ip2p",
78
- )
79
- return demo
80
-
81
-
82
- if __name__ == "__main__":
83
- from model import Model
84
-
85
- model = Model(task_name="ip2p")
86
- demo = create_demo(model.process_ip2p)
87
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_mlsd.py DELETED
@@ -1,99 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt")
21
- run_button = gr.Button("Run")
22
- with gr.Accordion("Advanced options", open=False):
23
- num_samples = gr.Slider(
24
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
25
- )
26
- image_resolution = gr.Slider(
27
- label="Image resolution",
28
- minimum=256,
29
- maximum=MAX_IMAGE_RESOLUTION,
30
- value=DEFAULT_IMAGE_RESOLUTION,
31
- step=256,
32
- )
33
- preprocess_resolution = gr.Slider(
34
- label="Preprocess resolution", minimum=128, maximum=512, value=512, step=1
35
- )
36
- mlsd_value_threshold = gr.Slider(
37
- label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01
38
- )
39
- mlsd_distance_threshold = gr.Slider(
40
- label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01
41
- )
42
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
43
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
44
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
45
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
46
- a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
47
- n_prompt = gr.Textbox(
48
- label="Negative prompt",
49
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
50
- )
51
- with gr.Column():
52
- result = gr.Gallery(label="Output", show_label=False, columns=4, object_fit="scale-down")
53
- inputs = [
54
- image,
55
- prompt,
56
- a_prompt,
57
- n_prompt,
58
- num_samples,
59
- image_resolution,
60
- preprocess_resolution,
61
- num_steps,
62
- guidance_scale,
63
- seed,
64
- mlsd_value_threshold,
65
- mlsd_distance_threshold,
66
- ]
67
- prompt.submit(
68
- fn=randomize_seed_fn,
69
- inputs=[seed, randomize_seed],
70
- outputs=seed,
71
- queue=False,
72
- api_name=False,
73
- ).then(
74
- fn=process,
75
- inputs=inputs,
76
- outputs=result,
77
- api_name=False,
78
- )
79
- run_button.click(
80
- fn=randomize_seed_fn,
81
- inputs=[seed, randomize_seed],
82
- outputs=seed,
83
- queue=False,
84
- api_name=False,
85
- ).then(
86
- fn=process,
87
- inputs=inputs,
88
- outputs=result,
89
- api_name="mlsd",
90
- )
91
- return demo
92
-
93
-
94
- if __name__ == "__main__":
95
- from model import Model
96
-
97
- model = Model(task_name="MLSD")
98
- demo = create_demo(model.process_mlsd)
99
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_normal.py DELETED
@@ -1,95 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt")
21
- run_button = gr.Button("Run")
22
- with gr.Accordion("Advanced options", open=False):
23
- preprocessor_name = gr.Radio(
24
- label="Preprocessor", choices=["NormalBae", "None"], type="value", value="NormalBae"
25
- )
26
- num_samples = gr.Slider(
27
- label="Images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
28
- )
29
- image_resolution = gr.Slider(
30
- label="Image resolution",
31
- minimum=256,
32
- maximum=MAX_IMAGE_RESOLUTION,
33
- value=DEFAULT_IMAGE_RESOLUTION,
34
- step=256,
35
- )
36
- preprocess_resolution = gr.Slider(
37
- label="Preprocess resolution", minimum=128, maximum=512, value=384, step=1
38
- )
39
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
40
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
41
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
42
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
43
- a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
44
- n_prompt = gr.Textbox(
45
- label="Negative prompt",
46
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
47
- )
48
- with gr.Column():
49
- result = gr.Gallery(label="Output", show_label=False, columns=4, object_fit="scale-down")
50
- inputs = [
51
- image,
52
- prompt,
53
- a_prompt,
54
- n_prompt,
55
- num_samples,
56
- image_resolution,
57
- preprocess_resolution,
58
- num_steps,
59
- guidance_scale,
60
- seed,
61
- preprocessor_name,
62
- ]
63
- prompt.submit(
64
- fn=randomize_seed_fn,
65
- inputs=[seed, randomize_seed],
66
- outputs=seed,
67
- queue=False,
68
- api_name=False,
69
- ).then(
70
- fn=process,
71
- inputs=inputs,
72
- outputs=result,
73
- api_name=False,
74
- )
75
- run_button.click(
76
- fn=randomize_seed_fn,
77
- inputs=[seed, randomize_seed],
78
- outputs=seed,
79
- queue=False,
80
- api_name=False,
81
- ).then(
82
- fn=process,
83
- inputs=inputs,
84
- outputs=result,
85
- api_name="normal",
86
- )
87
- return demo
88
-
89
-
90
- if __name__ == "__main__":
91
- from model import Model
92
-
93
- model = Model(task_name="NormalBae")
94
- demo = create_demo(model.process_normal)
95
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_scribble.py DELETED
@@ -1,95 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt")
21
- run_button = gr.Button("Run")
22
- with gr.Accordion("Advanced options", open=False):
23
- preprocessor_name = gr.Radio(
24
- label="Preprocessor", choices=["HED", "PidiNet", "None"], type="value", value="HED"
25
- )
26
- num_samples = gr.Slider(
27
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
28
- )
29
- image_resolution = gr.Slider(
30
- label="Image resolution",
31
- minimum=256,
32
- maximum=MAX_IMAGE_RESOLUTION,
33
- value=DEFAULT_IMAGE_RESOLUTION,
34
- step=256,
35
- )
36
- preprocess_resolution = gr.Slider(
37
- label="Preprocess resolution", minimum=128, maximum=512, value=512, step=1
38
- )
39
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
40
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
41
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
42
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
43
- a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
44
- n_prompt = gr.Textbox(
45
- label="Negative prompt",
46
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
47
- )
48
- with gr.Column():
49
- result = gr.Gallery(label="Output", show_label=False, columns=4, object_fit="scale-down")
50
- inputs = [
51
- image,
52
- prompt,
53
- a_prompt,
54
- n_prompt,
55
- num_samples,
56
- image_resolution,
57
- preprocess_resolution,
58
- num_steps,
59
- guidance_scale,
60
- seed,
61
- preprocessor_name,
62
- ]
63
- prompt.submit(
64
- fn=randomize_seed_fn,
65
- inputs=[seed, randomize_seed],
66
- outputs=seed,
67
- queue=False,
68
- api_name=False,
69
- ).then(
70
- fn=process,
71
- inputs=inputs,
72
- outputs=result,
73
- api_name=False,
74
- )
75
- run_button.click(
76
- fn=randomize_seed_fn,
77
- inputs=[seed, randomize_seed],
78
- outputs=seed,
79
- queue=False,
80
- api_name=False,
81
- ).then(
82
- fn=process,
83
- inputs=inputs,
84
- outputs=result,
85
- api_name="scribble",
86
- )
87
- return demo
88
-
89
-
90
- if __name__ == "__main__":
91
- from model import Model
92
-
93
- model = Model(task_name="scribble")
94
- demo = create_demo(model.process_scribble)
95
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_scribble_interactive.py DELETED
@@ -1,115 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
- import numpy as np
5
-
6
- from settings import (
7
- DEFAULT_IMAGE_RESOLUTION,
8
- DEFAULT_NUM_IMAGES,
9
- MAX_IMAGE_RESOLUTION,
10
- MAX_NUM_IMAGES,
11
- MAX_SEED,
12
- )
13
- from utils import randomize_seed_fn
14
-
15
-
16
- def create_canvas(w, h):
17
- return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
18
-
19
-
20
- def create_demo(process):
21
- with gr.Blocks() as demo:
22
- with gr.Row():
23
- with gr.Column():
24
- canvas_width = gr.Slider(
25
- label="Canvas width",
26
- minimum=256,
27
- maximum=MAX_IMAGE_RESOLUTION,
28
- value=DEFAULT_IMAGE_RESOLUTION,
29
- step=1,
30
- )
31
- canvas_height = gr.Slider(
32
- label="Canvas height",
33
- minimum=256,
34
- maximum=MAX_IMAGE_RESOLUTION,
35
- value=DEFAULT_IMAGE_RESOLUTION,
36
- step=1,
37
- )
38
- create_button = gr.Button("Open drawing canvas!")
39
- image = gr.Image(tool="sketch", brush_radius=10)
40
- prompt = gr.Textbox(label="Prompt")
41
- run_button = gr.Button("Run")
42
- with gr.Accordion("Advanced options", open=False):
43
- num_samples = gr.Slider(
44
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
45
- )
46
- image_resolution = gr.Slider(
47
- label="Image resolution",
48
- minimum=256,
49
- maximum=MAX_IMAGE_RESOLUTION,
50
- value=DEFAULT_IMAGE_RESOLUTION,
51
- step=256,
52
- )
53
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
54
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
55
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
56
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
57
- a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
58
- n_prompt = gr.Textbox(
59
- label="Negative prompt",
60
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
61
- )
62
- with gr.Column():
63
- result = gr.Gallery(label="Output", show_label=False, columns=4, object_fit="scale-down")
64
-
65
- create_button.click(
66
- fn=create_canvas,
67
- inputs=[canvas_width, canvas_height],
68
- outputs=image,
69
- queue=False,
70
- api_name=False,
71
- )
72
-
73
- inputs = [
74
- image,
75
- prompt,
76
- a_prompt,
77
- n_prompt,
78
- num_samples,
79
- image_resolution,
80
- num_steps,
81
- guidance_scale,
82
- seed,
83
- ]
84
- prompt.submit(
85
- fn=randomize_seed_fn,
86
- inputs=[seed, randomize_seed],
87
- outputs=seed,
88
- queue=False,
89
- api_name=False,
90
- ).then(
91
- fn=process,
92
- inputs=inputs,
93
- outputs=result,
94
- api_name=False,
95
- )
96
- run_button.click(
97
- fn=randomize_seed_fn,
98
- inputs=[seed, randomize_seed],
99
- outputs=seed,
100
- queue=False,
101
- api_name=False,
102
- ).then(
103
- fn=process,
104
- inputs=inputs,
105
- outputs=result,
106
- )
107
- return demo
108
-
109
-
110
- if __name__ == "__main__":
111
- from model import Model
112
-
113
- model = Model(task_name="scribble")
114
- demo = create_demo(model.process_scribble_interactive)
115
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_segmentation.py CHANGED
@@ -17,7 +17,7 @@ def create_demo(process):
17
  with gr.Row():
18
  with gr.Column():
19
  image = gr.Image(value='images/seg_demo.png')
20
- prompt = gr.Textbox(label="Prompt", value='A large building with a pointed roof and several chimneys.')
21
  run_button = gr.Button("Run")
22
  with gr.Accordion("Advanced options", open=False):
23
  preprocessor_name = gr.Radio(
 
17
  with gr.Row():
18
  with gr.Column():
19
  image = gr.Image(value='images/seg_demo.png')
20
+ prompt = gr.Textbox(label="Prompt", value='A large building with a pointed roof and several chimneys')
21
  run_button = gr.Button("Run")
22
  with gr.Accordion("Advanced options", open=False):
23
  preprocessor_name = gr.Radio(
app_shuffle.py DELETED
@@ -1,91 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import gradio as gr
4
-
5
- from settings import (
6
- DEFAULT_IMAGE_RESOLUTION,
7
- DEFAULT_NUM_IMAGES,
8
- MAX_IMAGE_RESOLUTION,
9
- MAX_NUM_IMAGES,
10
- MAX_SEED,
11
- )
12
- from utils import randomize_seed_fn
13
-
14
-
15
- def create_demo(process):
16
- with gr.Blocks() as demo:
17
- with gr.Row():
18
- with gr.Column():
19
- image = gr.Image()
20
- prompt = gr.Textbox(label="Prompt")
21
- run_button = gr.Button("Run")
22
- with gr.Accordion("Advanced options", open=False):
23
- preprocessor_name = gr.Radio(
24
- label="Preprocessor", choices=["ContentShuffle", "None"], type="value", value="ContentShuffle"
25
- )
26
- num_samples = gr.Slider(
27
- label="Number of images", minimum=1, maximum=MAX_NUM_IMAGES, value=DEFAULT_NUM_IMAGES, step=1
28
- )
29
- image_resolution = gr.Slider(
30
- label="Image resolution",
31
- minimum=256,
32
- maximum=MAX_IMAGE_RESOLUTION,
33
- value=DEFAULT_IMAGE_RESOLUTION,
34
- step=256,
35
- )
36
- num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=20, step=1)
37
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
38
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
39
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
40
- a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
41
- n_prompt = gr.Textbox(
42
- label="Negative prompt",
43
- value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
44
- )
45
- with gr.Column():
46
- result = gr.Gallery(label="Output", show_label=False, columns=4, object_fit="scale-down")
47
- inputs = [
48
- image,
49
- prompt,
50
- a_prompt,
51
- n_prompt,
52
- num_samples,
53
- image_resolution,
54
- num_steps,
55
- guidance_scale,
56
- seed,
57
- preprocessor_name,
58
- ]
59
- prompt.submit(
60
- fn=randomize_seed_fn,
61
- inputs=[seed, randomize_seed],
62
- outputs=seed,
63
- queue=False,
64
- api_name=False,
65
- ).then(
66
- fn=process,
67
- inputs=inputs,
68
- outputs=result,
69
- api_name=False,
70
- )
71
- run_button.click(
72
- fn=randomize_seed_fn,
73
- inputs=[seed, randomize_seed],
74
- outputs=seed,
75
- queue=False,
76
- api_name=False,
77
- ).then(
78
- fn=process,
79
- inputs=inputs,
80
- outputs=result,
81
- api_name="content-shuffle",
82
- )
83
- return demo
84
-
85
-
86
- if __name__ == "__main__":
87
- from model import Model
88
-
89
- model = Model(task_name="shuffle")
90
- demo = create_demo(model.process_shuffle)
91
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints/canny/controlnet/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.26.3",
4
+ "_name_or_path": "work_dirs/finetune/MultiGen20M_canny/ft_controlnet_sd15_canny_res512_bs256_lr1e-5_warmup100_iter5k_fp16ft0-1000/checkpoint-5000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
checkpoints/canny/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3fd425077e65024addc5cf73c97195fcfd499b7a5e16868e4251b47cebb0d89
3
+ size 1445157120
checkpoints/depth/controlnet/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.26.3",
4
+ "_name_or_path": "work_dirs/finetune/MultiGen20M_depth/ft_controlnet_sd15_depth_res512_bs256_lr1e-5_warmup100_iter5k_fp16ft0-200/checkpoint-5000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
checkpoints/depth/controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7450404d13ef888c9701433a3c17b2a86c021a6d042f9f5d2519602abd7f2f3
3
+ size 1445157120
checkpoints/hed/controlnet/config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.19.3",
4
+ "_name_or_path": "work_dirs/reward_model/MultiGen20M_Hed/reward_ft5k_controlnet_sd15_hed_res512_bs256_lr1e-5_warmup100_scale-1_iter5k_fp16_train0-1k_reward0-200/checkpoint-5000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "transformer_layers_per_block": 1,
49
+ "upcast_attention": false,
50
+ "use_linear_projection": false
51
+ }
checkpoints/hed/controlnet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:721a7b3ec6b26bc73887f9f6d8a4fc175b01c785c0f986f3b7f15cd520cecf8e
3
+ size 1445260234
checkpoints/lineart/controlnet/config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.19.3",
4
+ "_name_or_path": "work_dirs/reward_model/MultiGen20M_LineDrawing/reward_ft5k_controlnet_sd15_lineart_res512_bs256_lr1e-5_warmup100_scale-10_iter5k_fp16_train0-1k_reward0-200/checkpoint-5000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "transformer_layers_per_block": 1,
49
+ "upcast_attention": false,
50
+ "use_linear_projection": false
51
+ }
checkpoints/lineart/controlnet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3eab52fe2f7a3e2ad7841eeb7ce2d78271869ced226681ee83d83b8fa22a163a
3
+ size 1445260234
checkpoints/seg/FCN_controlnet/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.26.3",
4
+ "_name_or_path": "work_dirs/finetune/Captioned_ADE20K/ft_controlnet_sd15_seg_res512_bs256_lr1e-5_warmup100_iter5k_fp16/checkpoint-5000/controlnet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "upcast_attention": false,
51
+ "use_linear_projection": false
52
+ }
checkpoints/seg/FCN_controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c38bf06cd30bf31b4458cea39c6488a8f95c5ea7b9b5503c368aa0fef81a4e8
3
+ size 1445157120
checkpoints/seg/controlnet/config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.19.3",
4
+ "_name_or_path": "work_dirs/finetune/Captioned_ADE20K/ft_controlnet_sd15_seg_res512_bs256_lr1e-5_warmup100_iter5k_fp16/checkpoint-5000/controlnet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 3,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "controlnet_conditioning_channel_order": "rgb",
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "transformer_layers_per_block": 1,
49
+ "upcast_attention": false,
50
+ "use_linear_projection": false
51
+ }
checkpoints/seg/controlnet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11da8bb3cc8f400d097e49136741085af2ac87dc3508edd3fb7e15d99d963d96
3
+ size 1445260234
images/canny_demo.jpg ADDED
model.py CHANGED
@@ -18,15 +18,15 @@ from preprocessor import Preprocessor
18
  from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
19
 
20
  CONTROLNET_MODEL_IDS = {
21
- "Canny": "../diffusers/work_dirs/reward_model/MultiGen20M_Canny/reward_ft5k_canny_res512_bs256_lr1e-5_warmup100_scale-10_iter10k_fp16_train0-1k_reward0-200_denormalized-img_gradients-with-threshold0.05-mse-loss/checkpoint-10000/controlnet",
22
 
23
- "softedge": "../diffusers/work_dirs/reward_model/MultiGen20M_Hed/reward_ft5k_controlnet_sd15_hed_res512_bs256_lr1e-5_warmup100_scale-1_iter10k_fp16_train0-1k_reward0-200/checkpoint-10000/controlnet",
24
 
25
- "segmentation": "../diffusers/work_dirs/reward_model/Captioned_ADE20K/reward_ft_controlnet_sd15_seg_res512_bs256_lr1e-5_warmup100_scale-0.5_iter5k_fp16_train0-1k_reward0-200_FCN-R101-d8/checkpoint-5000/controlnet",
26
 
27
- "depth": "../diffusers/work_dirs/reward_model/MultiGen20M_Depth/reward_ft5k_controlnet_sd15_depth_res512_bs256_lr1e-5_warmup100_scale-1.0_iter10k_fp16_train0-1k_reward0-200_mse-loss/checkpoint-10000/controlnet",
28
 
29
- "lineart": "../diffusers/work_dirs/reward_model/MultiGen20M_LineDrawing/reward_ft5k_controlnet_sd15_lineart_res512_bs256_lr1e-5_warmup100_scale-10_iter10k_fp16_train0-1k_reward0-200/checkpoint-10000/controlnet",
30
  }
31
 
32
 
 
18
  from settings import MAX_IMAGE_RESOLUTION, MAX_NUM_IMAGES
19
 
20
  CONTROLNET_MODEL_IDS = {
21
+ "Canny": "checkpoints/canny/controlnet",
22
 
23
+ "softedge": "checkpoints/hed/controlnet",
24
 
25
+ "segmentation": "checkpoints/seg/controlnet",
26
 
27
+ "depth": "checkpoints/depth/controlnet",
28
 
29
+ "lineart": "checkpoints/lineart/controlnet",
30
  }
31
 
32
 
preprocessor.py CHANGED
@@ -3,6 +3,7 @@ import gc
3
  import numpy as np
4
  import PIL.Image
5
  import torch
 
6
  from controlnet_aux import (
7
  CannyDetector,
8
  ContentShuffleDetector,
@@ -21,6 +22,32 @@ from cv_utils import resize_image
21
  from depth_estimator import DepthEstimator
22
  from image_segmentor import ImageSegmentor
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class Preprocessor:
26
  MODEL_ID = "lllyasviel/Annotators"
@@ -49,7 +76,7 @@ class Preprocessor:
49
  elif name == "LineartAnime":
50
  self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
51
  elif name == "Canny":
52
- self.model = CannyDetector()
53
  elif name == "ContentShuffle":
54
  self.model = ContentShuffleDetector()
55
  elif name == "DPT":
@@ -70,7 +97,7 @@ class Preprocessor:
70
  image = HWC3(image)
71
  image = resize_image(image, resolution=detect_resolution)
72
  image = self.model(image, **kwargs)
73
- return PIL.Image.fromarray(image)
74
  elif self.name == "Midas":
75
  detect_resolution = kwargs.pop("detect_resolution", 512)
76
  image_resolution = kwargs.pop("image_resolution", 512)
 
3
  import numpy as np
4
  import PIL.Image
5
  import torch
6
+ import torchvision
7
  from controlnet_aux import (
8
  CannyDetector,
9
  ContentShuffleDetector,
 
22
  from depth_estimator import DepthEstimator
23
  from image_segmentor import ImageSegmentor
24
 
25
+ from kornia.core import Tensor
26
+ from kornia.filters import canny
27
+
28
+
29
+ class Canny:
30
+
31
+ def __call__(
32
+ self,
33
+ images: np.array,
34
+ low_threshold: float = 0.1,
35
+ high_threshold: float = 0.2,
36
+ kernel_size: tuple[int, int] | int = (5, 5),
37
+ sigma: tuple[float, float] | Tensor = (1, 1),
38
+ hysteresis: bool = True,
39
+ eps: float = 1e-6
40
+ ) -> torch.Tensor:
41
+
42
+ assert low_threshold is not None, "low_threshold must be provided"
43
+ assert high_threshold is not None, "high_threshold must be provided"
44
+
45
+ images = torch.from_numpy(images).permute(2, 0, 1).unsqueeze(0) / 255.0
46
+
47
+ images_tensor = canny(images, low_threshold, high_threshold, kernel_size, sigma, hysteresis, eps)[1]
48
+ images_tensor = (images_tensor[0][0].numpy() * 255).astype(np.uint8)
49
+ return images_tensor
50
+
51
 
52
  class Preprocessor:
53
  MODEL_ID = "lllyasviel/Annotators"
 
76
  elif name == "LineartAnime":
77
  self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
78
  elif name == "Canny":
79
+ self.model = Canny()
80
  elif name == "ContentShuffle":
81
  self.model = ContentShuffleDetector()
82
  elif name == "DPT":
 
97
  image = HWC3(image)
98
  image = resize_image(image, resolution=detect_resolution)
99
  image = self.model(image, **kwargs)
100
+ return PIL.Image.fromarray(image).convert('RGB')
101
  elif self.name == "Midas":
102
  detect_resolution = kwargs.pop("detect_resolution", 512)
103
  image_resolution = kwargs.pop("image_resolution", 512)
settings.py CHANGED
@@ -7,7 +7,7 @@ DEFAULT_MODEL_ID = os.getenv("DEFAULT_MODEL_ID", "runwayml/stable-diffusion-v1-5
7
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "4"))
8
  DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "4")))
9
  MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "768"))
10
- DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "768")))
11
 
12
  ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
13
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"
 
7
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "4"))
8
  DEFAULT_NUM_IMAGES = min(MAX_NUM_IMAGES, int(os.getenv("DEFAULT_NUM_IMAGES", "4")))
9
  MAX_IMAGE_RESOLUTION = int(os.getenv("MAX_IMAGE_RESOLUTION", "768"))
10
+ DEFAULT_IMAGE_RESOLUTION = min(MAX_IMAGE_RESOLUTION, int(os.getenv("DEFAULT_IMAGE_RESOLUTION", "512")))
11
 
12
  ALLOW_CHANGING_BASE_MODEL = os.getenv("SPACE_ID") != "hysts/ControlNet-v1-1"
13
  SHOW_DUPLICATE_BUTTON = os.getenv("SHOW_DUPLICATE_BUTTON") == "1"