|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from u2net import U2NET |
|
import data_transforms |
|
from transformers import Pipeline |
|
|
|
class U2NetPipeline(Pipeline): |
|
def __init__(self, model, **kwargs): |
|
super().__init__(model=model, **kwargs) |
|
self.model = model |
|
self.model.eval() |
|
|
|
@classmethod |
|
def from_pretrained(cls, model_path, **kwargs): |
|
model = U2NET(3, 1) |
|
model.load_state_dict(torch.load(f"{model_path}/u2net.pth", map_location="cpu")) |
|
return cls(model, **kwargs) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
return {}, {}, {} |
|
|
|
def preprocess(self, image): |
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
elif isinstance(image, Image.Image): |
|
image = image.convert("RGB") |
|
else: |
|
raise ValueError("Input must be a PIL Image or a path to an image file") |
|
|
|
image = np.array(image) |
|
transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)]) |
|
sample = transform({"imidx": np.array([0]), "image": image, "label": np.zeros(image.shape[:2])}) |
|
|
|
input_size = [1024, 1024] |
|
im_tensor = sample['image'].unsqueeze(0) |
|
im_tensor = F.interpolate(im_tensor, input_size, mode="bilinear") |
|
image = torch.divide(im_tensor, 255.0) |
|
image = transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])(image) |
|
|
|
return {"image": image, "original_size": image.shape[2:]} |
|
|
|
def _forward(self, model_inputs): |
|
with torch.no_grad(): |
|
outputs = self.model(model_inputs["image"]) |
|
return {"outputs": outputs, "original_size": model_inputs["original_size"]} |
|
|
|
def postprocess(self, model_outputs): |
|
result = model_outputs["outputs"][0][0] |
|
result = F.interpolate(result, size=model_outputs["original_size"], mode='bilinear', align_corners=False) |
|
result = result.squeeze().cpu().numpy() |
|
ma, mi = result.max(), result.min() |
|
result = (result - mi) / (ma - mi) |
|
return (result * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|