SkalskiP commited on
Commit
5d92a23
1 Parent(s): 867296e

image resolution dimensions divisible by 32 fix; advanced settings; debug mask mode

Browse files
Files changed (1) hide show
  1. app.py +105 -21
app.py CHANGED
@@ -1,6 +1,11 @@
1
- import torch
2
- import spaces
 
 
3
  import gradio as gr
 
 
 
4
  from diffusers import FluxInpaintPipeline
5
 
6
  MARKDOWN = """
@@ -11,39 +16,79 @@ creating this amazing model, and a big thanks to [Gothos](https://github.com/Got
11
  for taking it to the next level by enabling inpainting with the FLUX.
12
  """
13
 
 
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  pipe = FluxInpaintPipeline.from_pretrained(
17
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @spaces.GPU()
21
- def process(input_image_editor, input_text, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
22
  if not input_text:
23
  gr.Info("Please enter a text prompt.")
24
  return None
25
 
26
  image = input_image_editor['background']
27
- mask_image = input_image_editor['layers'][0]
28
 
29
  if not image:
30
  gr.Info("Please upload an image.")
31
  return None
32
 
33
- if not mask_image:
34
  gr.Info("Please draw a mask on the image.")
35
  return None
36
 
37
- width, height = image.size
 
 
38
 
 
 
 
39
  return pipe(
40
  prompt=input_text,
41
- image=image,
42
- mask_image=mask_image,
43
  width=width,
44
  height=height,
45
- strength=0.7
46
- ).images[0]
 
 
47
 
48
 
49
  with gr.Blocks() as demo:
@@ -57,27 +102,66 @@ with gr.Blocks() as demo:
57
  image_mode='RGB',
58
  layers=False,
59
  brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
60
- input_text_component = gr.Text(
61
- label="Prompt",
62
- show_label=False,
63
- max_lines=1,
64
- placeholder="Enter your prompt",
65
- container=False,
66
- )
67
- submit_button_component = gr.Button(
68
- value='Submit', variant='primary')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  with gr.Column():
70
  output_image_component = gr.Image(
71
  type='pil', image_mode='RGB', label='Generated image')
 
 
 
72
 
73
  submit_button_component.click(
74
  fn=process,
75
  inputs=[
76
  input_image_editor_component,
77
- input_text_component
 
 
 
 
78
  ],
79
  outputs=[
80
- output_image_component
 
81
  ]
82
  )
83
 
 
1
+ from typing import Tuple
2
+
3
+ import random
4
+ import numpy as np
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
  from diffusers import FluxInpaintPipeline
10
 
11
  MARKDOWN = """
 
16
  for taking it to the next level by enabling inpainting with the FLUX.
17
  """
18
 
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ MAX_IMAGE_SIZE = 2048
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  pipe = FluxInpaintPipeline.from_pretrained(
24
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
25
 
26
 
27
+ def resize_image_dimensions(
28
+ original_resolution_wh: Tuple[int, int],
29
+ maximum_dimension: int = 2048
30
+ ) -> Tuple[int, int]:
31
+ width, height = original_resolution_wh
32
+
33
+ if width > height:
34
+ scaling_factor = maximum_dimension / width
35
+ else:
36
+ scaling_factor = maximum_dimension / height
37
+
38
+ new_width = int(width * scaling_factor)
39
+ new_height = int(height * scaling_factor)
40
+
41
+ new_width = new_width - (new_width % 32)
42
+ new_height = new_height - (new_height % 32)
43
+
44
+ new_width = min(maximum_dimension, new_width)
45
+ new_height = min(maximum_dimension, new_height)
46
+
47
+ return new_width, new_height
48
+
49
+
50
  @spaces.GPU()
51
+ def process(
52
+ input_image_editor: dict,
53
+ input_text: str,
54
+ seed_slicer: int,
55
+ randomize_seed_checkbox: bool,
56
+ strength_slider: float,
57
+ num_inference_steps_slider: int,
58
+ progress=gr.Progress(track_tqdm=True)
59
+ ):
60
  if not input_text:
61
  gr.Info("Please enter a text prompt.")
62
  return None
63
 
64
  image = input_image_editor['background']
65
+ mask = input_image_editor['layers'][0]
66
 
67
  if not image:
68
  gr.Info("Please upload an image.")
69
  return None
70
 
71
+ if not mask:
72
  gr.Info("Please draw a mask on the image.")
73
  return None
74
 
75
+ width, height = resize_image_dimensions(original_resolution_wh=image.size)
76
+ resized_image = image.resize((width, height), Image.LANCZOS)
77
+ resized_mask = mask.resize((width, height), Image.NEAREST)
78
 
79
+ if randomize_seed_checkbox:
80
+ seed_slicer = random.randint(0, MAX_SEED)
81
+ generator = torch.Generator().manual_seed(seed_slicer)
82
  return pipe(
83
  prompt=input_text,
84
+ image=resized_image,
85
+ mask_image=resized_mask,
86
  width=width,
87
  height=height,
88
+ strength=strength_slider,
89
+ generator=generator,
90
+ num_inference_steps=num_inference_steps_slider
91
+ ).images[0], resized_mask
92
 
93
 
94
  with gr.Blocks() as demo:
 
102
  image_mode='RGB',
103
  layers=False,
104
  brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
105
+
106
+ with gr.Row():
107
+ input_text_component = gr.Text(
108
+ label="Prompt",
109
+ show_label=False,
110
+ max_lines=1,
111
+ placeholder="Enter your prompt",
112
+ container=False,
113
+ )
114
+ submit_button_component = gr.Button(
115
+ value='Submit', variant='primary', scale=0)
116
+
117
+ with gr.Accordion("Advanced Settings", open=False):
118
+ seed_slicer_component = gr.Slider(
119
+ label="Seed",
120
+ minimum=0,
121
+ maximum=MAX_SEED,
122
+ step=1,
123
+ value=0,
124
+ )
125
+
126
+ randomize_seed_checkbox_component = gr.Checkbox(
127
+ label="Randomize seed", value=True)
128
+
129
+ with gr.Row():
130
+ strength_slider_component = gr.Slider(
131
+ label="Strength",
132
+ minimum=0,
133
+ maximum=1,
134
+ step=0.01,
135
+ value=0.75,
136
+ )
137
+
138
+ num_inference_steps_slider_component = gr.Slider(
139
+ label="Number of inference steps",
140
+ minimum=1,
141
+ maximum=50,
142
+ step=1,
143
+ value=20,
144
+ )
145
  with gr.Column():
146
  output_image_component = gr.Image(
147
  type='pil', image_mode='RGB', label='Generated image')
148
+ with gr.Accordion("Debug", open=False):
149
+ output_mask_component = gr.Image(
150
+ type='pil', image_mode='RGB', label='Input mask')
151
 
152
  submit_button_component.click(
153
  fn=process,
154
  inputs=[
155
  input_image_editor_component,
156
+ input_text_component,
157
+ seed_slicer_component,
158
+ randomize_seed_checkbox_component,
159
+ strength_slider_component,
160
+ num_inference_steps_slider_component
161
  ],
162
  outputs=[
163
+ output_image_component,
164
+ output_mask_component
165
  ]
166
  )
167