import os import shutil import subprocess import sys from typing import List import numpy as np import pooch from PIL import Image from PIL.Image import Image as PILImage from .CustomSession import CustomBaseSession class ModnetWebcamSession(CustomBaseSession): def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: ort_outs = self.inner_session.run( None, self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (512, 512)), ) pred = ort_outs[0][:, 0, :, :] ma = np.max(pred) mi = np.min(pred) pred = (pred - mi) / (ma - mi) pred = np.squeeze(pred) mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") mask = mask.resize(img.size, Image.LANCZOS) return [mask] @classmethod def download_models(cls, *args, **kwargs): fname = f"{cls.name()}.onnx" if not os.path.exists(os.path.join(cls.u2net_home(), fname)): pooch.retrieve( "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/onnx/export_onnx.py", "SHA256:647990c98c409fbf6a72cd2a2db5fe19d2e4b15a3a436ef0302be0582458b63e", fname=f"export_onnx.py", path=os.path.join(cls.u2net_home(), "modnet-w/"), progressbar=True, ) pooch.retrieve( "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/onnx/modnet_onnx.py", "SHA256:0502cad1b7ab0bf2f866179960454c1e63df096390db05e93cb40145dbc26e1f", fname=f"modnet_onnx.py", path=os.path.join(cls.u2net_home(), "modnet-w/"), progressbar=True, ) pooch.retrieve( "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/__init__.py", "SHA256:28a5fb95f7dcf9e365edbf42c6d2e8ea0ca4839e51fd7f11bd0547d2359fcd96", fname=f"__init__.py", path=os.path.join(cls.u2net_home(), "modnet-w/src/models/backbones"), progressbar=True, ) pooch.retrieve( "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/mobilenetv2.py", "SHA256:e3cc8ad6a9933ba18a17a62d5f887c64e0721240871ea8b48742fb9a8a2c3199", fname=f"mobilenetv2.py", path=os.path.join(cls.u2net_home(), "modnet-w/src/models/backbones"), progressbar=True, ) pooch.retrieve( "https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/wrapper.py", "SHA256:41197be7eb96b8a60dc034b55d8c9340dd682a41441dcf2ce67238955dfa5607", fname=f"wrapper.py", path=os.path.join(cls.u2net_home(), "modnet-w/src/models/backbones"), progressbar=True, ) pooch.retrieve( "https://storage.openvinotoolkit.org/repositories/open_model_zoo/public/2022.2/modnet-webcam-portrait-matting/modnet_webcam_portrait_matting.ckpt", "SHA256:913b82b66558db39b6286c150f809017d7528c872b156eb14333c9c6cb52108b", fname=f"modnet_webcam_portrait_matting.ckpt", path=os.path.join(cls.u2net_home(), "modnet-w/"), progressbar=True, ) replace_line( os.path.join(cls.u2net_home(), "modnet-w/export_onnx.py"), "from . import modnet_onnx", "import modnet_onnx" ) subprocess.run([ sys.executable, os.path.join(cls.u2net_home(), "modnet-w/export_onnx.py"), "--ckpt-path=" + os.path.join(cls.u2net_home(), "modnet-w/modnet_webcam_portrait_matting.ckpt"), "--output-path=" + os.path.join(cls.u2net_home(), "modnet-w/../modnet-w.onnx"), ]) shutil.rmtree(os.path.join(cls.u2net_home(), "modnet-w/")) return os.path.join(cls.u2net_home(), fname) @classmethod def name(cls, *args, **kwargs): return "modnet-w" def replace_line(path: str, old: str, new: str): with open(path, "r", encoding="utf-8") as file: data = file.readlines() for i in range(len(data)): if data[i].__contains__(old): data[i] = data[i].replace(old, new) with open(path, "w", encoding="utf-8") as file: file.writelines(data)