tfwang commited on
Commit
1066174
1 Parent(s): f057f04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -227
app.py CHANGED
@@ -1,227 +1,228 @@
1
- """
2
- Train a diffusion model on images.
3
- """
4
- import gradio as gr
5
- import argparse
6
- from einops import rearrange
7
- from glide_text2im import dist_util, logger
8
- from torchvision.utils import make_grid
9
- from glide_text2im.script_util import (
10
- model_and_diffusion_defaults,
11
- create_model_and_diffusion,
12
- args_to_dict,
13
- add_dict_to_argparser,
14
- )
15
- from glide_text2im.image_datasets_sketch import get_tensor
16
- from glide_text2im.train_util import TrainLoop
17
- from glide_text2im.glide_util import sample
18
- import torch
19
- import os
20
- import torch as th
21
- import torchvision.utils as tvu
22
- import torch.distributed as dist
23
- from PIL import Image
24
- import cv2
25
- import numpy as np
26
-
27
- def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
28
- parser, parser_up = create_argparser()
29
-
30
- args = parser.parse_args()
31
- args_up = parser_up.parse_args()
32
- dist_util.setup_dist()
33
-
34
- if mode == 'sketch':
35
- args.mode = 'coco-edge'
36
- args_up.mode = 'coco-edge'
37
- args.model_path = './ckpt/base_edge.pt'
38
- args.sr_model_path = './ckpt/upsample_edge.pt'
39
-
40
- elif mode == 'mask':
41
- args.mode = 'coco'
42
- args_up.mode = 'coco'
43
- args.model_path = './ckpt/base_mask.pt'
44
- args.sr_model_path = './ckpt/upsample_mask.pt'
45
-
46
-
47
- args.val_data_dir = image
48
- args.sample_c = sample_c
49
- args.num_samples = num_samples
50
-
51
-
52
- options=args_to_dict(args, model_and_diffusion_defaults(0.).keys())
53
- model, diffusion = create_model_and_diffusion(**options)
54
-
55
- options_up=args_to_dict(args_up, model_and_diffusion_defaults(True).keys())
56
- model_up, diffusion_up = create_model_and_diffusion(**options_up)
57
-
58
-
59
- if args.model_path:
60
- print('loading model')
61
- model_ckpt = dist_util.load_state_dict(args.model_path, map_location="cpu")
62
-
63
- model.load_state_dict(
64
- model_ckpt , strict=True )
65
-
66
- if args.sr_model_path:
67
- print('loading sr model')
68
- model_ckpt2 = dist_util.load_state_dict(args.sr_model_path, map_location="cpu")
69
-
70
- model_up.load_state_dict(
71
- model_ckpt2 , strict=True )
72
-
73
-
74
- model.to(dist_util.dev())
75
- model_up.to(dist_util.dev())
76
- model.eval()
77
- model_up.eval()
78
-
79
- ########### dataset
80
- # logger.log("creating data loader...")
81
-
82
- if args.mode == 'coco':
83
- pil_image = image
84
- label_pil = pil_image.convert("RGB").resize((256, 256), Image.NEAREST)
85
- label_tensor = get_tensor()(label_pil)
86
-
87
- data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
88
-
89
- elif args.mode == 'coco-edge':
90
- # pil_image = Image.open(image)
91
- pil_image = image
92
- label_pil = pil_image.convert("L").resize((256, 256), Image.NEAREST)
93
-
94
- im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3)
95
- im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8)
96
- im_dist = Image.fromarray(im_dist).convert("RGB")
97
-
98
- label_tensor = get_tensor()(im_dist)[:1]
99
-
100
- data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
101
-
102
-
103
-
104
- print("sampling...")
105
-
106
-
107
- sampled_imgs = []
108
- grid_imgs = []
109
- img_id = 0
110
- while (True):
111
- if img_id >= args.num_samples:
112
- break
113
-
114
- model_kwargs = data_dict
115
- with th.no_grad():
116
- samples_lr =sample(
117
- glide_model= model,
118
- glide_options= options,
119
- side_x= 64,
120
- side_y= 64,
121
- prompt=model_kwargs,
122
- batch_size= args.num_samples,
123
- guidance_scale=args.sample_c,
124
- device=dist_util.dev(),
125
- prediction_respacing= str(sample_step),
126
- upsample_enabled= False,
127
- upsample_temp=0.997,
128
- mode = args.mode,
129
- )
130
-
131
- samples_lr = samples_lr.clamp(-1, 1)
132
-
133
- tmp = (127.5*(samples_lr + 1.0)).int()
134
- model_kwargs['low_res'] = tmp/127.5 - 1.
135
-
136
- samples_hr =sample(
137
- glide_model= model_up,
138
- glide_options= options_up,
139
- side_x=256,
140
- side_y=256,
141
- prompt=model_kwargs,
142
- batch_size=args.num_samples,
143
- guidance_scale=1,
144
- device=dist_util.dev(),
145
- prediction_respacing= "fast27",
146
- upsample_enabled=True,
147
- upsample_temp=0.997,
148
- mode = args.mode,
149
- )
150
-
151
-
152
- samples_hr = samples_hr
153
-
154
-
155
- for hr in samples_hr:
156
-
157
- hr = 255. * rearrange((hr.cpu().numpy()+1.0)*0.5, 'c h w -> h w c')
158
- sample_img = Image.fromarray(hr.astype(np.uint8))
159
- sampled_imgs.append(sample_img)
160
- img_id += 1
161
-
162
- grid_imgs.append(samples_hr)
163
-
164
- grid = torch.stack(grid_imgs, 0)
165
- grid = rearrange(grid, 'n b c h w -> (n b) c h w')
166
- grid = make_grid(grid, nrow=2)
167
- # to image
168
- grid = 255. * rearrange((grid+1.0)*0.5, 'c h w -> h w c').cpu().numpy()
169
-
170
- return Image.fromarray(grid.astype(np.uint8))
171
-
172
-
173
- def create_argparser():
174
- defaults = dict(
175
- data_dir="",
176
- val_data_dir="",
177
- model_path="./base_edge.pt",
178
- sr_model_path="./upsample_edge.pt",
179
- encoder_path="",
180
- schedule_sampler="uniform",
181
- lr=1e-4,
182
- weight_decay=0.0,
183
- lr_anneal_steps=0,
184
- batch_size=2,
185
- microbatch=-1, # -1 disables microbatches
186
- ema_rate="0.9999", # comma-separated list of EMA values
187
- log_interval=100,
188
- save_interval=20000,
189
- resume_checkpoint="",
190
- use_fp16=False,
191
- fp16_scale_growth=1e-3,
192
- sample_c=1.,
193
- sample_respacing="100",
194
- uncond_p=0.2,
195
- num_samples=3,
196
- finetune_decoder = False,
197
- mode = '',
198
- )
199
-
200
- defaults_up = defaults
201
- defaults.update(model_and_diffusion_defaults())
202
- parser = argparse.ArgumentParser()
203
- add_dict_to_argparser(parser, defaults)
204
-
205
- defaults_up.update(model_and_diffusion_defaults(True))
206
- parser_up = argparse.ArgumentParser()
207
- add_dict_to_argparser(parser_up, defaults_up)
208
-
209
- return parser, parser_up
210
-
211
- image = gr.outputs.Image(type="pil", label="Sampled results")
212
- css = ".output-image{height: 528px !important} .output-carousel .output-image{height:272px !important} a{text-decoration: underline}"
213
- demo = gr.Interface(fn=run, inputs=[
214
- gr.inputs.Image(type="pil", label="Input Sketch" ) ,
215
- # gr.Image(image_mode="L", source="canvas", type="pil", shape=(256,256), invert_colors=False, tool="editor"),
216
- gr.inputs.Radio(label="Input Mode - The type of your input", choices=["mask", "sketch"],default="sketch"),
217
- gr.inputs.Slider(label="sample_c - The strength of classifier-free guidance",default=1.4, minimum=1.0, maximum=2.0),
218
- gr.inputs.Slider(label="Number of samples - How many samples you wish to generate", default=4, step=1, minimum=1, maximum=16),
219
- gr.inputs.Slider(label="Number of Steps - How many steps you want to use", default=100, step=10, minimum=50, maximum=1000),
220
- ],
221
- outputs=[image],
222
- css=css,
223
- title="Generate images from sketches with PITI",
224
- description="<div>By uploading a sketch map or a semantic map and pressing submit, you can generate images based on your input.</div>")
225
-
226
- demo.launch(enable_queue=True)
227
-
 
 
1
+ """
2
+ Train a diffusion model on images.
3
+ """
4
+ import gradio as gr
5
+ import argparse
6
+ from einops import rearrange
7
+ from glide_text2im import dist_util, logger
8
+ from torchvision.utils import make_grid
9
+ from glide_text2im.script_util import (
10
+ model_and_diffusion_defaults,
11
+ create_model_and_diffusion,
12
+ args_to_dict,
13
+ add_dict_to_argparser,
14
+ )
15
+ from glide_text2im.image_datasets_sketch import get_tensor
16
+ from glide_text2im.train_util import TrainLoop
17
+ from glide_text2im.glide_util import sample
18
+ import torch
19
+ import os
20
+ import torch as th
21
+ import torchvision.utils as tvu
22
+ import torch.distributed as dist
23
+ from PIL import Image
24
+ import cv2
25
+ import numpy as np
26
+ from huggingface_hub import hf_hub_download
27
+
28
+ def run(image, mode, sample_c=1.3, num_samples=3, sample_step=100):
29
+ parser, parser_up = create_argparser()
30
+
31
+ args = parser.parse_args()
32
+ args_up = parser_up.parse_args()
33
+ dist_util.setup_dist()
34
+
35
+ if mode == 'sketch':
36
+ args.mode = 'coco-edge'
37
+ args_up.mode = 'coco-edge'
38
+ args.model_path = hf_hub_download(repo_id="tfwang/PITI", filename="base_edge.pt")
39
+ args.sr_model_path = hf_hub_download(repo_id="tfwang/PITI", filename="upsample_edge.pt")
40
+
41
+ elif mode == 'mask':
42
+ args.mode = 'coco'
43
+ args_up.mode = 'coco'
44
+ args.model_path = hf_hub_download(repo_id="tfwang/PITI", filename="base_mask.pt")
45
+ args.sr_model_path = hf_hub_download(repo_id="tfwang/PITI", filename="upsample_mask.pt")
46
+
47
+
48
+ args.val_data_dir = image
49
+ args.sample_c = sample_c
50
+ args.num_samples = num_samples
51
+
52
+
53
+ options=args_to_dict(args, model_and_diffusion_defaults(0.).keys())
54
+ model, diffusion = create_model_and_diffusion(**options)
55
+
56
+ options_up=args_to_dict(args_up, model_and_diffusion_defaults(True).keys())
57
+ model_up, diffusion_up = create_model_and_diffusion(**options_up)
58
+
59
+
60
+ if args.model_path:
61
+ print('loading model')
62
+ model_ckpt = dist_util.load_state_dict(args.model_path, map_location="cpu")
63
+
64
+ model.load_state_dict(
65
+ model_ckpt , strict=True )
66
+
67
+ if args.sr_model_path:
68
+ print('loading sr model')
69
+ model_ckpt2 = dist_util.load_state_dict(args.sr_model_path, map_location="cpu")
70
+
71
+ model_up.load_state_dict(
72
+ model_ckpt2 , strict=True )
73
+
74
+
75
+ model.to(dist_util.dev())
76
+ model_up.to(dist_util.dev())
77
+ model.eval()
78
+ model_up.eval()
79
+
80
+ ########### dataset
81
+ # logger.log("creating data loader...")
82
+
83
+ if args.mode == 'coco':
84
+ pil_image = image
85
+ label_pil = pil_image.convert("RGB").resize((256, 256), Image.NEAREST)
86
+ label_tensor = get_tensor()(label_pil)
87
+
88
+ data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
89
+
90
+ elif args.mode == 'coco-edge':
91
+ # pil_image = Image.open(image)
92
+ pil_image = image
93
+ label_pil = pil_image.convert("L").resize((256, 256), Image.NEAREST)
94
+
95
+ im_dist = cv2.distanceTransform(255-np.array(label_pil), cv2.DIST_L1, 3)
96
+ im_dist = np.clip((im_dist) , 0, 255).astype(np.uint8)
97
+ im_dist = Image.fromarray(im_dist).convert("RGB")
98
+
99
+ label_tensor = get_tensor()(im_dist)[:1]
100
+
101
+ data_dict = {"ref":label_tensor.unsqueeze(0).repeat(args.num_samples, 1, 1, 1)}
102
+
103
+
104
+
105
+ print("sampling...")
106
+
107
+
108
+ sampled_imgs = []
109
+ grid_imgs = []
110
+ img_id = 0
111
+ while (True):
112
+ if img_id >= args.num_samples:
113
+ break
114
+
115
+ model_kwargs = data_dict
116
+ with th.no_grad():
117
+ samples_lr =sample(
118
+ glide_model= model,
119
+ glide_options= options,
120
+ side_x= 64,
121
+ side_y= 64,
122
+ prompt=model_kwargs,
123
+ batch_size= args.num_samples,
124
+ guidance_scale=args.sample_c,
125
+ device=dist_util.dev(),
126
+ prediction_respacing= str(sample_step),
127
+ upsample_enabled= False,
128
+ upsample_temp=0.997,
129
+ mode = args.mode,
130
+ )
131
+
132
+ samples_lr = samples_lr.clamp(-1, 1)
133
+
134
+ tmp = (127.5*(samples_lr + 1.0)).int()
135
+ model_kwargs['low_res'] = tmp/127.5 - 1.
136
+
137
+ samples_hr =sample(
138
+ glide_model= model_up,
139
+ glide_options= options_up,
140
+ side_x=256,
141
+ side_y=256,
142
+ prompt=model_kwargs,
143
+ batch_size=args.num_samples,
144
+ guidance_scale=1,
145
+ device=dist_util.dev(),
146
+ prediction_respacing= "fast27",
147
+ upsample_enabled=True,
148
+ upsample_temp=0.997,
149
+ mode = args.mode,
150
+ )
151
+
152
+
153
+ samples_hr = samples_hr
154
+
155
+
156
+ for hr in samples_hr:
157
+
158
+ hr = 255. * rearrange((hr.cpu().numpy()+1.0)*0.5, 'c h w -> h w c')
159
+ sample_img = Image.fromarray(hr.astype(np.uint8))
160
+ sampled_imgs.append(sample_img)
161
+ img_id += 1
162
+
163
+ grid_imgs.append(samples_hr)
164
+
165
+ grid = torch.stack(grid_imgs, 0)
166
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
167
+ grid = make_grid(grid, nrow=2)
168
+ # to image
169
+ grid = 255. * rearrange((grid+1.0)*0.5, 'c h w -> h w c').cpu().numpy()
170
+
171
+ return Image.fromarray(grid.astype(np.uint8))
172
+
173
+
174
+ def create_argparser():
175
+ defaults = dict(
176
+ data_dir="",
177
+ val_data_dir="",
178
+ model_path="./base_edge.pt",
179
+ sr_model_path="./upsample_edge.pt",
180
+ encoder_path="",
181
+ schedule_sampler="uniform",
182
+ lr=1e-4,
183
+ weight_decay=0.0,
184
+ lr_anneal_steps=0,
185
+ batch_size=2,
186
+ microbatch=-1, # -1 disables microbatches
187
+ ema_rate="0.9999", # comma-separated list of EMA values
188
+ log_interval=100,
189
+ save_interval=20000,
190
+ resume_checkpoint="",
191
+ use_fp16=False,
192
+ fp16_scale_growth=1e-3,
193
+ sample_c=1.,
194
+ sample_respacing="100",
195
+ uncond_p=0.2,
196
+ num_samples=3,
197
+ finetune_decoder = False,
198
+ mode = '',
199
+ )
200
+
201
+ defaults_up = defaults
202
+ defaults.update(model_and_diffusion_defaults())
203
+ parser = argparse.ArgumentParser()
204
+ add_dict_to_argparser(parser, defaults)
205
+
206
+ defaults_up.update(model_and_diffusion_defaults(True))
207
+ parser_up = argparse.ArgumentParser()
208
+ add_dict_to_argparser(parser_up, defaults_up)
209
+
210
+ return parser, parser_up
211
+
212
+ image = gr.outputs.Image(type="pil", label="Sampled results")
213
+ css = ".output-image{height: 528px !important} .output-carousel .output-image{height:272px !important} a{text-decoration: underline}"
214
+ demo = gr.Interface(fn=run, inputs=[
215
+ gr.inputs.Image(type="pil", label="Input Sketch" ) ,
216
+ # gr.Image(image_mode="L", source="canvas", type="pil", shape=(256,256), invert_colors=False, tool="editor"),
217
+ gr.inputs.Radio(label="Input Mode - The type of your input", choices=["mask", "sketch"],default="sketch"),
218
+ gr.inputs.Slider(label="sample_c - The strength of classifier-free guidance",default=1.4, minimum=1.0, maximum=2.0),
219
+ gr.inputs.Slider(label="Number of samples - How many samples you wish to generate", default=4, step=1, minimum=1, maximum=16),
220
+ gr.inputs.Slider(label="Number of Steps - How many steps you want to use", default=100, step=10, minimum=50, maximum=1000),
221
+ ],
222
+ outputs=[image],
223
+ css=css,
224
+ title="Generate images from sketches with PITI",
225
+ description="<div>By uploading a sketch map or a semantic map and pressing submit, you can generate images based on your input.</div>")
226
+
227
+ demo.launch(enable_queue=True)
228
+