File size: 4,503 Bytes
1e3b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import shutil
import subprocess
import sys
from typing import List

import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage

from .CustomSession import CustomBaseSession


class ModnetPhotographicSession(CustomBaseSession):
    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
        ort_outs = self.inner_session.run(
            None,
            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (512, 512)),
        )

        pred = ort_outs[0][:, 0, :, :]

        ma = np.max(pred)
        mi = np.min(pred)

        pred = (pred - mi) / (ma - mi)
        pred = np.squeeze(pred)

        mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
        mask = mask.resize(img.size, Image.LANCZOS)

        return [mask]

    @classmethod
    def download_models(cls, *args, **kwargs):
        fname = f"{cls.name()}.onnx"

        if not os.path.exists(os.path.join(cls.u2net_home(), fname)):
            pooch.retrieve(
                "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/onnx/export_onnx.py",
                "SHA256:647990c98c409fbf6a72cd2a2db5fe19d2e4b15a3a436ef0302be0582458b63e",
                fname=f"export_onnx.py",
                path=os.path.join(cls.u2net_home(), "modnet-p/"),
                progressbar=True,
            )

            pooch.retrieve(
                "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/onnx/modnet_onnx.py",
                "SHA256:0502cad1b7ab0bf2f866179960454c1e63df096390db05e93cb40145dbc26e1f",
                fname=f"modnet_onnx.py",
                path=os.path.join(cls.u2net_home(), "modnet-p/"),
                progressbar=True,
            )

            pooch.retrieve(
                "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/__init__.py",
                "SHA256:28a5fb95f7dcf9e365edbf42c6d2e8ea0ca4839e51fd7f11bd0547d2359fcd96",
                fname=f"__init__.py",
                path=os.path.join(cls.u2net_home(), "modnet-p/src/models/backbones"),
                progressbar=True,
            )

            pooch.retrieve(
                "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/mobilenetv2.py",
                "SHA256:e3cc8ad6a9933ba18a17a62d5f887c64e0721240871ea8b48742fb9a8a2c3199",
                fname=f"mobilenetv2.py",
                path=os.path.join(cls.u2net_home(), "modnet-p/src/models/backbones"),
                progressbar=True,
            )

            pooch.retrieve(
                "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/wrapper.py",
                "SHA256:41197be7eb96b8a60dc034b55d8c9340dd682a41441dcf2ce67238955dfa5607",
                fname=f"wrapper.py",
                path=os.path.join(cls.u2net_home(), "modnet-p/src/models/backbones"),
                progressbar=True,
            )

            pooch.retrieve(
                "https://storage.openvinotoolkit.org/repositories/open_model_zoo/public/2022.2/modnet-photographic-portrait-matting/modnet_photographic_portrait_matting.ckpt",
                "SHA256:7c22235f0925deba15d4d63e53afcb654c47055bbcd98f56e393ab2584007ed8",
                fname=f"modnet_photographic_portrait_matting.ckpt",
                path=os.path.join(cls.u2net_home(), "modnet-p/"),
                progressbar=True,
            )

            replace_line(
                os.path.join(cls.u2net_home(), "modnet-p/export_onnx.py"),
                "from . import modnet_onnx",
                "import modnet_onnx"
            )

            subprocess.run([
                sys.executable,
                os.path.join(cls.u2net_home(), "modnet-p/export_onnx.py"),
                "--ckpt-path=" + os.path.join(cls.u2net_home(), "modnet-p/modnet_photographic_portrait_matting.ckpt"),
                "--output-path=" + os.path.join(cls.u2net_home(), "modnet-p/../modnet-p.onnx"),
            ])

            shutil.rmtree(os.path.join(cls.u2net_home(), "modnet-p/"))

        return os.path.join(cls.u2net_home(), fname)

    @classmethod
    def name(cls, *args, **kwargs):
        return "modnet-p"


def replace_line(path: str, old: str, new: str):
    with open(path, "r", encoding="utf-8") as file:
        data = file.readlines()

    for i in range(len(data)):
        if data[i].__contains__(old):
            data[i] = data[i].replace(old, new)

    with open(path, "w", encoding="utf-8") as file:
        file.writelines(data)