jennysun commited on
Commit
b96cf5d
Β·
1 Parent(s): 5037c92

Convert from local app to app suitable for huggingface space

Browse files
Files changed (1) hide show
  1. app.py +137 -135
app.py CHANGED
@@ -1,98 +1,100 @@
1
- import gradio as gr
2
- import torch
3
- import argparse
4
- from omegaconf import OmegaConf
5
- from gligen.task_grounded_generation import grounded_generation_box, load_ckpt
6
- from ldm.util import default_device
7
-
8
- import json
9
- import numpy as np
10
- from PIL import Image, ImageDraw, ImageFont
11
- from functools import partial
12
- import math
13
- from contextlib import nullcontext
14
-
15
- from gradio import processing_utils
16
- from typing import Optional
17
-
18
- from huggingface_hub import hf_hub_download
19
- hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
20
-
21
- import openai
22
- from gradio.components import Textbox, Text
23
- import os
24
-
25
- arg_bool = lambda x: x.lower() == 'true'
26
- device = default_device()
27
-
28
- print(f"GLIGEN uses {device.upper()} device.")
29
- if device == "cpu":
30
- print("It will be sloooow. Consider using GPU support with CUDA or (in case of M1/M2 Apple Silicon) MPS.")
31
- elif device == "mps":
32
- print("The fastest you can get on M1/2 Apple Silicon. Yet, still many opimizations are switched off and it will is much slower than CUDA.")
33
-
34
- def parse_option():
35
- parser = argparse.ArgumentParser('GLIGen Demo', add_help=False)
36
- parser.add_argument("--folder", type=str, default="create_samples", help="path to OUTPUT")
37
- parser.add_argument("--official_ckpt", type=str, default='ckpts/sd-v1-4.ckpt', help="")
38
- parser.add_argument("--guidance_scale", type=float, default=5, help="")
39
- parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
40
- parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
41
- parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=False, help="Load text-box inpainting pipeline.")
42
- parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=False, help="Load text-image-box generation pipeline.")
43
- args = parser.parse_args()
44
- return args
45
- args = parse_option()
46
-
47
-
48
- def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin'):
49
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
50
- return torch.load(cache_file, map_location='cpu')
51
-
52
- def load_ckpt_config_from_hf(modality):
53
- ckpt = load_from_hf(f'gligen/{modality}')
54
- config = load_from_hf('gligen/demo_config_legacy', filename=f'{modality}.pth')
55
- return ckpt, config
56
-
57
-
58
- if args.load_text_box_generation:
59
- pretrained_ckpt_gligen, config = load_ckpt_config_from_hf('gligen-generation-text-box')
60
- config = OmegaConf.create( config["_content"] ) # config used in training
61
- config.update( vars(args) )
62
- config.model['params']['is_inpaint'] = False
63
- config.model['params']['is_style'] = False
64
- loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen)
65
-
66
-
67
- if args.load_text_box_inpainting:
68
- pretrained_ckpt_gligen_inpaint, config = load_ckpt_config_from_hf('gligen-inpainting-text-box')
69
- config = OmegaConf.create( config["_content"] ) # config used in training
70
- config.update( vars(args) )
71
- config.model['params']['is_inpaint'] = True
72
- config.model['params']['is_style'] = False
73
- loaded_model_list_inpaint = load_ckpt(config, pretrained_ckpt_gligen_inpaint)
74
-
75
-
76
- if args.load_text_image_box_generation:
77
- pretrained_ckpt_gligen_style, config = load_ckpt_config_from_hf('gligen-generation-text-image-box')
78
- config = OmegaConf.create( config["_content"] ) # config used in training
79
- config.update( vars(args) )
80
- config.model['params']['is_inpaint'] = False
81
- config.model['params']['is_style'] = True
82
- loaded_model_list_style = load_ckpt(config, pretrained_ckpt_gligen_style)
83
-
84
-
85
- def load_clip_model():
86
- from transformers import CLIPProcessor, CLIPModel
87
- version = "openai/clip-vit-large-patch14"
88
- model = CLIPModel.from_pretrained(version).to(device)
89
- processor = CLIPProcessor.from_pretrained(version)
90
-
91
- return {
92
- 'version': version,
93
- 'model': model,
94
- 'processor': processor,
95
- }
 
 
96
 
97
  clip_model = load_clip_model()
98
 
@@ -139,7 +141,7 @@ class Blocks(gr.Blocks):
139
  }
140
 
141
  super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
142
-
143
  def get_config_file(self):
