BRIA-Eraser-API / app.py
MishaF's picture
Update app.py
82f14e3 verified
import gradio as gr
import numpy as np
import os
from PIL import Image
import requests
from io import BytesIO
import io
import base64
hf_token = os.environ.get("HF_TOKEN_API_DEMO") # we get it from a secret env variable, such that it's private
auth_headers = {"api_token": hf_token}
def convert_mask_image_to_base64_string(mask_image):
buffer = io.BytesIO()
mask_image.save(buffer, format="PNG") # You can choose the format (e.g., "JPEG", "PNG")
# Encode the buffer in base64
image_base64_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f",{image_base64_string}" # for some reason the funciton which downloads image from base64 expects prefix of "," which is redundant in the url
def download_image(url):
response = requests.get(url)
return Image.open(BytesIO(response.content)).convert("RGB")
def eraser_api_call(image_base64_file, mask_base64_file, mask_type):
url = "http://engine.prod.bria-api.com/v1/eraser"
payload = {
"file": image_base64_file,
"mask_file": mask_base64_file,
"mask_type": mask_type,
}
response = requests.post(url, json=payload, headers=auth_headers)
response = response.json()
res_image = download_image(response["result_url"])
return res_image
def predict(dict):
init_image = Image.fromarray(dict['background'][:, :, :3], 'RGB') #dict['background'].convert("RGB")#.resize((1024, 1024))
mask = Image.fromarray(dict['layers'][0][:,:,3], 'L') #dict['layers'].convert("RGB")#.resize((1024, 1024))
image_base64_file = convert_mask_image_to_base64_string(init_image)
mask_base64_file = convert_mask_image_to_base64_string(mask)
mask_type = "manual"
gen_img = eraser_api_call(image_base64_file, mask_base64_file, mask_type)
return gen_img
css = '''
.gradio-container{max-width: 1100px !important}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;}
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
#share-btn-container:hover {background-color: #060606}
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;}
#share-btn * {all: unset}
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
#share-btn-container .wrap {display: none !important}
#share-btn-container.hidden {display: none!important}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button {
width: 100%;
height: 50px; /* Set a fixed height for the button */
display: flex;
align-items: center;
justify-content: center;
}
#output-img img, #image_upload img {
object-fit: contain; /* Ensure aspect ratio is preserved */
width: 100%;
height: auto; /* Let height adjust automatically */
}
#prompt-container{margin-top:-18px;}
#prompt-container .form{border-top-left-radius: 0;border-top-right-radius: 0}
#image_upload{border-bottom-left-radius: 0px;border-bottom-right-radius: 0px}
'''
image_blocks = gr.Blocks(css=css, elem_id="total-container")
with image_blocks as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## BRIA Eraser API")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This demo showcases the BRIA Eraser capability, which allows users to remove specific elements or objects from images.<br>
The pipeline comprises multiple components, including <a href="https://huggingface.co/briaai/BRIA-2.3" target="_blank">briaai/BRIA-2.3</a>,
<a href="https://huggingface.co/briaai/BRIA-2.3-ControlNet-Inpainting" target="_blank">briaai/BRIA-2.3-ControlNet-Inpainting</a>,
and <a href="https://huggingface.co/briaai/BRIA-2.3-FAST-LORA" target="_blank">briaai/BRIA-2.3-FAST-LORA</a>, all trained on licensed data.<br>
This ensures full legal liability coverage for copyright and privacy infringement.<br>
Notes:<br>
- High-resolution images may take longer to process.<br>
- For multiple masks, results are better if all masks are included in inference.<br><br>
</p>
<p style="margin-bottom: 10px; font-size: 94%">
API Endpoint available on: <a href="https://fal.ai/models/fal-ai/bria/eraser" target="_blank">fal.ai</a><br>
ComfyUI node is available here: <a href="https://github.com/Bria-AI/ComfyUI-BRIA-API" target="_blank">ComfyUI Node</a>
</p>
''')
with gr.Row():
with gr.Column():
image = gr.ImageEditor(sources=["upload"], layers=False, transforms=[],
brush=gr.Brush(colors=["#000000"], color_mode="fixed"),
)
with gr.Row(elem_id="prompt-container", equal_height=True):
with gr.Column(): # Wrap the button inside a Column
btn = gr.Button("Erase!", elem_id="run_button")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img")
# Button click will trigger the inpainting function (no prompt required)
btn.click(fn=predict, inputs=[image], outputs=[image_out], api_name='run')
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="https://huggingface.co/diffusers" style="text-decoration: underline;" target="_blank">Diffusers</a> - Gradio Demo by 🤗 Hugging Face
</p>
</div>
"""
)
image_blocks.queue(max_size=25,api_open=False).launch(show_api=False)