Flux-Inpaint / app.py
rishh76's picture
Update app.py
5383233 verified
raw
history blame contribute delete
No virus
6.66 kB
from typing import Tuple, Dict
import requests
import random
import numpy as np
import gradio as gr
import torch
from PIL import Image
from diffusers import FluxInpaintPipeline
# Constants
MARKDOWN_TEXT = """
# FLUX.1 Inpainting 🔥
Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) for
creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
for taking it to the next level by enabling inpainting with the FLUX.
"""
MAX_SEED_VALUE = np.iinfo(np.int32).max
DEFAULT_IMAGE_SIZE = 1024
DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
# Model initialization
pipeline = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE_TYPE)
def adjust_image_size(
original_size: Tuple[int, int], max_dimension: int = DEFAULT_IMAGE_SIZE
) -> Tuple[int, int]:
width, height = original_size
scaling_factor = max_dimension / max(width, height)
new_width = int(width * scaling_factor) - (int(width * scaling_factor) % 32)
new_height = int(height * scaling_factor) - (int(height * scaling_factor) % 32)
return new_width, new_height
def process_images(
input_data: Dict,
prompt: str,
seed: int,
randomize_seed: bool,
strength: float,
num_steps: int,
progress=gr.Progress(track_tqdm=True)
):
if not prompt:
gr.Info("Please enter a text prompt.")
return None, None
background_img = input_data['background']
mask_img = input_data['layers'][0]
if background_img is None:
gr.Info("Please upload an image.")
return None, None
if mask_img is None:
gr.Info("Please draw a mask on the image.")
return None, None
new_width, new_height = adjust_image_size(background_img.size)
resized_bg = background_img.resize((new_width, new_height), Image.LANCZOS)
resized_mask = mask_img.resize((new_width, new_height), Image.LANCZOS)
if randomize_seed:
seed = random.randint(0, MAX_SEED_VALUE)
generator = torch.Generator().manual_seed(seed)
result_image = pipeline(
prompt=prompt,
image=resized_bg,
mask_image=resized_mask,
width=new_width,
height=new_height,
strength=strength,
generator=generator,
num_inference_steps=num_steps
).images[0]
return result_image, resized_mask
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN_TEXT)
with gr.Row():
with gr.Column():
img_editor = gr.ImageEditor(
label='Image',
type='pil',
sources=["upload", "webcam"],
image_mode='RGB',
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")
)
with gr.Row():
text_input = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False
)
submit_btn = gr.Button(
value='Submit', variant='primary', scale=0
)
with gr.Accordion("Advanced Settings", open=False):
seed_slider = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED_VALUE,
step=1,
value=42
)
random_seed_chkbox = gr.Checkbox(
label="Randomize seed", value=True
)
with gr.Row():
strength_slider = gr.Slider(
label="Strength",
info="Indicates extent to transform the reference `image`.",
minimum=0,
maximum=1,
step=0.01,
value=0.85
)
steps_slider = gr.Slider(
label="Number of inference steps",
info="The number of denoising steps.",
minimum=1,
maximum=50,
step=1,
value=20
)
with gr.Column():
output_img = gr.Image(
type='pil', image_mode='RGB', label='Generated Image', format="png"
)
with gr.Accordion("Debug", open=False):
output_mask = gr.Image(
type='pil', image_mode='RGB', label='Input Mask', format="png"
)
gr.Examples(
fn=process_images,
examples=[
[
{
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
"layers": [Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw).convert("RGBA")],
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
},
"little lion",
42,
False,
0.85,
30
],
[
{
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
"layers": [Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw).convert("RGBA")],
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw),
},
"tribal tattoos",
42,
False,
0.85,
30
]
],
inputs=[
img_editor,
text_input,
seed_slider,
random_seed_chkbox,
strength_slider,
steps_slider
],
outputs=[
output_img,
output_mask
],
run_on_click=True,
cache_examples=True
)
submit_btn.click(
fn=process_images,
inputs=[
img_editor,
text_input,
seed_slider,
random_seed_chkbox,
strength_slider,
steps_slider
],
outputs=[
output_img,
output_mask
]
)
demo.launch(debug=False, show_error=True)