144
  config = super(Blocks, self).get_config_file()
145
 
@@ -204,17 +206,20 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
204
  inpainting_boxes_nodrop = inpainting_boxes_nodrop,
205
  )
206
 
207
- # float16 autocasting only CUDA device
208
- with torch.autocast(device_type='cuda', dtype=torch.float16) if device == "cuda" else nullcontext():
209
- if task == 'Grounded Generation':
210
- if style_image == None:
211
- return grounded_generation_box(loaded_model_list, instruction, *args, **kwargs)
212
- else:
213
- return grounded_generation_box(loaded_model_list_style, instruction, *args, **kwargs)
214
- elif task == 'Grounded Inpainting':
215
- assert image is not None
216
- instruction['input_image'] = image.convert("RGB")
217
- return grounded_generation_box(loaded_model_list_inpaint, instruction, *args, **kwargs)
 
 
 
218
 
219
 
220
  def draw_box(boxes=[], texts=[], img=None):
@@ -264,6 +269,14 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
264
  boxes = state['boxes']
265
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
266
  assert len(boxes) == len(grounding_texts)
 
 
 
 
 
 
 
 
267
  boxes = (np.asarray(boxes) / 512).tolist()
268
  grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
269
 
@@ -446,32 +459,21 @@ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
446
  state = {}
447
  return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
448
 
449
- css = """
450
- #generate-btn {
451
- --tw-border-opacity: 1;
452
- border-color: rgb(255 216 180 / var(--tw-border-opacity));
453
- --tw-gradient-from: rgb(255 216 180 / .7);
454
- --tw-gradient-to: rgb(255 216 180 / 0);
455
- --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
456
- --tw-gradient-to: rgb(255 176 102 / .8);
457
- --tw-text-opacity: 1;
458
- color: rgb(238 116 0 / var(--tw-text-opacity));
459
- }
460
- #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img
461
- {
462
- height: var(--height) !important;
463
- max-height: var(--height) !important;
464
- min-height: var(--height) !important;
465
- }
466
- #mirrors a:hover {
467
- cursor:pointer;
468
- }
469
- #paper-info a {
470
- color:#008AD7;
471
- }
472
- #paper-info a:hover {
473
- cursor: pointer;
474
- }
475
  """
476
 
