File size: 4,571 Bytes
81b1a0e e797135 6284dc0 e797135 6284dc0 e797135 6be00d8 e797135 81b1a0e 69b7015 81b1a0e 327742a 6284dc0 327742a 81b1a0e 6284dc0 81b1a0e 327742a 81b1a0e 6284dc0 81b1a0e a10635a e7c2780 a10635a 81b1a0e de0b7d0 d967d62 fbe03e2 e797135 6284dc0 7ccb658 bfe6e38 0e5a7e4 a0ef2a3 bfe6e38 b59df1c 8aa2ae3 327742a 6284dc0 81b1a0e 6284dc0 81b1a0e 6284dc0 81b1a0e 6284dc0 81b1a0e fbe03e2 de5ed42 81b1a0e b59df1c e7c2780 b59df1c e797135 81b1a0e ec6f3d6 81b1a0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import cv2
import numpy as np
import torch
import gradio as gr
import spaces
from glob import glob
from typing import Optional, Tuple
from PIL import Image
from gradio_imageslider import ImageSlider
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
torch.set_float32_matmul_precision('high')
torch.jit.script = lambda f: f
device = "cuda" if torch.cuda.is_available() else "cpu"
def array_to_pil_image(image: np.ndarray, size: Tuple[int, int] = (1024, 1024)) -> Image.Image:
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
image = Image.fromarray(image).convert('RGB')
return image
class ImagePreprocessor():
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
self.transform_image = transforms.Compose([
# transforms.Resize(resolution), # 1. keep consistent with the cv2.resize used in training 2. redundant with that in path_to_image()
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image: Image.Image) -> torch.Tensor:
image = self.transform_image(image)
return image
usage_to_weights_file = {
'General': 'BiRefNet',
'General-Lite': 'BiRefNet_T',
'Portrait': 'BiRefNet-portrait',
'DIS': 'BiRefNet-DIS5K',
'HRSOD': 'BiRefNet-HRSOD',
'COD': 'BiRefNet-COD',
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs'
}
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
birefnet.to(device)
birefnet.eval()
@spaces.GPU
def predict(
image: np.ndarray,
resolution: str,
weights_file: Optional[str]
) -> Tuple[np.ndarray, np.ndarray]:
global birefnet
# Load BiRefNet with chosen weights
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
print('Using weights:', _weights_file)
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
birefnet.to(device)
birefnet.eval()
resolution = f"{image.shape[1]}x{image.shape[0]}" if resolution == '' else resolution
resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
image_shape = image.shape[:2]
image_pil = array_to_pil_image(image, tuple(resolution))
# Preprocess the image
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
image_proc = image_preprocessor.proc(image_pil)
image_proc = image_proc.unsqueeze(0)
# Perform the prediction
with torch.no_grad():
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
if device == 'cuda':
scaled_pred_tensor = scaled_pred_tensor.cpu()
# Resize the prediction to match the original image shape
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
# Apply the prediction mask to the original image
image_pil = image_pil.resize(pred.shape[::-1])
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
return image, image_pred
examples = [[_] for _ in glob('examples/*')][:]
# Add the option of resolution in a text box.
for idx_example, example in enumerate(examples):
examples[idx_example].append('1024x1024')
examples.append(examples[-1].copy())
examples[-1][1] = '512x512'
demo = gr.Interface(
fn=predict,
inputs=[
'image',
gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `512x512`. Higher resolutions can be much slower for inference.", label="Resolution"),
gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
],
outputs=ImageSlider(),
examples=examples,
title='Online demo for `Bilateral Reference for High-Resolution Dichotomous Image Segmentation`',
description=('Upload a picture, our model will extract a highly accurate segmentation of the subject in it. :)'
'\nThe resolution used in our training was `1024x1024`, which is thus the suggested resolution to obtain good results!\n Ours codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/birefnet for easier access.')
)
demo.launch(debug=True)
|