|
import torch |
|
from u2net import U2NET |
|
from torchvision import transforms |
|
import numpy as np |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
import data_transforms |
|
|
|
|
|
def load_model(): |
|
model = U2NET(3, 1) |
|
model.load_state_dict(torch.load("u2net.pth", map_location="cpu")) |
|
model.eval() |
|
return model |
|
|
|
|
|
def preprocess(image): |
|
transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)]) |
|
label_3 = np.zeros(image.shape) |
|
label = np.zeros(label_3.shape[0:2]) |
|
sample = transform({"imidx": np.array([0]), "image": image, "label": label}) |
|
return sample |
|
|
|
|
|
def infer(model, image): |
|
input_size = [1024, 1024] |
|
im_shp = image.shape[0:2] |
|
im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) |
|
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8) |
|
image = torch.divide(im_tensor, 255.0) |
|
result = model(image) |
|
result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0) |
|
result = (result - result.min()) / (result.max() - result.min()) |
|
return result.numpy() |
|
|