clearbg / u2net_pipeline.py
aryanxxvii's picture
Add config.json with model_type
9961846
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.nn.functional as F
from u2net import U2NET
import data_transforms
from transformers import Pipeline
class U2NetPipeline(Pipeline):
def __init__(self, model, **kwargs):
super().__init__(model=model, **kwargs)
self.model = model
self.model.eval()
@classmethod
def from_pretrained(cls, model_path, **kwargs):
model = U2NET(3, 1)
model.load_state_dict(torch.load(f"{model_path}/u2net.pth", map_location="cpu"))
return cls(model, **kwargs)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, image):
if isinstance(image, str):
image = Image.open(image).convert("RGB")
elif isinstance(image, Image.Image):
image = image.convert("RGB")
else:
raise ValueError("Input must be a PIL Image or a path to an image file")
image = np.array(image)
transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
sample = transform({"imidx": np.array([0]), "image": image, "label": np.zeros(image.shape[:2])})
input_size = [1024, 1024]
im_tensor = sample['image'].unsqueeze(0)
im_tensor = F.interpolate(im_tensor, input_size, mode="bilinear")
image = torch.divide(im_tensor, 255.0)
image = transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])(image)
return {"image": image, "original_size": image.shape[2:]}
def _forward(self, model_inputs):
with torch.no_grad():
outputs = self.model(model_inputs["image"])
return {"outputs": outputs, "original_size": model_inputs["original_size"]}
def postprocess(self, model_outputs):
result = model_outputs["outputs"][0][0]
result = F.interpolate(result, size=model_outputs["original_size"], mode='bilinear', align_corners=False)
result = result.squeeze().cpu().numpy()
ma, mi = result.max(), result.min()
result = (result - mi) / (ma - mi)
return (result * 255).astype(np.uint8)
# Remove or comment out this function as it's no longer needed
# def load_model():
# return U2NetPipeline("u2net.pth")