File size: 2,676 Bytes
6865e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9db8331
23fb6b9
6865e91
 
 
23fb6b9
6865e91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import cv2
import gradio as gr
from typing import Union, Tuple
from PIL import Image, ImageOps
import numpy as np
import torch

model = torch.jit.load('./model/model.pt').eval()

def resize_with_padding(img: Image.Image, expected_size: Tuple[int, int]) -> Image.Image:
    img.thumbnail((expected_size[0], expected_size[1]))
    delta_width = expected_size[0] - img.size[0]
    delta_height = expected_size[1] - img.size[1]
    pad_width = delta_width // 2
    pad_height = delta_height // 2
    padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
    return ImageOps.expand(img, padding), padding

def preprocess_image(img: Image.Image, size: int = 512) -> Tuple[Image.Image, torch.tensor, Tuple[int]]:    
    pil_img, padding = resize_with_padding(img, (size, size))
    
    img = (np.array(pil_img).astype(np.float32) / 255) - np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
    img = img / np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
    img = np.transpose(img, (2, 0, 1))
    
    return pil_img, torch.tensor(img[None]), padding

def soft_blur_with_mask(image: Image.Image, mask: torch.tensor, padding: Tuple[int]) -> Image.Image:
    image = np.array(image)
    # Create a blurred copy of the original image.
    blurred_image = cv2.GaussianBlur(image, (221, 221), sigmaX=20, sigmaY=20)
    image_height, image_width = image.shape[:2]
    mask = cv2.resize(mask.astype(np.uint8), (image_width, image_height), interpolation=cv2.INTER_NEAREST)
    # Blurring the mask itself to get a softer mask with no firm edges
    mask = cv2.GaussianBlur(mask.astype(np.float32), (11, 11), 10, 10)[:, :, None]

    # Take the blurred image where the mask it positive, and the original image where the image is original
    image = (mask * blurred_image + (1.0 - mask) * image)
    pad_w, pad_h, _, _ = padding
    img_w, img_h, _ = image.shape
    image = image[(pad_h):(img_h-pad_h), (pad_w):(img_w-pad_w), :]
    return Image.fromarray(image.astype(np.uint8))

def run(image, size):
    pil_image, torch_image, padding = preprocess_image(image, size=size)

    with torch.inference_mode():
        mask = model(torch_image)
    mask = mask.argmax(dim=1).numpy().squeeze()

    return soft_blur_with_mask(pil_image, mask, padding)

content_image_input = gr.inputs.Image(label="Entrada", type="pil")
model_image_size = gr.inputs.Radio([256, 384, 512, 1024], type="value", default=512, label="Ajustar nivel de inferencia")

app_interface = gr.Interface(fn=run,
                             inputs=[content_image_input, model_image_size],
                             outputs="image")
app_interface.launch()