Update MyPipe.py
Browse filesadded support to pil image as input
MyPipe.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn.functional as F
|
|
3 |
from torchvision.transforms.functional import normalize
|
4 |
import numpy as np
|
5 |
from transformers import Pipeline
|
|
|
6 |
from skimage import io
|
7 |
from PIL import Image
|
8 |
|
@@ -23,34 +24,35 @@ class RMBGPipe(Pipeline):
|
|
23 |
postprocess_kwargs["return_mask"] = kwargs["return_mask"]
|
24 |
return preprocess_kwargs, {}, postprocess_kwargs
|
25 |
|
26 |
-
def preprocess(self,
|
27 |
# preprocess the input
|
28 |
-
orig_im =
|
|
|
29 |
orig_im_size = orig_im.shape[0:2]
|
30 |
-
|
31 |
inputs = {
|
32 |
-
"
|
33 |
"orig_im_size":orig_im_size,
|
34 |
-
"
|
35 |
}
|
36 |
return inputs
|
37 |
|
38 |
def _forward(self,inputs):
|
39 |
-
result = self.model(inputs.pop("
|
40 |
inputs["result"] = result
|
41 |
return inputs
|
42 |
|
43 |
def postprocess(self,inputs,return_mask:bool=False ):
|
44 |
result = inputs.pop("result")
|
45 |
orig_im_size = inputs.pop("orig_im_size")
|
46 |
-
|
47 |
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
48 |
pil_im = Image.fromarray(result_image)
|
49 |
if return_mask ==True :
|
50 |
return pil_im
|
51 |
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
52 |
-
|
53 |
-
no_bg_image.paste(
|
54 |
return no_bg_image
|
55 |
|
56 |
# utilities functions
|
@@ -58,7 +60,6 @@ class RMBGPipe(Pipeline):
|
|
58 |
# same as utilities.py with minor modification
|
59 |
if len(im.shape) < 3:
|
60 |
im = im[:, :, np.newaxis]
|
61 |
-
# orig_im_size=im.shape[0:2]
|
62 |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
63 |
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
|
64 |
image = torch.divide(im_tensor,255.0)
|
|
|
3 |
from torchvision.transforms.functional import normalize
|
4 |
import numpy as np
|
5 |
from transformers import Pipeline
|
6 |
+
from transformers.image_utils import load_image
|
7 |
from skimage import io
|
8 |
from PIL import Image
|
9 |
|
|
|
24 |
postprocess_kwargs["return_mask"] = kwargs["return_mask"]
|
25 |
return preprocess_kwargs, {}, postprocess_kwargs
|
26 |
|
27 |
+
def preprocess(self,input_image,model_input_size: list=[1024,1024]):
|
28 |
# preprocess the input
|
29 |
+
orig_im = load_image(input_image)
|
30 |
+
orig_im = np.array(orig_im)
|
31 |
orig_im_size = orig_im.shape[0:2]
|
32 |
+
preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
|
33 |
inputs = {
|
34 |
+
"preprocessed_image":preprocessed_image,
|
35 |
"orig_im_size":orig_im_size,
|
36 |
+
"input_image" : input_image
|
37 |
}
|
38 |
return inputs
|
39 |
|
40 |
def _forward(self,inputs):
|
41 |
+
result = self.model(inputs.pop("preprocessed_image"))
|
42 |
inputs["result"] = result
|
43 |
return inputs
|
44 |
|
45 |
def postprocess(self,inputs,return_mask:bool=False ):
|
46 |
result = inputs.pop("result")
|
47 |
orig_im_size = inputs.pop("orig_im_size")
|
48 |
+
input_image = inputs.pop("input_image")
|
49 |
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
50 |
pil_im = Image.fromarray(result_image)
|
51 |
if return_mask ==True :
|
52 |
return pil_im
|
53 |
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
54 |
+
input_image = load_image(input_image)
|
55 |
+
no_bg_image.paste(input_image, mask=pil_im)
|
56 |
return no_bg_image
|
57 |
|
58 |
# utilities functions
|
|
|
60 |
# same as utilities.py with minor modification
|
61 |
if len(im.shape) < 3:
|
62 |
im = im[:, :, np.newaxis]
|
|
|
63 |
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
64 |
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
|
65 |
image = torch.divide(im_tensor,255.0)
|