Spaces:
Configuration error
Configuration error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import math | |
import cv2 | |
import numpy as np | |
import paddle | |
from paddleseg import utils | |
from paddleseg.core import infer | |
from paddleseg.utils import logger, progbar, visualize | |
def mkdir(path): | |
sub_dir = os.path.dirname(path) | |
if not os.path.exists(sub_dir): | |
os.makedirs(sub_dir) | |
def partition_list(arr, m): | |
"""split the list 'arr' into m pieces""" | |
n = int(math.ceil(len(arr) / float(m))) | |
return [arr[i:i + n] for i in range(0, len(arr), n)] | |
def preprocess(im_path, transforms): | |
data = {} | |
data['img'] = im_path | |
data = transforms(data) | |
data['img'] = data['img'][np.newaxis, ...] | |
data['img'] = paddle.to_tensor(data['img']) | |
return data | |
def predict(model, | |
model_path, | |
transforms, | |
image_list, | |
image_dir=None, | |
save_dir='output', | |
aug_pred=False, | |
scales=1.0, | |
flip_horizontal=True, | |
flip_vertical=False, | |
is_slide=False, | |
stride=None, | |
crop_size=None, | |
custom_color=None): | |
""" | |
predict and visualize the image_list. | |
Args: | |
model (nn.Layer): Used to predict for input image. | |
model_path (str): The path of pretrained model. | |
transforms (transform.Compose): Preprocess for input image. | |
image_list (list): A list of image path to be predicted. | |
image_dir (str, optional): The root directory of the images predicted. Default: None. | |
save_dir (str, optional): The directory to save the visualized results. Default: 'output'. | |
aug_pred (bool, optional): Whether to use mulit-scales and flip augment for predition. Default: False. | |
scales (list|float, optional): Scales for augment. It is valid when `aug_pred` is True. Default: 1.0. | |
flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_pred` is True. Default: True. | |
flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_pred` is True. Default: False. | |
is_slide (bool, optional): Whether to predict by sliding window. Default: False. | |
stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height. | |
It should be provided when `is_slide` is True. | |
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height. | |
It should be provided when `is_slide` is True. | |
custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map. | |
""" | |
utils.utils.load_entire_model(model, model_path) | |
model.eval() | |
nranks = paddle.distributed.get_world_size() | |
local_rank = paddle.distributed.get_rank() | |
if nranks > 1: | |
img_lists = partition_list(image_list, nranks) | |
else: | |
img_lists = [image_list] | |
added_saved_dir = os.path.join(save_dir, 'added_prediction') | |
pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction') | |
logger.info("Start to predict...") | |
progbar_pred = progbar.Progbar(target=len(img_lists[0]), verbose=1) | |
color_map = visualize.get_color_map_list(256, custom_color=custom_color) | |
with paddle.no_grad(): | |
for i, im_path in enumerate(img_lists[local_rank]): | |
data = preprocess(im_path, transforms) | |
if aug_pred: | |
pred, _ = infer.aug_inference( | |
model, | |
data['img'], | |
trans_info=data['trans_info'], | |
scales=scales, | |
flip_horizontal=flip_horizontal, | |
flip_vertical=flip_vertical, | |
is_slide=is_slide, | |
stride=stride, | |
crop_size=crop_size) | |
else: | |
pred, _ = infer.inference( | |
model, | |
data['img'], | |
trans_info=data['trans_info'], | |
is_slide=is_slide, | |
stride=stride, | |
crop_size=crop_size) | |
pred = paddle.squeeze(pred) | |
pred = pred.numpy().astype('uint8') | |
# get the saved name | |
if image_dir is not None: | |
im_file = im_path.replace(image_dir, '') | |
else: | |
im_file = os.path.basename(im_path) | |
if im_file[0] == '/' or im_file[0] == '\\': | |
im_file = im_file[1:] | |
# save added image | |
added_image = utils.visualize.visualize( | |
im_path, pred, color_map, weight=0.6) | |
added_image_path = os.path.join(added_saved_dir, im_file) | |
mkdir(added_image_path) | |
cv2.imwrite(added_image_path, added_image) | |
# save pseudo color prediction | |
pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map) | |
pred_saved_path = os.path.join( | |
pred_saved_dir, os.path.splitext(im_file)[0] + ".png") | |
mkdir(pred_saved_path) | |
pred_mask.save(pred_saved_path) | |
progbar_pred.update(i + 1) | |