# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import filelock import os import tempfile import numpy as np import random from urllib.parse import urlparse, unquote import paddle from paddleseg.utils import logger, seg_env from paddleseg.utils.download import download_file_and_uncompress @contextlib.contextmanager def generate_tempdir(directory: str=None, **kwargs): '''Generate a temporary directory''' directory = seg_env.TMP_HOME if not directory else directory with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir: yield _dir def load_entire_model(model, pretrained): if pretrained is not None: load_pretrained_model(model, pretrained) else: logger.warning('Not all pretrained params of {} are loaded, ' \ 'training from scratch or a pretrained backbone.'.format(model.__class__.__name__)) def download_pretrained_model(pretrained_model): """ Download pretrained model from url. Args: pretrained_model (str): the url of pretrained weight Returns: str: the path of pretrained weight """ assert urlparse(pretrained_model).netloc, "The url is not valid." pretrained_model = unquote(pretrained_model) savename = pretrained_model.split('/')[-1] if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')): savename = pretrained_model.split('/')[-2] else: savename = savename.split('.')[0] with generate_tempdir() as _dir: with filelock.FileLock(os.path.join(seg_env.TMP_HOME, savename)): pretrained_model = download_file_and_uncompress( pretrained_model, savepath=_dir, extrapath=seg_env.PRETRAINED_MODEL_HOME, extraname=savename) pretrained_model = os.path.join(pretrained_model, 'model.pdparams') return pretrained_model def load_pretrained_model(model, pretrained_model): if pretrained_model is not None: logger.info('Loading pretrained model from {}'.format(pretrained_model)) if urlparse(pretrained_model).netloc: pretrained_model = download_pretrained_model(pretrained_model) if os.path.exists(pretrained_model): para_state_dict = paddle.load(pretrained_model) model_state_dict = model.state_dict() keys = model_state_dict.keys() num_params_loaded = 0 for k in keys: if k not in para_state_dict: logger.warning("{} is not in pretrained model".format(k)) elif list(para_state_dict[k].shape) != list(model_state_dict[k] .shape): logger.warning( "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})" .format(k, para_state_dict[k].shape, model_state_dict[k] .shape)) else: model_state_dict[k] = para_state_dict[k] num_params_loaded += 1 model.set_dict(model_state_dict) logger.info("There are {}/{} variables loaded into {}.".format( num_params_loaded, len(model_state_dict), model.__class__.__name__)) else: raise ValueError('The pretrained model directory is not Found: {}'. format(pretrained_model)) else: logger.info( 'No pretrained model to load, {} will be trained from scratch.'. format(model.__class__.__name__)) def resume(model, optimizer, resume_model): if resume_model is not None: logger.info('Resume model from {}'.format(resume_model)) if os.path.exists(resume_model): resume_model = os.path.normpath(resume_model) ckpt_path = os.path.join(resume_model, 'model.pdparams') para_state_dict = paddle.load(ckpt_path) ckpt_path = os.path.join(resume_model, 'model.pdopt') opti_state_dict = paddle.load(ckpt_path) model.set_state_dict(para_state_dict) optimizer.set_state_dict(opti_state_dict) iter = resume_model.split('_')[-1] iter = int(iter) return iter else: raise ValueError( 'Directory of the model needed to resume is not Found: {}'. format(resume_model)) else: logger.info('No model needed to resume.') def worker_init_fn(worker_id): np.random.seed(random.randint(0, 100000)) def get_image_list(image_path): """Get image list""" valid_suffix = [ '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' ] image_list = [] image_dir = None if os.path.isfile(image_path): if os.path.splitext(image_path)[-1] in valid_suffix: image_list.append(image_path) else: image_dir = os.path.dirname(image_path) with open(image_path, 'r') as f: for line in f: line = line.strip() if len(line.split()) > 1: line = line.split()[0] image_list.append(os.path.join(image_dir, line)) elif os.path.isdir(image_path): image_dir = image_path for root, dirs, files in os.walk(image_path): for f in files: if '.ipynb_checkpoints' in root: continue if f.startswith('.'): continue if os.path.splitext(f)[-1] in valid_suffix: image_list.append(os.path.join(root, f)) else: raise FileNotFoundError( '`--image_path` is not found. it should be a path of image, or a file list containing image paths, or a directory including images.' ) if len(image_list) == 0: raise RuntimeError( 'There are not image file in `--image_path`={}'.format(image_path)) return image_list, image_dir