477
  rescale_js = """
@@ -660,7 +662,7 @@ with Blocks(
660
  use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition", visible=False)
661
  style_cond_image = gr.Image(type="pil", label="Style Condition", interactive=True, visible=False)
662
  with gr.Column(scale=4):
663
- gr.Markdown("### Generated Images")
664
  with gr.Row():
665
  out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
666
  out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
@@ -800,4 +802,4 @@ with Blocks(
800
  queue=False)
801
 
802
  main.queue(concurrency_count=1, api_open=False)
803
- main.launch(share=False, show_api=False)
 
1
+ import gradio as gr
2
+ import torch
3
+ from omegaconf import OmegaConf
4
+ from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
5
+ import json
6
+ import numpy as np
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ from functools import partial
9
+ from collections import Counter
10
+ import math
11
+ import gc
12
+ from gradio import processing_utils
13
+ from typing import Optional
14
+ import warnings
15
+ from datetime import datetime
16
+ from huggingface_hub import hf_hub_download
17
+ hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
18
+ import sys
19
+ sys.tracebacklimit = 0
20
+ def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
21
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
22
+ return torch.load(cache_file, map_location='cpu')
23
+ def load_ckpt_config_from_hf(modality):
24
+ ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
25
+ config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
26
+ return ckpt, config
27
+ def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
28
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
29
+ config = OmegaConf.create( config["_content"] ) # config used in training
30
+ config.alpha_scale = 1.0
31
+ config.model['params']['is_inpaint'] = is_inpaint
32
+ config.model['params']['is_style'] = is_style
33
+ if common_instances is None:
34
+ common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
35
+ common_instances = load_common_ckpt(config, common_ckpt)
36
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
37
+ return loaded_model_list, common_instances
38
+ class Instance:
39
+ def __init__(self, capacity = 2):
40
+ self.model_type = 'base'
41
+ self.loaded_model_list = {}
42
+ self.counter = Counter()
43
+ self.global_counter = Counter()
44
+ self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
45
+ 'gligen-generation-text-box',
46
+ is_inpaint=False, is_style=False, common_instances=None
47
+ )
48
+ self.capacity = capacity
49
+ def _log(self, model_type, batch_size, instruction, phrase_list):
50
+ self.counter[model_type] += 1
51
+ self.global_counter[model_type] += 1
52
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
53
+ print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
54
+ current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
55
+ ))
56
+ def get_model(self, model_type, batch_size, instruction, phrase_list):
57
+ if model_type in self.loaded_model_list:
58
+ self._log(model_type, batch_size, instruction, phrase_list)
59
+ return self.loaded_model_list[model_type]
60
+ if self.capacity == len(self.loaded_model_list):
61
+ least_used_type = self.counter.most_common()[-1][0]
62
+ del self.loaded_model_list[least_used_type]
63
+ del self.counter[least_used_type]
64
+ gc.collect()
65
+ torch.cuda.empty_cache()
66
+ self.loaded_model_list[model_type] = self._get_model(model_type)
67
+ self._log(model_type, batch_size, instruction, phrase_list)
68
+ return self.loaded_model_list[model_type]
69
+ def _get_model(self, model_type):
70
+ if model_type == 'base':
71
+ return ckpt_load_helper(
72
+ 'gligen-generation-text-box',
73
+ is_inpaint=False, is_style=False, common_instances=self.common_instances
74
+ )[0]
75
+ elif model_type == 'inpaint':
76
+ return ckpt_load_helper(
77
+ 'gligen-inpainting-text-box',
78
+ is_inpaint=True, is_style=False, common_instances=self.common_instances
79
+ )[0]
80
+ elif model_type == 'style':
81
+ return ckpt_load_helper(
82
+ 'gligen-generation-text-image-box',
83
+ is_inpaint=False, is_style=True, common_instances=self.common_instances
84
+ )[0]
85
+
86
+ assert False
87
+ instance = Instance()
88
+ def load_clip_model():
89
+ from transformers import CLIPProcessor, CLIPModel
90
+ version = "openai/clip-vit-large-patch14"
91
+ model = CLIPModel.from_pretrained(version).cuda()
92
+ processor = CLIPProcessor.from_pretrained(version)
93
+ return {
94
+ 'version': version,
95
+ 'model': model,
96
+ 'processor': processor,
97
+ }
98
 
99
  clip_model = load_clip_model()
100
 
 
141
  }
142
 
143
  super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
144
+ warnings.filterwarnings("ignore")
145
  def get_config_file(self):
146
  config = super(Blocks, self).get_config_file()
147
 
 
206
  inpainting_boxes_nodrop = inpainting_boxes_nodrop,
207
  )
208
 
209
+ get_model = partial(instance.get_model,
210
+ batch_size=batch_size,
211
+ instruction=language_instruction,
212
+ phrase_list=phrase_list)
213
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
214
+ if task == 'Grounded Generation':
215
+ if style_image == None:
216
+ return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
217
+ else:
218
+ return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
219
+ elif task == 'Grounded Inpainting':
220
+ assert image is not None
221
+ instruction['input_image'] = image.convert("RGB")
222
+ return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
223
 
224
 
225
  def draw_box(boxes=[], texts=[], img=None):
 
269
  boxes = state['boxes']
270
  grounding_texts = [x.strip() for x in grounding_texts.split(';')]
271
  assert len(boxes) == len(grounding_texts)
272
+ if len(boxes) != len(grounding_texts):
273
+ if len(boxes) < len(grounding_texts):
274
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
275
+ Number of boxes drawn: {}, number of grounding tokens: {}.
276
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
277
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
278
+
279
+
280
  boxes = (np.asarray(boxes) / 512).tolist()
281
  grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
282
 
 
459
  state = {}
460
  return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
461
 
462
+ css = """
463
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
464
+ {
465
+ height: var(--height) !important;
466
+ max-height: var(--height) !important;
467
+ min-height: var(--height) !important;
468
+ }
469
+ #paper-info a {
470
+ color:#008AD7;
471
+ text-decoration: none;
472
+ }
473
+ #paper-info a:hover {
474
+ cursor: pointer;
475
+ text-decoration: none;
476
+ }
 
 
 
 
 
 
 
 
 
 
 
477
  """
478
 
479
  rescale_js = """
 
662
  use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition", visible=False)
663
  style_cond_image = gr.Image(type="pil", label="Style Condition", interactive=True, visible=False)
664
  with gr.Column(scale=4):
665
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
666
  with gr.Row():
667
  out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
668
  out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
 
802
  queue=False)
803
 
804
  main.queue(concurrency_count=1, api_open=False)
805
+ main.launch(share=False, show_api=False, show_error=True)