kottu commited on
Commit
b758fc3
β€’
1 Parent(s): 265bfb8

Create app_sketch.py

Browse files
Files changed (1) hide show
  1. app_sketch.py +165 -0
app_sketch.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import PIL.Image
4
+ import torch
5
+ import torchvision.transforms.functional as TF
6
+
7
+ from model import Model
8
+ from utils import (
9
+ DEFAULT_STYLE_NAME,
10
+ MAX_SEED,
11
+ STYLE_NAMES,
12
+ apply_style,
13
+ randomize_seed_fn,
14
+ )
15
+
16
+
17
+ def create_demo(model: Model) -> gr.Blocks:
18
+ def run(
19
+ image: PIL.Image.Image,
20
+ prompt: str,
21
+ negative_prompt: str,
22
+ style_name: str = DEFAULT_STYLE_NAME,
23
+ num_steps: int = 25,
24
+ guidance_scale: float = 5,
25
+ adapter_conditioning_scale: float = 0.8,
26
+ adapter_conditioning_factor: float = 0.8,
27
+ seed: int = 0,
28
+ progress=gr.Progress(track_tqdm=True),
29
+ ) -> PIL.Image.Image:
30
+ image = image.convert("RGB")
31
+ image = TF.to_tensor(image) > 0.5
32
+ image = TF.to_pil_image(image.to(torch.float32))
33
+
34
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
35
+
36
+ return model.run(
37
+ image=image,
38
+ prompt=prompt,
39
+ negative_prompt=negative_prompt,
40
+ adapter_name="sketch",
41
+ num_inference_steps=num_steps,
42
+ guidance_scale=guidance_scale,
43
+ adapter_conditioning_scale=adapter_conditioning_scale,
44
+ adapter_conditioning_factor=adapter_conditioning_factor,
45
+ seed=seed,
46
+ apply_preprocess=False,
47
+ )[1]
48
+
49
+ with gr.Blocks() as demo:
50
+ with gr.Row():
51
+ with gr.Column():
52
+ with gr.Group():
53
+ image = gr.Image(
54
+ source="canvas",
55
+ tool="sketch",
56
+ type="pil",
57
+ image_mode="L",
58
+ invert_colors=True,
59
+ shape=(1024, 1024),
60
+ brush_radius=4,
61
+ height=600,
62
+ )
63
+ prompt = gr.Textbox(label="Prompt")
64
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
65
+ run_button = gr.Button("Run")
66
+ with gr.Accordion("Advanced options", open=False):
67
+ negative_prompt = gr.Textbox(
68
+ label="Negative prompt",
69
+ value=" extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured",
70
+ )
71
+ num_steps = gr.Slider(
72
+ label="Number of steps",
73
+ minimum=1,
74
+ maximum=50,
75
+ step=1,
76
+ value=25,
77
+ )
78
+ guidance_scale = gr.Slider(
79
+ label="Guidance scale",
80
+ minimum=0.1,
81
+ maximum=10.0,
82
+ step=0.1,
83
+ value=5,
84
+ )
85
+ adapter_conditioning_scale = gr.Slider(
86
+ label="Adapter conditioning scale",
87
+ minimum=0.5,
88
+ maximum=1,
89
+ step=0.1,
90
+ value=0.8,
91
+ )
92
+ adapter_conditioning_factor = gr.Slider(
93
+ label="Adapter conditioning factor",
94
+ info="Fraction of timesteps for which adapter should be applied",
95
+ minimum=0.5,
96
+ maximum=1,
97
+ step=0.1,
98
+ value=0.8,
99
+ )
100
+ seed = gr.Slider(
101
+ label="Seed",
102
+ minimum=0,
103
+ maximum=MAX_SEED,
104
+ step=1,
105
+ value=0,
106
+ )
107
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
108
+ with gr.Column():
109
+ result = gr.Image(label="Result", height=600)
110
+
111
+ inputs = [
112
+ image,
113
+ prompt,
114
+ negative_prompt,
115
+ style,
116
+ num_steps,
117
+ guidance_scale,
118
+ adapter_conditioning_scale,
119
+ adapter_conditioning_factor,
120
+ seed,
121
+ ]
122
+ prompt.submit(
123
+ fn=randomize_seed_fn,
124
+ inputs=[seed, randomize_seed],
125
+ outputs=seed,
126
+ queue=False,
127
+ api_name=False,
128
+ ).then(
129
+ fn=run,
130
+ inputs=inputs,
131
+ outputs=result,
132
+ api_name=False,
133
+ )
134
+ negative_prompt.submit(
135
+ fn=randomize_seed_fn,
136
+ inputs=[seed, randomize_seed],
137
+ outputs=seed,
138
+ queue=False,
139
+ api_name=False,
140
+ ).then(
141
+ fn=run,
142
+ inputs=inputs,
143
+ outputs=result,
144
+ api_name=False,
145
+ )
146
+ run_button.click(
147
+ fn=randomize_seed_fn,
148
+ inputs=[seed, randomize_seed],
149
+ outputs=seed,
150
+ queue=False,
151
+ api_name=False,
152
+ ).then(
153
+ fn=run,
154
+ inputs=inputs,
155
+ outputs=result,
156
+ api_name=False,
157
+ )
158
+
159
+ return demo
160
+
161
+
162
+ if __name__ == "__main__":
163
+ model = Model("sketch")
164
+ demo = create_demo(model)
165
+ demo.queue(max_size=20).launch()