File size: 2,085 Bytes
0008ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.nn.functional as F
from u2net import U2NET
import data_transforms
from transformers import Pipeline

class U2NetPipeline(Pipeline):
    def __init__(self, model, **kwargs):
        super().__init__(model=model, **kwargs)
        self.model = U2NET(3, 1)
        self.model.load_state_dict(torch.load(model, map_location="cpu"))
        self.model.eval()

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, image):
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        elif isinstance(image, Image.Image):
            image = image.convert("RGB")
        else:
            raise ValueError("Input must be a PIL Image or a path to an image file")

        image = np.array(image)
        transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
        sample = transform({"imidx": np.array([0]), "image": image, "label": np.zeros(image.shape[:2])})

        input_size = [1024, 1024]
        im_tensor = sample['image'].unsqueeze(0)
        im_tensor = F.interpolate(im_tensor, input_size, mode="bilinear")
        image = torch.divide(im_tensor, 255.0)
        image = transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])(image)

        return {"image": image, "original_size": image.shape[2:]}

    def _forward(self, model_inputs):
        with torch.no_grad():
            outputs = self.model(model_inputs["image"])
        return {"outputs": outputs, "original_size": model_inputs["original_size"]}

    def postprocess(self, model_outputs):
        result = model_outputs["outputs"][0][0]
        result = F.interpolate(result, size=model_outputs["original_size"], mode='bilinear', align_corners=False)
        result = result.squeeze().cpu().numpy()
        ma, mi = result.max(), result.min()
        result = (result - mi) / (ma - mi)
        return (result * 255).astype(np.uint8)

def load_model():
    return U2NetPipeline("u2net.pth")