integrate with transformers
#21
by
not-lain
- opened
- MyConfig.py +13 -0
- MyPipe.py +73 -0
- README.md +13 -34
- briarmbg.py +8 -7
- config.json +24 -3
- requirements.txt +2 -1
MyConfig.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class RMBGConfig(PretrainedConfig):
|
5 |
+
model_type = "SegformerForSemanticSegmentation"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
in_ch=3,
|
9 |
+
out_ch=1,
|
10 |
+
**kwargs):
|
11 |
+
self.in_ch = in_ch
|
12 |
+
self.out_ch = out_ch
|
13 |
+
super().__init__(**kwargs)
|
MyPipe.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torchvision.transforms.functional import normalize
|
4 |
+
import numpy as np
|
5 |
+
from transformers import Pipeline
|
6 |
+
from skimage import io
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
class RMBGPipe(Pipeline):
|
10 |
+
def __init__(self,**kwargs):
|
11 |
+
Pipeline.__init__(self,**kwargs)
|
12 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
self.model.to(self.device)
|
14 |
+
self.model.eval()
|
15 |
+
|
16 |
+
def _sanitize_parameters(self, **kwargs):
|
17 |
+
# parse parameters
|
18 |
+
preprocess_kwargs = {}
|
19 |
+
postprocess_kwargs = {}
|
20 |
+
if "model_input_size" in kwargs :
|
21 |
+
preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
|
22 |
+
if "return_mask" in kwargs:
|
23 |
+
postprocess_kwargs["return_mask"] = kwargs["return_mask"]
|
24 |
+
return preprocess_kwargs, {}, postprocess_kwargs
|
25 |
+
|
26 |
+
def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
|
27 |
+
# preprocess the input
|
28 |
+
orig_im = io.imread(im_path)
|
29 |
+
orig_im_size = orig_im.shape[0:2]
|
30 |
+
image = self.preprocess_image(orig_im, model_input_size).to(self.device)
|
31 |
+
inputs = {
|
32 |
+
"image":image,
|
33 |
+
"orig_im_size":orig_im_size,
|
34 |
+
"im_path" : im_path
|
35 |
+
}
|
36 |
+
return inputs
|
37 |
+
|
38 |
+
def _forward(self,inputs):
|
39 |
+
result = self.model(inputs.pop("image"))
|
40 |
+
inputs["result"] = result
|
41 |
+
return inputs
|
42 |
+
def postprocess(self,inputs,return_mask:bool=False ):
|
43 |
+
result = inputs.pop("result")
|
44 |
+
orig_im_size = inputs.pop("orig_im_size")
|
45 |
+
im_path = inputs.pop("im_path")
|
46 |
+
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
47 |
+
pil_im = Image.fromarray(result_image)
|
48 |
+
if return_mask ==True :
|
49 |
+
return pil_im
|
50 |
+
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
51 |
+
orig_image = Image.open(im_path)
|
52 |
+
no_bg_image.paste(orig_image, mask=pil_im)
|
53 |
+
return no_bg_image
|
54 |
+
|
55 |
+
# utilities functions
|
56 |
+
def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
|
57 |
+
# same as utilities.py with minor modification
|
58 |
+
if len(im.shape) < 3:
|
59 |
+
im = im[:, :, np.newaxis]
|
60 |
+
# orig_im_size=im.shape[0:2]
|
61 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
|
62 |
+
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
|
63 |
+
image = torch.divide(im_tensor,255.0)
|
64 |
+
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
65 |
+
return image
|
66 |
+
def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
|
67 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
68 |
+
ma = torch.max(result)
|
69 |
+
mi = torch.min(result)
|
70 |
+
result = (result-mi)/(ma-mi)
|
71 |
+
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
|
72 |
+
im_array = np.squeeze(im_array)
|
73 |
+
return im_array
|
README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
license: other
|
3 |
license_name: bria-rmbg-1.4
|
4 |
license_link: https://bria.ai/bria-huggingface-model-license-agreement/
|
5 |
-
pipeline_tag: image-
|
6 |
tags:
|
7 |
- remove background
|
8 |
- background
|
@@ -10,6 +10,7 @@ tags:
|
|
10 |
- Pytorch
|
11 |
- vision
|
12 |
- legal liability
|
|
|
13 |
|
14 |
extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
|
15 |
extra_gated_fields:
|
@@ -94,43 +95,21 @@ These modifications significantly improve the model’s accuracy and effectivene
|
|
94 |
|
95 |
## Installation
|
96 |
```bash
|
97 |
-
|
98 |
-
cd RMBG-1.4/
|
99 |
-
pip install -r requirements.txt
|
100 |
```
|
101 |
|
102 |
## Usage
|
103 |
|
|
|
104 |
```python
|
105 |
-
from
|
106 |
-
|
107 |
-
|
108 |
-
from briarmbg import BriaRMBG
|
109 |
-
from utilities import preprocess_image, postprocess_image
|
110 |
-
|
111 |
-
im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
|
112 |
-
|
113 |
-
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
|
114 |
-
|
115 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
116 |
-
net.to(device)
|
117 |
-
|
118 |
-
# prepare input
|
119 |
-
model_input_size = [1024,1024]
|
120 |
-
orig_im = io.imread(im_path)
|
121 |
-
orig_im_size = orig_im.shape[0:2]
|
122 |
-
image = preprocess_image(orig_im, model_input_size).to(device)
|
123 |
-
|
124 |
-
# inference
|
125 |
-
result=net(image)
|
126 |
-
|
127 |
-
# post process
|
128 |
-
result_image = postprocess_image(result[0][0], orig_im_size)
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
```
|
|
|
2 |
license: other
|
3 |
license_name: bria-rmbg-1.4
|
4 |
license_link: https://bria.ai/bria-huggingface-model-license-agreement/
|
5 |
+
pipeline_tag: image-segmentation
|
6 |
tags:
|
7 |
- remove background
|
8 |
- background
|
|
|
10 |
- Pytorch
|
11 |
- vision
|
12 |
- legal liability
|
13 |
+
- transformers
|
14 |
|
15 |
extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
|
16 |
extra_gated_fields:
|
|
|
95 |
|
96 |
## Installation
|
97 |
```bash
|
98 |
+
wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
|
|
|
|
|
99 |
```
|
100 |
|
101 |
## Usage
|
102 |
|
103 |
+
either load the model
|
104 |
```python
|
105 |
+
from transformers import AutoModelForImageSegmentation
|
106 |
+
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True)
|
107 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
or load the pipeline
|
110 |
+
```python
|
111 |
+
from transformers import pipeline
|
112 |
+
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
113 |
+
pillow_mask = pipe("img_path",return_mask = True) # outputs a pillow mask
|
114 |
+
pillow_image = pipe("image_path") # applies mask on input and returns a pillow image
|
115 |
```
|
briarmbg.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
-
from
|
|
|
5 |
|
6 |
class REBNCONV(nn.Module):
|
7 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
@@ -345,12 +346,12 @@ class myrebnconv(nn.Module):
|
|
345 |
return self.rl(self.bn(self.conv(x)))
|
346 |
|
347 |
|
348 |
-
class BriaRMBG(
|
349 |
-
|
350 |
-
def __init__(self,config
|
351 |
-
super(
|
352 |
-
in_ch=config
|
353 |
-
out_ch=config
|
354 |
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
355 |
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
356 |
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
+
from transformers import PreTrainedModel
|
5 |
+
from .MyConfig import RMBGConfig
|
6 |
|
7 |
class REBNCONV(nn.Module):
|
8 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
|
|
346 |
return self.rl(self.bn(self.conv(x)))
|
347 |
|
348 |
|
349 |
+
class BriaRMBG(PreTrainedModel):
|
350 |
+
config_class = RMBGConfig
|
351 |
+
def __init__(self,config):
|
352 |
+
super().__init__(config)
|
353 |
+
in_ch = config.in_ch # 3
|
354 |
+
out_ch = config.out_ch # 1
|
355 |
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
356 |
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
357 |
|
config.json
CHANGED
@@ -1,4 +1,25 @@
|
|
1 |
{
|
2 |
-
"
|
3 |
-
"
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "briaai/RMBG-1.4",
|
3 |
+
"architectures": [
|
4 |
+
"BriaRMBG"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "MyConfig.RMBGConfig",
|
8 |
+
"AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
|
9 |
+
},
|
10 |
+
"custom_pipelines": {
|
11 |
+
"image-segmentation": {
|
12 |
+
"impl": "MyPipe.RMBGPipe",
|
13 |
+
"pt": [
|
14 |
+
"AutoModelForImageSegmentation"
|
15 |
+
],
|
16 |
+
"tf": [],
|
17 |
+
"type": "image"
|
18 |
+
}
|
19 |
+
},
|
20 |
+
"in_ch": 3,
|
21 |
+
"model_type": "SegformerForSemanticSegmentation",
|
22 |
+
"out_ch": 1,
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"transformers_version": "4.38.0.dev0"
|
25 |
+
}
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ pillow
|
|
4 |
numpy
|
5 |
typing
|
6 |
scikit-image
|
7 |
-
huggingface_hub
|
|
|
|
4 |
numpy
|
5 |
typing
|
6 |
scikit-image
|
7 |
+
huggingface_hub
|
8 |
+
transformers==4.39.1
|