|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import fnmatch |
|
import cv2 |
|
|
|
import sys |
|
|
|
import numpy as np |
|
from modules import devices |
|
from einops import rearrange |
|
from annotator.annotator_path import models_path |
|
|
|
import torchvision |
|
from torchvision.models import MobileNet_V2_Weights |
|
from torchvision import transforms |
|
|
|
COLOR_BACKGROUND = (255,255,0) |
|
COLOR_HAIR = (0,0,255) |
|
COLOR_EYE = (255,0,0) |
|
COLOR_MOUTH = (255,255,255) |
|
COLOR_FACE = (0,255,0) |
|
COLOR_SKIN = (0,255,255) |
|
COLOR_CLOTHES = (255,0,255) |
|
PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] |
|
|
|
class UNet(nn.Module): |
|
def __init__(self): |
|
super(UNet, self).__init__() |
|
self.NUM_SEG_CLASSES = 7 |
|
|
|
mobilenet_v2 = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1) |
|
mob_blocks = mobilenet_v2.features |
|
|
|
|
|
self.en_block0 = nn.Sequential( |
|
mob_blocks[0], |
|
mob_blocks[1] |
|
) |
|
self.en_block1 = nn.Sequential( |
|
mob_blocks[2], |
|
mob_blocks[3], |
|
) |
|
self.en_block2 = nn.Sequential( |
|
mob_blocks[4], |
|
mob_blocks[5], |
|
mob_blocks[6], |
|
) |
|
self.en_block3 = nn.Sequential( |
|
mob_blocks[7], |
|
mob_blocks[8], |
|
mob_blocks[9], |
|
mob_blocks[10], |
|
mob_blocks[11], |
|
mob_blocks[12], |
|
mob_blocks[13], |
|
) |
|
self.en_block4 = nn.Sequential( |
|
mob_blocks[14], |
|
mob_blocks[15], |
|
mob_blocks[16], |
|
) |
|
|
|
|
|
self.de_block4 = nn.Sequential( |
|
nn.UpsamplingNearest2d(scale_factor=2), |
|
nn.Conv2d(160, 96, kernel_size=3, padding=1), |
|
nn.InstanceNorm2d(96), |
|
nn.LeakyReLU(0.1), |
|
nn.Dropout(p=0.2) |
|
) |
|
self.de_block3 = nn.Sequential( |
|
nn.UpsamplingNearest2d(scale_factor=2), |
|
nn.Conv2d(96*2, 32, kernel_size=3, padding=1), |
|
nn.InstanceNorm2d(32), |
|
nn.LeakyReLU(0.1), |
|
nn.Dropout(p=0.2) |
|
) |
|
self.de_block2 = nn.Sequential( |
|
nn.UpsamplingNearest2d(scale_factor=2), |
|
nn.Conv2d(32*2, 24, kernel_size=3, padding=1), |
|
nn.InstanceNorm2d(24), |
|
nn.LeakyReLU(0.1), |
|
nn.Dropout(p=0.2) |
|
) |
|
self.de_block1 = nn.Sequential( |
|
nn.UpsamplingNearest2d(scale_factor=2), |
|
nn.Conv2d(24*2, 16, kernel_size=3, padding=1), |
|
nn.InstanceNorm2d(16), |
|
nn.LeakyReLU(0.1), |
|
nn.Dropout(p=0.2) |
|
) |
|
|
|
self.de_block0 = nn.Sequential( |
|
nn.UpsamplingNearest2d(scale_factor=2), |
|
nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1), |
|
nn.Softmax2d() |
|
) |
|
|
|
def forward(self, x): |
|
e0 = self.en_block0(x) |
|
e1 = self.en_block1(e0) |
|
e2 = self.en_block2(e1) |
|
e3 = self.en_block3(e2) |
|
e4 = self.en_block4(e3) |
|
|
|
d4 = self.de_block4(e4) |
|
d4 = F.interpolate(d4, size=e3.size()[2:], mode='bilinear', align_corners=True) |
|
c4 = torch.cat((d4,e3),1) |
|
|
|
d3 = self.de_block3(c4) |
|
d3 = F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=True) |
|
c3 = torch.cat((d3,e2),1) |
|
|
|
d2 = self.de_block2(c3) |
|
d2 = F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=True) |
|
c2 =torch.cat((d2,e1),1) |
|
|
|
d1 = self.de_block1(c2) |
|
d1 = F.interpolate(d1, size=e0.size()[2:], mode='bilinear', align_corners=True) |
|
c1 = torch.cat((d1,e0),1) |
|
y = self.de_block0(c1) |
|
|
|
return y |
|
|
|
|
|
class AnimeFaceSegment: |
|
|
|
model_dir = os.path.join(models_path, "anime_face_segment") |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.device = devices.get_device_for("controlnet") |
|
|
|
def load_model(self): |
|
remote_model_path = "https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/Annotators/UNet.pth" |
|
modelpath = os.path.join(self.model_dir, "UNet.pth") |
|
if not os.path.exists(modelpath): |
|
from basicsr.utils.download_util import load_file_from_url |
|
load_file_from_url(remote_model_path, model_dir=self.model_dir) |
|
net = UNet() |
|
ckpt = torch.load(modelpath, map_location=self.device) |
|
for key in list(ckpt.keys()): |
|
if 'module.' in key: |
|
ckpt[key.replace('module.', '')] = ckpt[key] |
|
del ckpt[key] |
|
net.load_state_dict(ckpt) |
|
net.eval() |
|
self.model = net.to(self.device) |
|
|
|
def unload_model(self): |
|
if self.model is not None: |
|
self.model.cpu() |
|
|
|
def __call__(self, input_image): |
|
|
|
if self.model is None: |
|
self.load_model() |
|
self.model.to(self.device) |
|
transform = transforms.Compose([ |
|
transforms.Resize(512,interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(),]) |
|
img = Image.fromarray(input_image) |
|
with torch.no_grad(): |
|
img = transform(img).unsqueeze(dim=0).to(self.device) |
|
seg = self.model(img).squeeze(dim=0) |
|
seg = seg.cpu().detach().numpy() |
|
img = rearrange(seg,'h w c -> w c h') |
|
img = [[PALETTE[np.argmax(val)] for val in buf]for buf in img] |
|
return np.array(img).astype(np.uint8) |