import os from annotator.annotator_path import models_path from modules import devices from annotator.uniformer.inference import init_segmentor, inference_segmentor, show_result_pyplot try: from mmseg.core.evaluation import get_palette except ImportError: from annotator.mmpkg.mmseg.core.evaluation import get_palette modeldir = os.path.join(models_path, "uniformer") checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth" config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "upernet_global_small.py") old_modeldir = os.path.dirname(os.path.realpath(__file__)) model = None def unload_uniformer_model(): global model if model is not None: model = model.cpu() def apply_uniformer(img): global model if model is None: modelpath = os.path.join(modeldir, "upernet_global_small.pth") old_modelpath = os.path.join(old_modeldir, "upernet_global_small.pth") if os.path.exists(old_modelpath): modelpath = old_modelpath elif not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url load_file_from_url(checkpoint_file, model_dir=modeldir) model = init_segmentor(config_file, modelpath, device=devices.get_device_for("controlnet")) model = model.to(devices.get_device_for("controlnet")) if devices.get_device_for("controlnet").type == 'mps': # adaptive_avg_pool2d can fail on MPS, workaround with CPU import torch.nn.functional orig_adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d def cpu_if_exception(input, *args, **kwargs): try: return orig_adaptive_avg_pool2d(input, *args, **kwargs) except: return orig_adaptive_avg_pool2d(input.cpu(), *args, **kwargs).to(input.device) try: torch.nn.functional.adaptive_avg_pool2d = cpu_if_exception result = inference_segmentor(model, img) finally: torch.nn.functional.adaptive_avg_pool2d = orig_adaptive_avg_pool2d else: result = inference_segmentor(model, img) res_img = show_result_pyplot(model, img, result, get_palette('ade'), opacity=1) return res_img