not-lain commited on
Commit
a845831
1 Parent(s): 6a999bd

integrate with transformers

Browse files

this pr will fix the integration with the `transformers` library

following https://github.com/huggingface/transformers/pull/29262 there is no need to further use the method i used in [pr-9](https://huggingface.co/briaai/RMBG-1.4/discussions/9)

I have fixed the `requirements.txt` and the `README.md` files for future users beforehand so no need to change those.

# this is some exrtra bit of code to test the pr before mergin
```
wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
pip install -q git+https://github.com/huggingface/transformers.git
```
how to use before

```python
from transformers import pipeline
pipe = pipeline("image-segmentation",
model="briaai/RMBG-1.4",
revision ="refs/pr/21", # only when using the pr
trust_remote_code=True)
pipe("image_path.webp",out_name="myout.png") # applies mask and saves the extracted image as `myout.png`
```

also friendly tag to

@OriLib



Sincerely,

Hafedh Hichri

Files changed (6) hide show
  1. MyConfig.py +14 -0
  2. MyPipe.py +76 -0
  3. README.md +22 -32
  4. briarmbg.py +9 -7
  5. config.json +24 -3
  6. requirements.txt +2 -1
MyConfig.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+ from typing import List
4
+
5
+ class RMBGConfig(PretrainedConfig):
6
+ model_type = "SegformerForSemanticSegmentation"
7
+ def __init__(
8
+ self,
9
+ in_ch=3,
10
+ out_ch=1,
11
+ **kwargs):
12
+ self.in_ch = in_ch
13
+ self.out_ch = out_ch
14
+ super().__init__(**kwargs)
MyPipe.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch, os
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import numpy as np
6
+ from transformers import Pipeline
7
+ from skimage import io
8
+ from PIL import Image
9
+
10
+ class RMBGPipe(Pipeline):
11
+ def __init__(self,**kwargs):
12
+ Pipeline.__init__(self,**kwargs)
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs :
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "out_name" in kwargs:
24
+ postprocess_kwargs["out_name"] = kwargs["out_name"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
28
+ # preprocess the input
29
+ orig_im = io.imread(im_path)
30
+ orig_im_size = orig_im.shape[0:2]
31
+ image = self.preprocess_image(orig_im, model_input_size).to(self.device)
32
+ inputs = {
33
+ "image":image,
34
+ "orig_im_size":orig_im_size,
35
+ "im_path" : im_path
36
+ }
37
+ return inputs
38
+
39
+ def _forward(self,inputs):
40
+ result = self.model(inputs.pop("image"))
41
+ inputs["result"] = result
42
+ return inputs
43
+ def postprocess(self,inputs,out_name = ""):
44
+ result = inputs.pop("result")
45
+ orig_im_size = inputs.pop("orig_im_size")
46
+ im_path = inputs.pop("im_path")
47
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
48
+ if out_name != "" :
49
+ # if out_name is specified we save the image using that name
50
+ pil_im = Image.fromarray(result_image)
51
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
52
+ orig_image = Image.open(im_path)
53
+ no_bg_image.paste(orig_image, mask=pil_im)
54
+ no_bg_image.save(out_name)
55
+ else :
56
+ return result_image
57
+
58
+ # utilities functions
59
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
+ # same as utilities.py with minor modification
61
+ if len(im.shape) < 3:
62
+ im = im[:, :, np.newaxis]
63
+ # orig_im_size=im.shape[0:2]
64
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
65
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
66
+ image = torch.divide(im_tensor,255.0)
67
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
68
+ return image
69
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
+ ma = torch.max(result)
72
+ mi = torch.min(result)
73
+ result = (result-mi)/(ma-mi)
74
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
+ im_array = np.squeeze(im_array)
76
+ return im_array
README.md CHANGED
@@ -2,7 +2,7 @@
2
  license: other
3
  license_name: bria-rmbg-1.4
4
  license_link: https://bria.ai/bria-huggingface-model-license-agreement/
5
- pipeline_tag: image-to-image
6
  tags:
7
  - remove background
8
  - background
@@ -10,6 +10,7 @@ tags:
10
  - Pytorch
11
  - vision
12
  - legal liability
 
13
 
14
  extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
15
  extra_gated_fields:
@@ -94,43 +95,32 @@ These modifications significantly improve the model’s accuracy and effectivene
94
 
95
  ## Installation
96
  ```bash
97
- git clone https://huggingface.co/briaai/RMBG-1.4
98
- cd RMBG-1.4/
99
- pip install -r requirements.txt
100
  ```
101
 
102
  ## Usage
103
 
104
  ```python
105
- from skimage import io
106
- import torch, os
107
- from PIL import Image
108
- from briarmbg import BriaRMBG
109
- from utilities import preprocess_image, postprocess_image
110
 
111
- im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
112
-
113
- net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
114
-
115
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
- net.to(device)
117
-
118
- # prepare input
119
- model_input_size = [1024,1024]
120
- orig_im = io.imread(im_path)
121
- orig_im_size = orig_im.shape[0:2]
122
- image = preprocess_image(orig_im, model_input_size).to(device)
123
-
124
- # inference
125
- result=net(image)
126
 
127
- # post process
128
- result_image = postprocess_image(result[0][0], orig_im_size)
 
 
 
 
 
129
 
130
- # save result
131
- pil_im = Image.fromarray(result_image)
132
- no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
133
- orig_image = Image.open(im_path)
134
- no_bg_image.paste(orig_image, mask=pil_im)
135
- no_bg_image.save("example_image_no_bg.png")
136
  ```
 
2
  license: other
3
  license_name: bria-rmbg-1.4
4
  license_link: https://bria.ai/bria-huggingface-model-license-agreement/
5
+ pipeline_tag: image-segmentation
6
  tags:
7
  - remove background
8
  - background
 
10
  - Pytorch
11
  - vision
12
  - legal liability
13
+ - transformers
14
 
15
  extra_gated_prompt: This model weights by BRIA AI can be obtained after a commercial license is agreed upon. Fill in the form below and we reach out to you.
16
  extra_gated_fields:
 
95
 
96
  ## Installation
97
  ```bash
98
+ wget https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt && pip install -qr requirements.txt
 
 
99
  ```
100
 
101
  ## Usage
102
 
103
  ```python
104
+ # How to use
 
 
 
 
105
 
106
+ either load the model
107
+ ```python
108
+ from transformers import AutoModelForImageSegmentation
109
+ model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4",trust_remote_code=True)
110
+ ```
 
 
 
 
 
 
 
 
 
 
111
 
112
+ or load the pipeline
113
+ ```python
114
+ from transformers import pipeline
115
+ pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
116
+ numpy_mask = pipe("img_path") # outputs numpy mask
117
+ pipe("image_path",out_name="myout.png") # applies mask and saves the extracted image as `myout.png`
118
+ ```
119
 
120
+ # parameters :
121
+ for the pipeline you can use the following parameters :
122
+ * `model_input_size` : default to [1024,1024]
123
+ * `out_name` : if specified it will use the numpy mask to extract the image and save it using the `out_name`
124
+ * `preprocess_image` : original method created by briaai
125
+ * `postprocess_image` : original method created by briaai
126
  ```
briarmbg.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from huggingface_hub import PyTorchModelHubMixin
 
5
 
6
  class REBNCONV(nn.Module):
7
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
@@ -345,12 +347,12 @@ class myrebnconv(nn.Module):
345
  return self.rl(self.bn(self.conv(x)))
346
 
347
 
348
- class BriaRMBG(nn.Module, PyTorchModelHubMixin):
349
-
350
- def __init__(self,config:dict={"in_ch":3,"out_ch":1}):
351
- super(BriaRMBG,self).__init__()
352
- in_ch=config["in_ch"]
353
- out_ch=config["out_ch"]
354
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
355
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
356
 
 
1
+
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from transformers import PreTrainedModel
6
+ from .MyConfig import RMBGConfig
7
 
8
  class REBNCONV(nn.Module):
9
  def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
 
347
  return self.rl(self.bn(self.conv(x)))
348
 
349
 
350
+ class BriaRMBG(PreTrainedModel):
351
+ config_class = RMBGConfig
352
+ def __init__(self,config):
353
+ super().__init__(config)
354
+ in_ch = config.in_ch # 3
355
+ out_ch = config.out_ch # 1
356
  self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
357
  self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
358
 
config.json CHANGED
@@ -1,4 +1,25 @@
1
  {
2
- "in_ch":3,
3
- "out_ch":1
4
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "_name_or_path": "briaai/RMBG-1.4",
3
+ "architectures": [
4
+ "BriaRMBG"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "MyConfig.RMBGConfig",
8
+ "AutoModelForImageSegmentation": "briarmbg.BriaRMBG"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "impl": "MyPipe.RMBGPipe",
13
+ "pt": [
14
+ "AutoModelForImageSegmentation"
15
+ ],
16
+ "tf": [],
17
+ "type": "image"
18
+ }
19
+ },
20
+ "in_ch": 3,
21
+ "model_type": "SegformerForSemanticSegmentation",
22
+ "out_ch": 1,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.38.0.dev0"
25
+ }
requirements.txt CHANGED
@@ -4,4 +4,5 @@ pillow
4
  numpy
5
  typing
6
  scikit-image
7
- huggingface_hub
 
 
4
  numpy
5
  typing
6
  scikit-image
7
+ huggingface_hub
8
+ git+https://github.com/huggingface/transformers.git