|
import os,sys |
|
import folder_paths |
|
|
|
from PIL import Image |
|
import importlib.util |
|
import numpy as np |
|
|
|
import torch |
|
|
|
global _available |
|
_available=False |
|
|
|
def is_installed(package): |
|
try: |
|
spec = importlib.util.find_spec(package) |
|
except ModuleNotFoundError: |
|
return False |
|
return spec is not None |
|
|
|
if is_installed('simple_lama_inpainting')==False: |
|
import subprocess |
|
from packaging import version |
|
|
|
if version.parse(torch.__version__)>=version.parse('2.1'): |
|
|
|
print('#pip install simple_lama_inpainting') |
|
|
|
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'simple_lama_inpainting'], capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0: |
|
print("#install success") |
|
from simple_lama_inpainting import SimpleLama |
|
_available=True |
|
else: |
|
print("#install error") |
|
else: |
|
print('#pls check your torch version >= 2.1') |
|
|
|
else: |
|
from simple_lama_inpainting import SimpleLama |
|
_available=True |
|
|
|
|
|
def get_lama_path(): |
|
try: |
|
return folder_paths.get_folder_paths('lama')[0] |
|
except: |
|
return os.path.join(folder_paths.models_dir, "lama") |
|
|
|
llma_model_path=os.path.join(get_lama_path(), "big-lama.pt") |
|
if not os.path.exists(llma_model_path): |
|
os.environ['LAMA_MODEL']='' |
|
print(f"## lama torchscript model not found: {llma_model_path},pls download from https://github.com/enesmsahin/simple-lama-inpainting/releases/download/v0.1.0/big-lama.pt") |
|
else: |
|
os.environ['LAMA_MODEL'] = llma_model_path |
|
|
|
|
|
def tensor2pil(image): |
|
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
|
|
|
def pil2tensor(image): |
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LaMaInpainting: |
|
global _available |
|
available=_available |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"image": ("IMAGE",), |
|
"mask": ("MASK",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
|
FUNCTION = "run" |
|
|
|
CATEGORY = "♾️Mixlab/Image" |
|
|
|
INPUT_IS_LIST = True |
|
OUTPUT_IS_LIST = (True,) |
|
global simple_lama |
|
simple_lama = None |
|
def run(self,image,mask): |
|
global simple_lama |
|
|
|
result=[] |
|
if simple_lama==None: |
|
simple_lama = SimpleLama() |
|
else: |
|
simple_lama.model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
for i in range(len(image)): |
|
im=image[i] |
|
ma=mask[i] |
|
im=tensor2pil(im) |
|
ma=tensor2pil(ma) |
|
ma =ma.convert('L') |
|
|
|
res = simple_lama(im, ma) |
|
res=pil2tensor(res) |
|
result.append(res) |
|
|
|
if simple_lama.device=='cuda': |
|
simple_lama.model.to('cpu') |
|
|
|
return (result,) |