runpod_custom_nodes / ComfyUI-Allor /session /ModnetPhotographicSession.py
3v324v23's picture
lfs
1e3b872
import os
import shutil
import subprocess
import sys
from typing import List
import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage
from .CustomSession import CustomBaseSession
class ModnetPhotographicSession(CustomBaseSession):
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (512, 512)),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.LANCZOS)
return [mask]
@classmethod
def download_models(cls, *args, **kwargs):
fname = f"{cls.name()}.onnx"
if not os.path.exists(os.path.join(cls.u2net_home(), fname)):
pooch.retrieve(
"https://raw.githubusercontent.com/ZHKKKe/MODNet/master/onnx/export_onnx.py",
"SHA256:647990c98c409fbf6a72cd2a2db5fe19d2e4b15a3a436ef0302be0582458b63e",
fname=f"export_onnx.py",
path=os.path.join(cls.u2net_home(), "modnet-p/"),
progressbar=True,
)
pooch.retrieve(
"https://raw.githubusercontent.com/ZHKKKe/MODNet/master/onnx/modnet_onnx.py",
"SHA256:0502cad1b7ab0bf2f866179960454c1e63df096390db05e93cb40145dbc26e1f",
fname=f"modnet_onnx.py",
path=os.path.join(cls.u2net_home(), "modnet-p/"),
progressbar=True,
)
pooch.retrieve(
"https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/__init__.py",
"SHA256:28a5fb95f7dcf9e365edbf42c6d2e8ea0ca4839e51fd7f11bd0547d2359fcd96",
fname=f"__init__.py",
path=os.path.join(cls.u2net_home(), "modnet-p/src/models/backbones"),
progressbar=True,
)
pooch.retrieve(
"https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/mobilenetv2.py",
"SHA256:e3cc8ad6a9933ba18a17a62d5f887c64e0721240871ea8b48742fb9a8a2c3199",
fname=f"mobilenetv2.py",
path=os.path.join(cls.u2net_home(), "modnet-p/src/models/backbones"),
progressbar=True,
)
pooch.retrieve(
"https://raw.githubusercontent.com/ZHKKKe/MODNet/master/src/models/backbones/wrapper.py",
"SHA256:41197be7eb96b8a60dc034b55d8c9340dd682a41441dcf2ce67238955dfa5607",
fname=f"wrapper.py",
path=os.path.join(cls.u2net_home(), "modnet-p/src/models/backbones"),
progressbar=True,
)
pooch.retrieve(
"https://storage.openvinotoolkit.org/repositories/open_model_zoo/public/2022.2/modnet-photographic-portrait-matting/modnet_photographic_portrait_matting.ckpt",
"SHA256:7c22235f0925deba15d4d63e53afcb654c47055bbcd98f56e393ab2584007ed8",
fname=f"modnet_photographic_portrait_matting.ckpt",
path=os.path.join(cls.u2net_home(), "modnet-p/"),
progressbar=True,
)
replace_line(
os.path.join(cls.u2net_home(), "modnet-p/export_onnx.py"),
"from . import modnet_onnx",
"import modnet_onnx"
)
subprocess.run([
sys.executable,
os.path.join(cls.u2net_home(), "modnet-p/export_onnx.py"),
"--ckpt-path=" + os.path.join(cls.u2net_home(), "modnet-p/modnet_photographic_portrait_matting.ckpt"),
"--output-path=" + os.path.join(cls.u2net_home(), "modnet-p/../modnet-p.onnx"),
])
shutil.rmtree(os.path.join(cls.u2net_home(), "modnet-p/"))
return os.path.join(cls.u2net_home(), fname)
@classmethod
def name(cls, *args, **kwargs):
return "modnet-p"
def replace_line(path: str, old: str, new: str):
with open(path, "r", encoding="utf-8") as file:
data = file.readlines()
for i in range(len(data)):
if data[i].__contains__(old):
data[i] = data[i].replace(old, new)
with open(path, "w", encoding="utf-8") as file:
file.writelines(data)