Spaces:
Runtime error
Runtime error
Convert from local app to app suitable for huggingface space
Browse files
app.py
CHANGED
@@ -1,98 +1,100 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
-
import
|
4 |
-
from
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
import
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
import
|
13 |
-
from
|
14 |
-
|
15 |
-
from
|
16 |
-
from
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
'
|
95 |
-
|
|
|
|
|
96 |
|
97 |
clip_model = load_clip_model()
|
98 |
|
@@ -139,7 +141,7 @@ class Blocks(gr.Blocks):
|
|
139 |
}
|
140 |
|
141 |
super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
|
142 |
-
|
143 |
def get_config_file(self):
|
144 |
config = super(Blocks, self).get_config_file()
|
145 |
|
@@ -204,17 +206,20 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
|
|
204 |
inpainting_boxes_nodrop = inpainting_boxes_nodrop,
|
205 |
)
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
218 |
|
219 |
|
220 |
def draw_box(boxes=[], texts=[], img=None):
|
@@ -264,6 +269,14 @@ def generate(task, language_instruction, grounding_texts, sketch_pad,
|
|
264 |
boxes = state['boxes']
|
265 |
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
|
266 |
assert len(boxes) == len(grounding_texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
boxes = (np.asarray(boxes) / 512).tolist()
|
268 |
grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
|
269 |
|
@@ -446,32 +459,21 @@ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
|
|
446 |
state = {}
|
447 |
return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
|
448 |
|
449 |
-
css = """
|
450 |
-
#
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
}
|
460 |
-
#
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
min-height: var(--height) !important;
|
465 |
-
}
|
466 |
-
#mirrors a:hover {
|
467 |
-
cursor:pointer;
|
468 |
-
}
|
469 |
-
#paper-info a {
|
470 |
-
color:#008AD7;
|
471 |
-
}
|
472 |
-
#paper-info a:hover {
|
473 |
-
cursor: pointer;
|
474 |
-
}
|
475 |
"""
|
476 |
|
477 |
rescale_js = """
|
@@ -660,7 +662,7 @@ with Blocks(
|
|
660 |
use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition", visible=False)
|
661 |
style_cond_image = gr.Image(type="pil", label="Style Condition", interactive=True, visible=False)
|
662 |
with gr.Column(scale=4):
|
663 |
-
gr.
|
664 |
with gr.Row():
|
665 |
out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
|
666 |
out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
|
@@ -800,4 +802,4 @@ with Blocks(
|
|
800 |
queue=False)
|
801 |
|
802 |
main.queue(concurrency_count=1, api_open=False)
|
803 |
-
main.launch(share=False, show_api=False)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
from gligen.task_grounded_generation import grounded_generation_box, load_ckpt, load_common_ckpt
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
from functools import partial
|
9 |
+
from collections import Counter
|
10 |
+
import math
|
11 |
+
import gc
|
12 |
+
from gradio import processing_utils
|
13 |
+
from typing import Optional
|
14 |
+
import warnings
|
15 |
+
from datetime import datetime
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
|
18 |
+
import sys
|
19 |
+
sys.tracebacklimit = 0
|
20 |
+
def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin', subfolder=None):
|
21 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
|
22 |
+
return torch.load(cache_file, map_location='cpu')
|
23 |
+
def load_ckpt_config_from_hf(modality):
|
24 |
+
ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='model')
|
25 |
+
config = load_from_hf('gligen/demo_ckpts_legacy', filename=f'{modality}.pth', subfolder='config')
|
26 |
+
return ckpt, config
|
27 |
+
def ckpt_load_helper(modality, is_inpaint, is_style, common_instances=None):
|
28 |
+
pretrained_ckpt_gligen, config = load_ckpt_config_from_hf(modality)
|
29 |
+
config = OmegaConf.create( config["_content"] ) # config used in training
|
30 |
+
config.alpha_scale = 1.0
|
31 |
+
config.model['params']['is_inpaint'] = is_inpaint
|
32 |
+
config.model['params']['is_style'] = is_style
|
33 |
+
if common_instances is None:
|
34 |
+
common_ckpt = load_from_hf('gligen/demo_ckpts_legacy', filename=f'common.pth', subfolder='model')
|
35 |
+
common_instances = load_common_ckpt(config, common_ckpt)
|
36 |
+
loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen, common_instances)
|
37 |
+
return loaded_model_list, common_instances
|
38 |
+
class Instance:
|
39 |
+
def __init__(self, capacity = 2):
|
40 |
+
self.model_type = 'base'
|
41 |
+
self.loaded_model_list = {}
|
42 |
+
self.counter = Counter()
|
43 |
+
self.global_counter = Counter()
|
44 |
+
self.loaded_model_list['base'], self.common_instances = ckpt_load_helper(
|
45 |
+
'gligen-generation-text-box',
|
46 |
+
is_inpaint=False, is_style=False, common_instances=None
|
47 |
+
)
|
48 |
+
self.capacity = capacity
|
49 |
+
def _log(self, model_type, batch_size, instruction, phrase_list):
|
50 |
+
self.counter[model_type] += 1
|
51 |
+
self.global_counter[model_type] += 1
|
52 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
53 |
+
print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
|
54 |
+
current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
|
55 |
+
))
|
56 |
+
def get_model(self, model_type, batch_size, instruction, phrase_list):
|
57 |
+
if model_type in self.loaded_model_list:
|
58 |
+
self._log(model_type, batch_size, instruction, phrase_list)
|
59 |
+
return self.loaded_model_list[model_type]
|
60 |
+
if self.capacity == len(self.loaded_model_list):
|
61 |
+
least_used_type = self.counter.most_common()[-1][0]
|
62 |
+
del self.loaded_model_list[least_used_type]
|
63 |
+
del self.counter[least_used_type]
|
64 |
+
gc.collect()
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
self.loaded_model_list[model_type] = self._get_model(model_type)
|
67 |
+
self._log(model_type, batch_size, instruction, phrase_list)
|
68 |
+
return self.loaded_model_list[model_type]
|
69 |
+
def _get_model(self, model_type):
|
70 |
+
if model_type == 'base':
|
71 |
+
return ckpt_load_helper(
|
72 |
+
'gligen-generation-text-box',
|
73 |
+
is_inpaint=False, is_style=False, common_instances=self.common_instances
|
74 |
+
)[0]
|
75 |
+
elif model_type == 'inpaint':
|
76 |
+
return ckpt_load_helper(
|
77 |
+
'gligen-inpainting-text-box',
|
78 |
+
is_inpaint=True, is_style=False, common_instances=self.common_instances
|
79 |
+
)[0]
|
80 |
+
elif model_type == 'style':
|
81 |
+
return ckpt_load_helper(
|
82 |
+
'gligen-generation-text-image-box',
|
83 |
+
is_inpaint=False, is_style=True, common_instances=self.common_instances
|
84 |
+
)[0]
|
85 |
+
|
86 |
+
assert False
|
87 |
+
instance = Instance()
|
88 |
+
def load_clip_model():
|
89 |
+
from transformers import CLIPProcessor, CLIPModel
|
90 |
+
version = "openai/clip-vit-large-patch14"
|
91 |
+
model = CLIPModel.from_pretrained(version).cuda()
|
92 |
+
processor = CLIPProcessor.from_pretrained(version)
|
93 |
+
return {
|
94 |
+
'version': version,
|
95 |
+
'model': model,
|
96 |
+
'processor': processor,
|
97 |
+
}
|
98 |
|
99 |
clip_model = load_clip_model()
|
100 |
|
|
|
141 |
}
|
142 |
|
143 |
super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
|
144 |
+
warnings.filterwarnings("ignore")
|
145 |
def get_config_file(self):
|
146 |
config = super(Blocks, self).get_config_file()
|
147 |
|
|
|
206 |
inpainting_boxes_nodrop = inpainting_boxes_nodrop,
|
207 |
)
|
208 |
|
209 |
+
get_model = partial(instance.get_model,
|
210 |
+
batch_size=batch_size,
|
211 |
+
instruction=language_instruction,
|
212 |
+
phrase_list=phrase_list)
|
213 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
214 |
+
if task == 'Grounded Generation':
|
215 |
+
if style_image == None:
|
216 |
+
return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
|
217 |
+
else:
|
218 |
+
return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
|
219 |
+
elif task == 'Grounded Inpainting':
|
220 |
+
assert image is not None
|
221 |
+
instruction['input_image'] = image.convert("RGB")
|
222 |
+
return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
|
223 |
|
224 |
|
225 |
def draw_box(boxes=[], texts=[], img=None):
|
|
|
269 |
boxes = state['boxes']
|
270 |
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
|
271 |
assert len(boxes) == len(grounding_texts)
|
272 |
+
if len(boxes) != len(grounding_texts):
|
273 |
+
if len(boxes) < len(grounding_texts):
|
274 |
+
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
|
275 |
+
Number of boxes drawn: {}, number of grounding tokens: {}.
|
276 |
+
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
|
277 |
+
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
|
278 |
+
|
279 |
+
|
280 |
boxes = (np.asarray(boxes) / 512).tolist()
|
281 |
grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
|
282 |
|
|
|
459 |
state = {}
|
460 |
return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
|
461 |
|
462 |
+
css = """
|
463 |
+
#img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
|
464 |
+
{
|
465 |
+
height: var(--height) !important;
|
466 |
+
max-height: var(--height) !important;
|
467 |
+
min-height: var(--height) !important;
|
468 |
+
}
|
469 |
+
#paper-info a {
|
470 |
+
color:#008AD7;
|
471 |
+
text-decoration: none;
|
472 |
+
}
|
473 |
+
#paper-info a:hover {
|
474 |
+
cursor: pointer;
|
475 |
+
text-decoration: none;
|
476 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
"""
|
478 |
|
479 |
rescale_js = """
|
|
|
662 |
use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition", visible=False)
|
663 |
style_cond_image = gr.Image(type="pil", label="Style Condition", interactive=True, visible=False)
|
664 |
with gr.Column(scale=4):
|
665 |
+
gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated Images</span>')
|
666 |
with gr.Row():
|
667 |
out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
|
668 |
out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
|
|
|
802 |
queue=False)
|
803 |
|
804 |
main.queue(concurrency_count=1, api_open=False)
|
805 |
+
main.launch(share=False, show_api=False, show_error=True)
|