Spaces:
Runtime error
Runtime error
Upload 159 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mmpl/__init__.py +0 -0
- mmpl/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/__pycache__/registry.cpython-310.pyc +0 -0
- mmpl/datasets/__init__.py +9 -0
- mmpl/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/datasets/__pycache__/builder.cpython-310.pyc +0 -0
- mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc +0 -0
- mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc +0 -0
- mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc +0 -0
- mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc +0 -0
- mmpl/datasets/base_dataset.py +212 -0
- mmpl/datasets/builder.py +25 -0
- mmpl/datasets/custom.py +237 -0
- mmpl/datasets/nwpu_ins_dataset.py +59 -0
- mmpl/datasets/pl_datamodule.py +73 -0
- mmpl/datasets/ssdd_ins_dataset.py +54 -0
- mmpl/datasets/transforms/__init__.py +0 -0
- mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/datasets/utils.py +243 -0
- mmpl/datasets/whu_ins_dataset.py +54 -0
- mmpl/engine/__init__.py +5 -0
- mmpl/engine/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__init__.py +6 -0
- mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc +0 -0
- mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc +0 -0
- mmpl/engine/hooks/builder.py +31 -0
- mmpl/engine/hooks/ema_hook.py +240 -0
- mmpl/engine/hooks/param_scheduler_hook.py +128 -0
- mmpl/engine/hooks/pipeline_switch_hook.py +41 -0
- mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py +96 -0
- mmpl/engine/hooks/switch_to_deploy_hook.py +21 -0
- mmpl/engine/hooks/visualization_hook.py +199 -0
- mmpl/engine/hooks/yolov5_param_scheduler_hook.py +111 -0
- mmpl/engine/hooks/yolox_mode_switch_hook.py +54 -0
- mmpl/engine/logger/__init__.py +1 -0
- mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/engine/logger/__pycache__/builder.cpython-310.pyc +0 -0
- mmpl/engine/logger/builder.py +112 -0
- mmpl/engine/optimizers/__init__.py +0 -0
- mmpl/engine/runner/__init__.py +3 -0
- mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc +0 -0
- mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc +0 -0
- mmpl/engine/runner/pl_runner.py +941 -0
- mmpl/engine/strategies/__init__.py +1 -0
- mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc +0 -0
mmpl/__init__.py
ADDED
File without changes
|
mmpl/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (142 Bytes). View file
|
|
mmpl/__pycache__/registry.cpython-310.pyc
ADDED
Binary file (2.1 kB). View file
|
|
mmpl/datasets/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import build_dataset
|
2 |
+
from .pl_datamodule import PLDataModule
|
3 |
+
from .nwpu_ins_dataset import NWPUInsSegDataset
|
4 |
+
from .whu_ins_dataset import WHUInsSegDataset
|
5 |
+
from .ssdd_ins_dataset import SSDDInsSegDataset
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'build_dataset', 'PLDataModule',
|
9 |
+
]
|
mmpl/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (449 Bytes). View file
|
|
mmpl/datasets/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (1.01 kB). View file
|
|
mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc
ADDED
Binary file (2.07 kB). View file
|
|
mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc
ADDED
Binary file (1.76 kB). View file
|
|
mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc
ADDED
Binary file (1.76 kB). View file
|
|
mmpl/datasets/base_dataset.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from os import PathLike
|
4 |
+
from typing import List, Optional, Sequence, Union
|
5 |
+
|
6 |
+
import mmengine
|
7 |
+
import numpy as np
|
8 |
+
from mmengine.dataset import BaseDataset as _BaseDataset
|
9 |
+
|
10 |
+
from .builder import DATASETS
|
11 |
+
|
12 |
+
|
13 |
+
def expanduser(path):
|
14 |
+
"""Expand ~ and ~user constructions.
|
15 |
+
|
16 |
+
If user or $HOME is unknown, do nothing.
|
17 |
+
"""
|
18 |
+
if isinstance(path, (str, PathLike)):
|
19 |
+
return osp.expanduser(path)
|
20 |
+
else:
|
21 |
+
return path
|
22 |
+
|
23 |
+
|
24 |
+
@DATASETS.register_module()
|
25 |
+
class BaseDataset(_BaseDataset):
|
26 |
+
"""Base dataset for image classification task.
|
27 |
+
|
28 |
+
This dataset support annotation file in `OpenMMLab 2.0 style annotation
|
29 |
+
format`.
|
30 |
+
|
31 |
+
.. _OpenMMLab 2.0 style annotation format:
|
32 |
+
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
|
33 |
+
|
34 |
+
Comparing with the :class:`mmengine.BaseDataset`, this class implemented
|
35 |
+
several useful methods.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
ann_file (str): Annotation file path.
|
39 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
40 |
+
information. Defaults to None.
|
41 |
+
data_root (str): The root directory for ``data_prefix`` and
|
42 |
+
``ann_file``. Defaults to ''.
|
43 |
+
data_prefix (str | dict): Prefix for training data. Defaults to ''.
|
44 |
+
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
45 |
+
indices (int or Sequence[int], optional): Support using first few
|
46 |
+
data in annotation file to facilitate training/testing on a smaller
|
47 |
+
dataset. Defaults to None, which means using all ``data_infos``.
|
48 |
+
serialize_data (bool): Whether to hold memory using serialized objects,
|
49 |
+
when enabled, data loader workers can use shared RAM from master
|
50 |
+
process instead of making a copy. Defaults to True.
|
51 |
+
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
|
52 |
+
test_mode (bool): ``test_mode=True`` means in test phase.
|
53 |
+
Defaults to False.
|
54 |
+
lazy_init (bool): Whether to load annotation during instantiation.
|
55 |
+
In some cases, such as visualization, only the meta information of
|
56 |
+
the dataset is needed, which is not necessary to load annotation
|
57 |
+
file. ``Basedataset`` can skip load annotations to save time by set
|
58 |
+
``lazy_init=False``. Defaults to False.
|
59 |
+
max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
|
60 |
+
The maximum extra number of cycles to get a valid image.
|
61 |
+
Defaults to 1000.
|
62 |
+
classes (str | Sequence[str], optional): Specify names of classes.
|
63 |
+
|
64 |
+
- If is string, it should be a file path, and the every line of
|
65 |
+
the file is a name of a class.
|
66 |
+
- If is a sequence of string, every item is a name of class.
|
67 |
+
- If is None, use categories information in ``metainfo`` argument,
|
68 |
+
annotation file or the class attribute ``METAINFO``.
|
69 |
+
|
70 |
+
Defaults to None.
|
71 |
+
""" # noqa: E501
|
72 |
+
|
73 |
+
def __init__(self,
|
74 |
+
ann_file: str = '',
|
75 |
+
metainfo: Optional[dict] = None,
|
76 |
+
data_root: str = '',
|
77 |
+
data_prefix: Union[str, dict] = '',
|
78 |
+
filter_cfg: Optional[dict] = None,
|
79 |
+
indices: Optional[Union[int, Sequence[int]]] = None,
|
80 |
+
serialize_data: bool = True,
|
81 |
+
pipeline: Sequence = (),
|
82 |
+
test_mode: bool = False,
|
83 |
+
lazy_init: bool = False,
|
84 |
+
max_refetch: int = 1000,
|
85 |
+
classes: Union[str, Sequence[str], None] = None):
|
86 |
+
if isinstance(data_prefix, str):
|
87 |
+
data_prefix = dict(img_path=expanduser(data_prefix))
|
88 |
+
|
89 |
+
ann_file = expanduser(ann_file)
|
90 |
+
metainfo = self._compat_classes(metainfo, classes)
|
91 |
+
|
92 |
+
super().__init__(
|
93 |
+
ann_file=ann_file,
|
94 |
+
metainfo=metainfo,
|
95 |
+
data_root=data_root,
|
96 |
+
data_prefix=data_prefix,
|
97 |
+
filter_cfg=filter_cfg,
|
98 |
+
indices=indices,
|
99 |
+
serialize_data=serialize_data,
|
100 |
+
pipeline=pipeline,
|
101 |
+
test_mode=test_mode,
|
102 |
+
lazy_init=lazy_init,
|
103 |
+
max_refetch=max_refetch)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def img_prefix(self):
|
107 |
+
"""The prefix of images."""
|
108 |
+
return self.data_prefix['img_path']
|
109 |
+
|
110 |
+
@property
|
111 |
+
def CLASSES(self):
|
112 |
+
"""Return all categories names."""
|
113 |
+
return self._metainfo.get('classes', None)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def class_to_idx(self):
|
117 |
+
"""Map mapping class name to class index.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
dict: mapping from class name to class index.
|
121 |
+
"""
|
122 |
+
|
123 |
+
return {cat: i for i, cat in enumerate(self.CLASSES)}
|
124 |
+
|
125 |
+
def get_gt_labels(self):
|
126 |
+
"""Get all ground-truth labels (categories).
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
np.ndarray: categories for all images.
|
130 |
+
"""
|
131 |
+
|
132 |
+
gt_labels = np.array(
|
133 |
+
[self.get_data_info(i)['gt_label'] for i in range(len(self))])
|
134 |
+
return gt_labels
|
135 |
+
|
136 |
+
def get_cat_ids(self, idx: int) -> List[int]:
|
137 |
+
"""Get category id by index.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
idx (int): Index of data.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
cat_ids (List[int]): Image category of specified index.
|
144 |
+
"""
|
145 |
+
|
146 |
+
return [int(self.get_data_info(idx)['gt_label'])]
|
147 |
+
|
148 |
+
def _compat_classes(self, metainfo, classes):
|
149 |
+
"""Merge the old style ``classes`` arguments to ``metainfo``."""
|
150 |
+
if isinstance(classes, str):
|
151 |
+
# take it as a file path
|
152 |
+
class_names = mmengine.list_from_file(expanduser(classes))
|
153 |
+
elif isinstance(classes, (tuple, list)):
|
154 |
+
class_names = classes
|
155 |
+
elif classes is not None:
|
156 |
+
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
157 |
+
|
158 |
+
if metainfo is None:
|
159 |
+
metainfo = {}
|
160 |
+
|
161 |
+
if classes is not None:
|
162 |
+
metainfo = {'classes': tuple(class_names), **metainfo}
|
163 |
+
|
164 |
+
return metainfo
|
165 |
+
|
166 |
+
def full_init(self):
|
167 |
+
"""Load annotation file and set ``BaseDataset._fully_initialized`` to
|
168 |
+
True."""
|
169 |
+
super().full_init()
|
170 |
+
|
171 |
+
# To support the standard OpenMMLab 2.0 annotation format. Generate
|
172 |
+
# metainfo in internal format from standard metainfo format.
|
173 |
+
if 'categories' in self._metainfo and 'classes' not in self._metainfo:
|
174 |
+
categories = sorted(
|
175 |
+
self._metainfo['categories'], key=lambda x: x['id'])
|
176 |
+
self._metainfo['classes'] = tuple(
|
177 |
+
[cat['category_name'] for cat in categories])
|
178 |
+
|
179 |
+
def __repr__(self):
|
180 |
+
"""Print the basic information of the dataset.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
str: Formatted string.
|
184 |
+
"""
|
185 |
+
head = 'Dataset ' + self.__class__.__name__
|
186 |
+
body = []
|
187 |
+
if self._fully_initialized:
|
188 |
+
body.append(f'Number of samples: \t{self.__len__()}')
|
189 |
+
else:
|
190 |
+
body.append("Haven't been initialized")
|
191 |
+
|
192 |
+
if self.CLASSES is not None:
|
193 |
+
body.append(f'Number of categories: \t{len(self.CLASSES)}')
|
194 |
+
else:
|
195 |
+
body.append('The `CLASSES` meta info is not set.')
|
196 |
+
|
197 |
+
body.extend(self.extra_repr())
|
198 |
+
|
199 |
+
if len(self.pipeline.transforms) > 0:
|
200 |
+
body.append('With transforms:')
|
201 |
+
for t in self.pipeline.transforms:
|
202 |
+
body.append(f' {t}')
|
203 |
+
|
204 |
+
lines = [head] + [' ' * 4 + line for line in body]
|
205 |
+
return '\n'.join(lines)
|
206 |
+
|
207 |
+
def extra_repr(self) -> List[str]:
|
208 |
+
"""The extra repr information of the dataset."""
|
209 |
+
body = []
|
210 |
+
body.append(f'Annotation file: \t{self.ann_file}')
|
211 |
+
body.append(f'Prefix of images: \t{self.img_prefix}')
|
212 |
+
return body
|
mmpl/datasets/builder.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmpl.registry import DATASETS
|
3 |
+
|
4 |
+
|
5 |
+
def build_dataset(cfg):
|
6 |
+
"""Build dataset.
|
7 |
+
|
8 |
+
Examples:
|
9 |
+
>>> from mmpl.datasets import build_dataset
|
10 |
+
>>> mnist_train = build_dataset(
|
11 |
+
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
|
12 |
+
>>> print(mnist_train)
|
13 |
+
Dataset MNIST
|
14 |
+
Number of samples: 60000
|
15 |
+
Number of categories: 10
|
16 |
+
Prefix of data: data/mnist/
|
17 |
+
>>> mnist_test = build_dataset(
|
18 |
+
... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True))
|
19 |
+
>>> print(mnist_test)
|
20 |
+
Dataset MNIST
|
21 |
+
Number of samples: 10000
|
22 |
+
Number of categories: 10
|
23 |
+
Prefix of data: data/mnist/
|
24 |
+
"""
|
25 |
+
return DATASETS.build(cfg)
|
mmpl/datasets/custom.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
from mmengine.fileio import (BaseStorageBackend, get_file_backend,
|
5 |
+
list_from_file)
|
6 |
+
from mmengine.logging import MMLogger
|
7 |
+
|
8 |
+
from mmcls.registry import DATASETS
|
9 |
+
from .base_dataset import BaseDataset
|
10 |
+
|
11 |
+
|
12 |
+
def find_folders(
|
13 |
+
root: str,
|
14 |
+
backend: Optional[BaseStorageBackend] = None
|
15 |
+
) -> Tuple[List[str], Dict[str, int]]:
|
16 |
+
"""Find classes by folders under a root.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
root (string): root directory of folders
|
20 |
+
backend (BaseStorageBackend | None): The file backend of the root.
|
21 |
+
If None, auto infer backend from the root path. Defaults to None.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Tuple[List[str], Dict[str, int]]:
|
25 |
+
|
26 |
+
- folders: The name of sub folders under the root.
|
27 |
+
- folder_to_idx: The map from folder name to class idx.
|
28 |
+
"""
|
29 |
+
# Pre-build file backend to prevent verbose file backend inference.
|
30 |
+
backend = backend or get_file_backend(root, enable_singleton=True)
|
31 |
+
folders = list(
|
32 |
+
backend.list_dir_or_file(
|
33 |
+
root,
|
34 |
+
list_dir=True,
|
35 |
+
list_file=False,
|
36 |
+
recursive=False,
|
37 |
+
))
|
38 |
+
folders.sort()
|
39 |
+
folder_to_idx = {folders[i]: i for i in range(len(folders))}
|
40 |
+
return folders, folder_to_idx
|
41 |
+
|
42 |
+
|
43 |
+
def get_samples(
|
44 |
+
root: str,
|
45 |
+
folder_to_idx: Dict[str, int],
|
46 |
+
is_valid_file: Callable,
|
47 |
+
backend: Optional[BaseStorageBackend] = None,
|
48 |
+
):
|
49 |
+
"""Make dataset by walking all images under a root.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
root (string): root directory of folders
|
53 |
+
folder_to_idx (dict): the map from class name to class idx
|
54 |
+
is_valid_file (Callable): A function that takes path of a file
|
55 |
+
and check if the file is a valid sample file.
|
56 |
+
backend (BaseStorageBackend | None): The file backend of the root.
|
57 |
+
If None, auto infer backend from the root path. Defaults to None.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tuple[list, set]:
|
61 |
+
|
62 |
+
- samples: a list of tuple where each element is (image, class_idx)
|
63 |
+
- empty_folders: The folders don't have any valid files.
|
64 |
+
"""
|
65 |
+
samples = []
|
66 |
+
available_classes = set()
|
67 |
+
# Pre-build file backend to prevent verbose file backend inference.
|
68 |
+
backend = backend or get_file_backend(root, enable_singleton=True)
|
69 |
+
|
70 |
+
for folder_name in sorted(list(folder_to_idx.keys())):
|
71 |
+
_dir = backend.join_path(root, folder_name)
|
72 |
+
files = backend.list_dir_or_file(
|
73 |
+
_dir,
|
74 |
+
list_dir=False,
|
75 |
+
list_file=True,
|
76 |
+
recursive=True,
|
77 |
+
)
|
78 |
+
for file in sorted(list(files)):
|
79 |
+
if is_valid_file(file):
|
80 |
+
path = backend.join_path(folder_name, file)
|
81 |
+
item = (path, folder_to_idx[folder_name])
|
82 |
+
samples.append(item)
|
83 |
+
available_classes.add(folder_name)
|
84 |
+
|
85 |
+
empty_folders = set(folder_to_idx.keys()) - available_classes
|
86 |
+
|
87 |
+
return samples, empty_folders
|
88 |
+
|
89 |
+
|
90 |
+
@DATASETS.register_module()
|
91 |
+
class CustomDataset(BaseDataset):
|
92 |
+
"""Custom dataset for classification.
|
93 |
+
|
94 |
+
The dataset supports two kinds of annotation format.
|
95 |
+
|
96 |
+
1. An annotation file is provided, and each line indicates a sample:
|
97 |
+
|
98 |
+
The sample files: ::
|
99 |
+
|
100 |
+
data_prefix/
|
101 |
+
├── folder_1
|
102 |
+
│ ├── xxx.png
|
103 |
+
│ ├── xxy.png
|
104 |
+
│ └── ...
|
105 |
+
└── folder_2
|
106 |
+
├── 123.png
|
107 |
+
├── nsdf3.png
|
108 |
+
└── ...
|
109 |
+
|
110 |
+
The annotation file (the first column is the image path and the second
|
111 |
+
column is the index of category): ::
|
112 |
+
|
113 |
+
folder_1/xxx.png 0
|
114 |
+
folder_1/xxy.png 1
|
115 |
+
folder_2/123.png 5
|
116 |
+
folder_2/nsdf3.png 3
|
117 |
+
...
|
118 |
+
|
119 |
+
Please specify the name of categories by the argument ``classes``
|
120 |
+
or ``metainfo``.
|
121 |
+
|
122 |
+
2. The samples are arranged in the specific way: ::
|
123 |
+
|
124 |
+
data_prefix/
|
125 |
+
├── class_x
|
126 |
+
│ ├── xxx.png
|
127 |
+
│ ├── xxy.png
|
128 |
+
│ └── ...
|
129 |
+
│ └── xxz.png
|
130 |
+
└── class_y
|
131 |
+
├── 123.png
|
132 |
+
├── nsdf3.png
|
133 |
+
├── ...
|
134 |
+
└── asd932_.png
|
135 |
+
|
136 |
+
If the ``ann_file`` is specified, the dataset will be generated by the
|
137 |
+
first way, otherwise, try the second way.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
ann_file (str): Annotation file path. Defaults to ''.
|
141 |
+
metainfo (dict, optional): Meta information for dataset, such as class
|
142 |
+
information. Defaults to None.
|
143 |
+
data_root (str): The root directory for ``data_prefix`` and
|
144 |
+
``ann_file``. Defaults to ''.
|
145 |
+
data_prefix (str | dict): Prefix for the data. Defaults to ''.
|
146 |
+
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
|
147 |
+
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
|
148 |
+
lazy_init (bool): Whether to load annotation during instantiation.
|
149 |
+
In some cases, such as visualization, only the meta information of
|
150 |
+
the dataset is needed, which is not necessary to load annotation
|
151 |
+
file. ``Basedataset`` can skip load annotations to save time by set
|
152 |
+
``lazy_init=False``. Defaults to False.
|
153 |
+
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self,
|
157 |
+
ann_file: str = '',
|
158 |
+
metainfo: Optional[dict] = None,
|
159 |
+
data_root: str = '',
|
160 |
+
data_prefix: Union[str, dict] = '',
|
161 |
+
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
|
162 |
+
'.bmp', '.pgm', '.tif'),
|
163 |
+
lazy_init: bool = False,
|
164 |
+
**kwargs):
|
165 |
+
assert (ann_file or data_prefix or data_root), \
|
166 |
+
'One of `ann_file`, `data_root` and `data_prefix` must '\
|
167 |
+
'be specified.'
|
168 |
+
|
169 |
+
self.extensions = tuple(set([i.lower() for i in extensions]))
|
170 |
+
|
171 |
+
super().__init__(
|
172 |
+
# The base class requires string ann_file but this class doesn't
|
173 |
+
ann_file=ann_file,
|
174 |
+
metainfo=metainfo,
|
175 |
+
data_root=data_root,
|
176 |
+
data_prefix=data_prefix,
|
177 |
+
# Force to lazy_init for some modification before loading data.
|
178 |
+
lazy_init=True,
|
179 |
+
**kwargs)
|
180 |
+
|
181 |
+
# Full initialize the dataset.
|
182 |
+
if not lazy_init:
|
183 |
+
self.full_init()
|
184 |
+
|
185 |
+
def _find_samples(self):
|
186 |
+
"""find samples from ``data_prefix``."""
|
187 |
+
classes, folder_to_idx = find_folders(self.img_prefix)
|
188 |
+
samples, empty_classes = get_samples(
|
189 |
+
self.img_prefix,
|
190 |
+
folder_to_idx,
|
191 |
+
is_valid_file=self.is_valid_file,
|
192 |
+
)
|
193 |
+
|
194 |
+
if len(samples) == 0:
|
195 |
+
raise RuntimeError(
|
196 |
+
f'Found 0 files in subfolders of: {self.data_prefix}. '
|
197 |
+
f'Supported extensions are: {",".join(self.extensions)}')
|
198 |
+
|
199 |
+
if self.CLASSES is not None:
|
200 |
+
assert len(self.CLASSES) == len(classes), \
|
201 |
+
f"The number of subfolders ({len(classes)}) doesn't match " \
|
202 |
+
f'the number of specified classes ({len(self.CLASSES)}). ' \
|
203 |
+
'Please check the data folder.'
|
204 |
+
else:
|
205 |
+
self._metainfo['classes'] = tuple(classes)
|
206 |
+
|
207 |
+
if empty_classes:
|
208 |
+
logger = MMLogger.get_current_instance()
|
209 |
+
logger.warning(
|
210 |
+
'Found no valid file in the folder '
|
211 |
+
f'{", ".join(empty_classes)}. '
|
212 |
+
f"Supported extensions are: {', '.join(self.extensions)}")
|
213 |
+
|
214 |
+
self.folder_to_idx = folder_to_idx
|
215 |
+
|
216 |
+
return samples
|
217 |
+
|
218 |
+
def load_data_list(self):
|
219 |
+
"""Load image paths and gt_labels."""
|
220 |
+
if not self.ann_file:
|
221 |
+
samples = self._find_samples()
|
222 |
+
else:
|
223 |
+
lines = list_from_file(self.ann_file)
|
224 |
+
samples = [x.strip().rsplit(' ', 1) for x in lines]
|
225 |
+
|
226 |
+
# Pre-build file backend to prevent verbose file backend inference.
|
227 |
+
backend = get_file_backend(self.img_prefix, enable_singleton=True)
|
228 |
+
data_list = []
|
229 |
+
for filename, gt_label in samples:
|
230 |
+
img_path = backend.join_path(self.img_prefix, filename)
|
231 |
+
info = {'img_path': img_path, 'gt_label': int(gt_label)}
|
232 |
+
data_list.append(info)
|
233 |
+
return data_list
|
234 |
+
|
235 |
+
def is_valid_file(self, filename: str) -> bool:
|
236 |
+
"""Check if a file is a valid sample."""
|
237 |
+
return filename.lower().endswith(self.extensions)
|
mmpl/datasets/nwpu_ins_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from mmpl.registry import DATASETS
|
4 |
+
from mmdet.datasets.coco import CocoDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class NWPUInsSegDataset(CocoDataset):
|
9 |
+
"""Dataset for Cityscapes."""
|
10 |
+
|
11 |
+
METAINFO = {
|
12 |
+
'classes': ['airplane', 'ship', 'storage_tank', 'baseball_diamond',
|
13 |
+
'tennis_court', 'basketball_court', 'ground_track_field',
|
14 |
+
'harbor', 'bridge', 'vehicle'],
|
15 |
+
'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
|
16 |
+
(0, 60, 100), (0, 80, 100), (0, 0, 230),
|
17 |
+
(119, 11, 32), (0, 255, 0), (0, 0, 255)]
|
18 |
+
}
|
19 |
+
|
20 |
+
def filter_data(self) -> List[dict]:
|
21 |
+
"""Filter annotations according to filter_cfg.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
List[dict]: Filtered results.
|
25 |
+
"""
|
26 |
+
if self.test_mode:
|
27 |
+
return self.data_list
|
28 |
+
|
29 |
+
if self.filter_cfg is None:
|
30 |
+
return self.data_list
|
31 |
+
|
32 |
+
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
|
33 |
+
min_size = self.filter_cfg.get('min_size', 0)
|
34 |
+
|
35 |
+
# obtain images that contain annotation
|
36 |
+
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
|
37 |
+
# obtain images that contain annotations of the required categories
|
38 |
+
ids_in_cat = set()
|
39 |
+
for i, class_id in enumerate(self.cat_ids):
|
40 |
+
ids_in_cat |= set(self.cat_img_map[class_id])
|
41 |
+
# merge the image id sets of the two conditions and use the merged set
|
42 |
+
# to filter out images if self.filter_empty_gt=True
|
43 |
+
ids_in_cat &= ids_with_ann
|
44 |
+
|
45 |
+
valid_data_infos = []
|
46 |
+
for i, data_info in enumerate(self.data_list):
|
47 |
+
img_id = data_info['img_id']
|
48 |
+
width = data_info['width']
|
49 |
+
height = data_info['height']
|
50 |
+
all_is_crowd = all([
|
51 |
+
instance['ignore_flag'] == 1
|
52 |
+
for instance in data_info['instances']
|
53 |
+
])
|
54 |
+
if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
|
55 |
+
continue
|
56 |
+
if min(width, height) >= min_size:
|
57 |
+
valid_data_infos.append(data_info)
|
58 |
+
|
59 |
+
return valid_data_infos
|
mmpl/datasets/pl_datamodule.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmpl.registry import DATASETS
|
2 |
+
import lightning.pytorch as pl
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from .builder import build_dataset
|
5 |
+
from mmengine.registry import FUNCTIONS
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
|
9 |
+
def get_collate_fn(dataloader_cfg):
|
10 |
+
collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate'))
|
11 |
+
collate_fn_type = collate_fn_cfg.pop('type')
|
12 |
+
collate_fn = FUNCTIONS.get(collate_fn_type)
|
13 |
+
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
14 |
+
return collate_fn
|
15 |
+
|
16 |
+
|
17 |
+
@DATASETS.register_module()
|
18 |
+
class PLDataModule(pl.LightningDataModule):
|
19 |
+
def __init__(self,
|
20 |
+
train_loader=None,
|
21 |
+
val_loader=None,
|
22 |
+
test_loader=None,
|
23 |
+
predict_loader=None,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.train_loader = train_loader
|
28 |
+
self.val_loader = val_loader
|
29 |
+
self.test_loader = test_loader
|
30 |
+
self.predict_loader = predict_loader
|
31 |
+
self.train_dataset = None
|
32 |
+
self.val_dataset = None
|
33 |
+
self.test_dataset = None
|
34 |
+
self.predict_dataset = None
|
35 |
+
|
36 |
+
def prepare_data(self):
|
37 |
+
pass
|
38 |
+
|
39 |
+
def setup(self, stage: str):
|
40 |
+
if stage == "fit":
|
41 |
+
dataset_cfg = self.train_loader.pop('dataset')
|
42 |
+
self.train_dataset = build_dataset(dataset_cfg)
|
43 |
+
if self.val_loader is not None:
|
44 |
+
dataset_cfg = self.val_loader.pop('dataset')
|
45 |
+
self.val_dataset = build_dataset(dataset_cfg)
|
46 |
+
if stage == "val":
|
47 |
+
if self.val_loader is not None:
|
48 |
+
dataset_cfg = self.val_loader.pop('dataset')
|
49 |
+
self.val_dataset = build_dataset(dataset_cfg)
|
50 |
+
if stage == "test":
|
51 |
+
if self.test_loader is not None:
|
52 |
+
dataset_cfg = self.test_loader.pop('dataset')
|
53 |
+
self.test_dataset = build_dataset(dataset_cfg)
|
54 |
+
if stage == "predict":
|
55 |
+
if self.predict_loader is not None:
|
56 |
+
dataset_cfg = self.predict_loader.pop('dataset')
|
57 |
+
self.predict_dataset = build_dataset(dataset_cfg)
|
58 |
+
|
59 |
+
def train_dataloader(self):
|
60 |
+
collate_fn = get_collate_fn(self.train_loader)
|
61 |
+
return DataLoader(self.train_dataset, collate_fn=collate_fn, **self.train_loader)
|
62 |
+
|
63 |
+
def val_dataloader(self):
|
64 |
+
collate_fn = get_collate_fn(self.val_loader)
|
65 |
+
return DataLoader(self.val_dataset, collate_fn=collate_fn, **self.val_loader)
|
66 |
+
|
67 |
+
def test_dataloader(self):
|
68 |
+
collate_fn = get_collate_fn(self.test_loader)
|
69 |
+
return DataLoader(self.test_dataset, collate_fn=collate_fn, **self.test_loader)
|
70 |
+
|
71 |
+
def predict_dataloader(self):
|
72 |
+
collate_fn = get_collate_fn(self.predict_loader)
|
73 |
+
return DataLoader(self.predict_dataset, collate_fn=collate_fn, **self.predict_loader)
|
mmpl/datasets/ssdd_ins_dataset.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from mmpl.registry import DATASETS
|
3 |
+
from mmdet.datasets.coco import CocoDataset
|
4 |
+
|
5 |
+
|
6 |
+
@DATASETS.register_module()
|
7 |
+
class SSDDInsSegDataset(CocoDataset):
|
8 |
+
"""Dataset for Cityscapes."""
|
9 |
+
|
10 |
+
METAINFO = {
|
11 |
+
'classes': ['ship'],
|
12 |
+
'palette': [(0, 0, 255)]
|
13 |
+
}
|
14 |
+
|
15 |
+
def filter_data(self) -> List[dict]:
|
16 |
+
"""Filter annotations according to filter_cfg.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
List[dict]: Filtered results.
|
20 |
+
"""
|
21 |
+
# if self.test_mode:
|
22 |
+
# return self.data_list
|
23 |
+
|
24 |
+
if self.filter_cfg is None:
|
25 |
+
return self.data_list
|
26 |
+
|
27 |
+
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
|
28 |
+
min_size = self.filter_cfg.get('min_size', 0)
|
29 |
+
|
30 |
+
# obtain images that contain annotation
|
31 |
+
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
|
32 |
+
# obtain images that contain annotations of the required categories
|
33 |
+
ids_in_cat = set()
|
34 |
+
for i, class_id in enumerate(self.cat_ids):
|
35 |
+
ids_in_cat |= set(self.cat_img_map[class_id])
|
36 |
+
# merge the image id sets of the two conditions and use the merged set
|
37 |
+
# to filter out images if self.filter_empty_gt=True
|
38 |
+
ids_in_cat &= ids_with_ann
|
39 |
+
|
40 |
+
valid_data_infos = []
|
41 |
+
for i, data_info in enumerate(self.data_list):
|
42 |
+
img_id = data_info['img_id']
|
43 |
+
width = data_info['width']
|
44 |
+
height = data_info['height']
|
45 |
+
all_is_crowd = all([
|
46 |
+
instance['ignore_flag'] == 1
|
47 |
+
for instance in data_info['instances']
|
48 |
+
])
|
49 |
+
if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
|
50 |
+
continue
|
51 |
+
if min(width, height) >= min_size:
|
52 |
+
valid_data_infos.append(data_info)
|
53 |
+
|
54 |
+
return valid_data_infos
|
mmpl/datasets/transforms/__init__.py
ADDED
File without changes
|
mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (162 Bytes). View file
|
|
mmpl/datasets/utils.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import gzip
|
3 |
+
import hashlib
|
4 |
+
import os
|
5 |
+
import os.path
|
6 |
+
import shutil
|
7 |
+
import tarfile
|
8 |
+
import tempfile
|
9 |
+
import urllib.error
|
10 |
+
import urllib.request
|
11 |
+
import zipfile
|
12 |
+
|
13 |
+
from mmengine.fileio import LocalBackend, get_file_backend
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
|
17 |
+
'open_maybe_compressed_file'
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def rm_suffix(s, suffix=None):
|
22 |
+
if suffix is None:
|
23 |
+
return s[:s.rfind('.')]
|
24 |
+
else:
|
25 |
+
return s[:s.rfind(suffix)]
|
26 |
+
|
27 |
+
|
28 |
+
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
|
29 |
+
md5 = hashlib.md5()
|
30 |
+
backend = get_file_backend(fpath, enable_singleton=True)
|
31 |
+
if isinstance(backend, LocalBackend):
|
32 |
+
# Enable chunk update for local file.
|
33 |
+
with open(fpath, 'rb') as f:
|
34 |
+
for chunk in iter(lambda: f.read(chunk_size), b''):
|
35 |
+
md5.update(chunk)
|
36 |
+
else:
|
37 |
+
md5.update(backend.get(fpath))
|
38 |
+
return md5.hexdigest()
|
39 |
+
|
40 |
+
|
41 |
+
def check_md5(fpath, md5, **kwargs):
|
42 |
+
return md5 == calculate_md5(fpath, **kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
def check_integrity(fpath, md5=None):
|
46 |
+
if not os.path.isfile(fpath):
|
47 |
+
return False
|
48 |
+
if md5 is None:
|
49 |
+
return True
|
50 |
+
return check_md5(fpath, md5)
|
51 |
+
|
52 |
+
|
53 |
+
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
54 |
+
"""Download object at the given URL to a local path.
|
55 |
+
|
56 |
+
Modified from
|
57 |
+
https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file
|
58 |
+
|
59 |
+
Args:
|
60 |
+
url (str): URL of the object to download
|
61 |
+
dst (str): Full path where object will be saved,
|
62 |
+
e.g. ``/tmp/temporary_file``
|
63 |
+
hash_prefix (string, optional): If not None, the SHA256 downloaded
|
64 |
+
file should start with ``hash_prefix``. Defaults to None.
|
65 |
+
progress (bool): whether or not to display a progress bar to stderr.
|
66 |
+
Defaults to True
|
67 |
+
"""
|
68 |
+
file_size = None
|
69 |
+
req = urllib.request.Request(url)
|
70 |
+
u = urllib.request.urlopen(req)
|
71 |
+
meta = u.info()
|
72 |
+
if hasattr(meta, 'getheaders'):
|
73 |
+
content_length = meta.getheaders('Content-Length')
|
74 |
+
else:
|
75 |
+
content_length = meta.get_all('Content-Length')
|
76 |
+
if content_length is not None and len(content_length) > 0:
|
77 |
+
file_size = int(content_length[0])
|
78 |
+
|
79 |
+
# We deliberately save it in a temp file and move it after download is
|
80 |
+
# complete. This prevents a local file being overridden by a broken
|
81 |
+
# download.
|
82 |
+
dst = os.path.expanduser(dst)
|
83 |
+
dst_dir = os.path.dirname(dst)
|
84 |
+
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
85 |
+
|
86 |
+
import rich.progress
|
87 |
+
columns = [
|
88 |
+
rich.progress.DownloadColumn(),
|
89 |
+
rich.progress.BarColumn(bar_width=None),
|
90 |
+
rich.progress.TimeRemainingColumn(),
|
91 |
+
]
|
92 |
+
try:
|
93 |
+
if hash_prefix is not None:
|
94 |
+
sha256 = hashlib.sha256()
|
95 |
+
with rich.progress.Progress(*columns) as pbar:
|
96 |
+
task = pbar.add_task('download', total=file_size, visible=progress)
|
97 |
+
while True:
|
98 |
+
buffer = u.read(8192)
|
99 |
+
if len(buffer) == 0:
|
100 |
+
break
|
101 |
+
f.write(buffer)
|
102 |
+
if hash_prefix is not None:
|
103 |
+
sha256.update(buffer)
|
104 |
+
pbar.update(task, advance=len(buffer))
|
105 |
+
|
106 |
+
f.close()
|
107 |
+
if hash_prefix is not None:
|
108 |
+
digest = sha256.hexdigest()
|
109 |
+
if digest[:len(hash_prefix)] != hash_prefix:
|
110 |
+
raise RuntimeError(
|
111 |
+
'invalid hash value (expected "{}", got "{}")'.format(
|
112 |
+
hash_prefix, digest))
|
113 |
+
shutil.move(f.name, dst)
|
114 |
+
finally:
|
115 |
+
f.close()
|
116 |
+
if os.path.exists(f.name):
|
117 |
+
os.remove(f.name)
|
118 |
+
|
119 |
+
|
120 |
+
def download_url(url, root, filename=None, md5=None):
|
121 |
+
"""Download a file from a url and place it in root.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
url (str): URL to download file from.
|
125 |
+
root (str): Directory to place downloaded file in.
|
126 |
+
filename (str | None): Name to save the file under.
|
127 |
+
If filename is None, use the basename of the URL.
|
128 |
+
md5 (str | None): MD5 checksum of the download.
|
129 |
+
If md5 is None, download without md5 check.
|
130 |
+
"""
|
131 |
+
root = os.path.expanduser(root)
|
132 |
+
if not filename:
|
133 |
+
filename = os.path.basename(url)
|
134 |
+
fpath = os.path.join(root, filename)
|
135 |
+
|
136 |
+
os.makedirs(root, exist_ok=True)
|
137 |
+
|
138 |
+
if check_integrity(fpath, md5):
|
139 |
+
print(f'Using downloaded and verified file: {fpath}')
|
140 |
+
else:
|
141 |
+
try:
|
142 |
+
print(f'Downloading {url} to {fpath}')
|
143 |
+
download_url_to_file(url, fpath)
|
144 |
+
except (urllib.error.URLError, IOError) as e:
|
145 |
+
if url[:5] == 'https':
|
146 |
+
url = url.replace('https:', 'http:')
|
147 |
+
print('Failed download. Trying https -> http instead.'
|
148 |
+
f' Downloading {url} to {fpath}')
|
149 |
+
download_url_to_file(url, fpath)
|
150 |
+
else:
|
151 |
+
raise e
|
152 |
+
# check integrity of downloaded file
|
153 |
+
if not check_integrity(fpath, md5):
|
154 |
+
raise RuntimeError('File not found or corrupted.')
|
155 |
+
|
156 |
+
|
157 |
+
def _is_tarxz(filename):
|
158 |
+
return filename.endswith('.tar.xz')
|
159 |
+
|
160 |
+
|
161 |
+
def _is_tar(filename):
|
162 |
+
return filename.endswith('.tar')
|
163 |
+
|
164 |
+
|
165 |
+
def _is_targz(filename):
|
166 |
+
return filename.endswith('.tar.gz')
|
167 |
+
|
168 |
+
|
169 |
+
def _is_tgz(filename):
|
170 |
+
return filename.endswith('.tgz')
|
171 |
+
|
172 |
+
|
173 |
+
def _is_gzip(filename):
|
174 |
+
return filename.endswith('.gz') and not filename.endswith('.tar.gz')
|
175 |
+
|
176 |
+
|
177 |
+
def _is_zip(filename):
|
178 |
+
return filename.endswith('.zip')
|
179 |
+
|
180 |
+
|
181 |
+
def extract_archive(from_path, to_path=None, remove_finished=False):
|
182 |
+
if to_path is None:
|
183 |
+
to_path = os.path.dirname(from_path)
|
184 |
+
|
185 |
+
if _is_tar(from_path):
|
186 |
+
with tarfile.open(from_path, 'r') as tar:
|
187 |
+
tar.extractall(path=to_path)
|
188 |
+
elif _is_targz(from_path) or _is_tgz(from_path):
|
189 |
+
with tarfile.open(from_path, 'r:gz') as tar:
|
190 |
+
tar.extractall(path=to_path)
|
191 |
+
elif _is_tarxz(from_path):
|
192 |
+
with tarfile.open(from_path, 'r:xz') as tar:
|
193 |
+
tar.extractall(path=to_path)
|
194 |
+
elif _is_gzip(from_path):
|
195 |
+
to_path = os.path.join(
|
196 |
+
to_path,
|
197 |
+
os.path.splitext(os.path.basename(from_path))[0])
|
198 |
+
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
|
199 |
+
out_f.write(zip_f.read())
|
200 |
+
elif _is_zip(from_path):
|
201 |
+
with zipfile.ZipFile(from_path, 'r') as z:
|
202 |
+
z.extractall(to_path)
|
203 |
+
else:
|
204 |
+
raise ValueError(f'Extraction of {from_path} not supported')
|
205 |
+
|
206 |
+
if remove_finished:
|
207 |
+
os.remove(from_path)
|
208 |
+
|
209 |
+
|
210 |
+
def download_and_extract_archive(url,
|
211 |
+
download_root,
|
212 |
+
extract_root=None,
|
213 |
+
filename=None,
|
214 |
+
md5=None,
|
215 |
+
remove_finished=False):
|
216 |
+
download_root = os.path.expanduser(download_root)
|
217 |
+
if extract_root is None:
|
218 |
+
extract_root = download_root
|
219 |
+
if not filename:
|
220 |
+
filename = os.path.basename(url)
|
221 |
+
|
222 |
+
download_url(url, download_root, filename, md5)
|
223 |
+
|
224 |
+
archive = os.path.join(download_root, filename)
|
225 |
+
print(f'Extracting {archive} to {extract_root}')
|
226 |
+
extract_archive(archive, extract_root, remove_finished)
|
227 |
+
|
228 |
+
|
229 |
+
def open_maybe_compressed_file(path: str):
|
230 |
+
"""Return a file object that possibly decompresses 'path' on the fly.
|
231 |
+
|
232 |
+
Decompression occurs when argument `path` is a string and ends with '.gz'
|
233 |
+
or '.xz'.
|
234 |
+
"""
|
235 |
+
if not isinstance(path, str):
|
236 |
+
return path
|
237 |
+
if path.endswith('.gz'):
|
238 |
+
import gzip
|
239 |
+
return gzip.open(path, 'rb')
|
240 |
+
if path.endswith('.xz'):
|
241 |
+
import lzma
|
242 |
+
return lzma.open(path, 'rb')
|
243 |
+
return open(path, 'rb')
|
mmpl/datasets/whu_ins_dataset.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from mmpl.registry import DATASETS
|
3 |
+
from mmdet.datasets.coco import CocoDataset
|
4 |
+
|
5 |
+
|
6 |
+
@DATASETS.register_module()
|
7 |
+
class WHUInsSegDataset(CocoDataset):
|
8 |
+
"""Dataset for Cityscapes."""
|
9 |
+
|
10 |
+
METAINFO = {
|
11 |
+
'classes': ['building'],
|
12 |
+
'palette': [(0, 255, 0)]
|
13 |
+
}
|
14 |
+
|
15 |
+
def filter_data(self) -> List[dict]:
|
16 |
+
"""Filter annotations according to filter_cfg.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
List[dict]: Filtered results.
|
20 |
+
"""
|
21 |
+
# if self.test_mode:
|
22 |
+
# return self.data_list
|
23 |
+
|
24 |
+
if self.filter_cfg is None:
|
25 |
+
return self.data_list
|
26 |
+
|
27 |
+
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
|
28 |
+
min_size = self.filter_cfg.get('min_size', 0)
|
29 |
+
|
30 |
+
# obtain images that contain annotation
|
31 |
+
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
|
32 |
+
# obtain images that contain annotations of the required categories
|
33 |
+
ids_in_cat = set()
|
34 |
+
for i, class_id in enumerate(self.cat_ids):
|
35 |
+
ids_in_cat |= set(self.cat_img_map[class_id])
|
36 |
+
# merge the image id sets of the two conditions and use the merged set
|
37 |
+
# to filter out images if self.filter_empty_gt=True
|
38 |
+
ids_in_cat &= ids_with_ann
|
39 |
+
|
40 |
+
valid_data_infos = []
|
41 |
+
for i, data_info in enumerate(self.data_list):
|
42 |
+
img_id = data_info['img_id']
|
43 |
+
width = data_info['width']
|
44 |
+
height = data_info['height']
|
45 |
+
all_is_crowd = all([
|
46 |
+
instance['ignore_flag'] == 1
|
47 |
+
for instance in data_info['instances']
|
48 |
+
])
|
49 |
+
if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
|
50 |
+
continue
|
51 |
+
if min(width, height) >= min_size:
|
52 |
+
valid_data_infos.append(data_info)
|
53 |
+
|
54 |
+
return valid_data_infos
|
mmpl/engine/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runner import *
|
2 |
+
from .logger import *
|
3 |
+
from .hooks import *
|
4 |
+
from .visualization import *
|
5 |
+
from .strategies import *
|
mmpl/engine/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (254 Bytes). View file
|
|
mmpl/engine/hooks/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import PL_HOOKS
|
2 |
+
from .pipeline_switch_hook import PipelineSwitchHook
|
3 |
+
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
|
4 |
+
from .ema_hook import EMAHook
|
5 |
+
from .param_scheduler_hook import ParamSchedulerHook
|
6 |
+
from .visualization_hook import DetVisualizationHook
|
mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (500 Bytes). View file
|
|
mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (1.05 kB). View file
|
|
mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc
ADDED
Binary file (8.86 kB). View file
|
|
mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc
ADDED
Binary file (4.26 kB). View file
|
|
mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc
ADDED
Binary file (1.55 kB). View file
|
|
mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc
ADDED
Binary file (6.21 kB). View file
|
|
mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc
ADDED
Binary file (3.85 kB). View file
|
|
mmpl/engine/hooks/builder.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import inspect
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import lightning
|
8 |
+
|
9 |
+
from mmengine.config import Config, ConfigDict
|
10 |
+
from mmengine.device import is_npu_available
|
11 |
+
from mmpl.registry import HOOKS
|
12 |
+
|
13 |
+
|
14 |
+
def register_pl_hooks() -> List[str]:
|
15 |
+
"""Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
List[str]: A list of registered callbacks' name.
|
19 |
+
"""
|
20 |
+
pl_hooks = []
|
21 |
+
for module_name in dir(lightning.pytorch.callbacks):
|
22 |
+
if module_name.startswith('__'):
|
23 |
+
continue
|
24 |
+
_hook = getattr(lightning.pytorch.callbacks, module_name)
|
25 |
+
if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback):
|
26 |
+
HOOKS.register_module(module=_hook)
|
27 |
+
pl_hooks.append(module_name)
|
28 |
+
return pl_hooks
|
29 |
+
|
30 |
+
|
31 |
+
PL_HOOKS = register_pl_hooks()
|
mmpl/engine/hooks/ema_hook.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import itertools
|
3 |
+
import logging
|
4 |
+
from typing import Dict, Optional, Any
|
5 |
+
|
6 |
+
from lightning import Callback
|
7 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
8 |
+
from mmengine.logging import print_log
|
9 |
+
from mmengine.model import is_model_wrapper
|
10 |
+
from mmpl.registry import HOOKS, MODELS
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
@HOOKS.register_module()
|
15 |
+
class EMAHook(Callback):
|
16 |
+
"""A Hook to apply Exponential Moving Average (EMA) on the model during
|
17 |
+
training.
|
18 |
+
|
19 |
+
Note:
|
20 |
+
- EMAHook takes priority over CheckpointHook.
|
21 |
+
- The original model parameters are actually saved in ema field after
|
22 |
+
train.
|
23 |
+
- ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
ema_type (str): The type of EMA strategy to use. You can find the
|
27 |
+
supported strategies in :mod:`mmengine.model.averaged_model`.
|
28 |
+
Defaults to 'ExponentialMovingAverage'.
|
29 |
+
strict_load (bool): Whether to strictly enforce that the keys of
|
30 |
+
``state_dict`` in checkpoint match the keys returned by
|
31 |
+
``self.module.state_dict``. Defaults to False.
|
32 |
+
Changed in v0.3.0.
|
33 |
+
begin_iter (int): The number of iteration to enable ``EMAHook``.
|
34 |
+
Defaults to 0.
|
35 |
+
begin_epoch (int): The number of epoch to enable ``EMAHook``.
|
36 |
+
Defaults to 0.
|
37 |
+
**kwargs: Keyword arguments passed to subclasses of
|
38 |
+
:obj:`BaseAveragedModel`
|
39 |
+
"""
|
40 |
+
|
41 |
+
priority = 'NORMAL'
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
ema_type: str = 'ExponentialMovingAverage',
|
45 |
+
strict_load: bool = False,
|
46 |
+
begin_iter: int = 0,
|
47 |
+
begin_epoch: int = 0,
|
48 |
+
**kwargs):
|
49 |
+
self.strict_load = strict_load
|
50 |
+
self.ema_cfg = dict(type=ema_type, **kwargs)
|
51 |
+
assert not (begin_iter != 0 and begin_epoch != 0), (
|
52 |
+
'`begin_iter` and `begin_epoch` should not be both set.')
|
53 |
+
assert begin_iter >= 0, (
|
54 |
+
'`begin_iter` must larger than or equal to 0, '
|
55 |
+
f'but got begin_iter: {begin_iter}')
|
56 |
+
assert begin_epoch >= 0, (
|
57 |
+
'`begin_epoch` must larger than or equal to 0, '
|
58 |
+
f'but got begin_epoch: {begin_epoch}')
|
59 |
+
self.begin_iter = begin_iter
|
60 |
+
self.begin_epoch = begin_epoch
|
61 |
+
# If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
|
62 |
+
# enabled at 0 iteration.
|
63 |
+
self.enabled_by_epoch = self.begin_epoch > 0
|
64 |
+
|
65 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
66 |
+
"""Create an ema copy of the model.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
runner (Runner): The runner of the training process.
|
70 |
+
"""
|
71 |
+
model = pl_module
|
72 |
+
if is_model_wrapper(model):
|
73 |
+
model = model.module
|
74 |
+
self.src_model = model
|
75 |
+
self.ema_model = MODELS.build(
|
76 |
+
self.ema_cfg, default_args=dict(model=self.src_model))
|
77 |
+
|
78 |
+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
79 |
+
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
runner (Runner): The runner of the training process.
|
83 |
+
"""
|
84 |
+
if self.enabled_by_epoch:
|
85 |
+
assert self.begin_epoch <= trainer.max_epochs, (
|
86 |
+
'self.begin_epoch should be smaller than or equal to '
|
87 |
+
f'runner.max_epochs: {trainer.max_epochs}, but got '
|
88 |
+
f'begin_epoch: {self.begin_epoch}')
|
89 |
+
else:
|
90 |
+
assert self.begin_iter <= trainer.max_steps or self.begin_iter <= trainer.max_epochs * len(trainer.train_dataloader), (
|
91 |
+
'self.begin_iter should be smaller than or equal to '
|
92 |
+
f'runner.max_iters: {trainer.max_steps}, but got '
|
93 |
+
f'begin_iter: {self.begin_iter}')
|
94 |
+
|
95 |
+
def on_train_batch_end(
|
96 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
97 |
+
) -> None:
|
98 |
+
"""Update ema parameter.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
runner (Runner): The runner of the training process.
|
102 |
+
batch_idx (int): The index of the current batch in the train loop.
|
103 |
+
data_batch (Sequence[dict], optional): Data from dataloader.
|
104 |
+
Defaults to None.
|
105 |
+
outputs (dict, optional): Outputs from model. Defaults to None.
|
106 |
+
"""
|
107 |
+
if self._ema_started(trainer):
|
108 |
+
self.ema_model.update_parameters(self.src_model)
|
109 |
+
else:
|
110 |
+
ema_params = self.ema_model.module.state_dict()
|
111 |
+
src_params = self.src_model.state_dict()
|
112 |
+
for k, p in ema_params.items():
|
113 |
+
p.data.copy_(src_params[k].data)
|
114 |
+
|
115 |
+
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
116 |
+
"""We load parameter values from ema model to source model before
|
117 |
+
validation.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
runner (Runner): The runner of the training process.
|
121 |
+
"""
|
122 |
+
self._swap_ema_parameters()
|
123 |
+
|
124 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
125 |
+
"""We recover source model's parameter from ema model after validation.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
runner (Runner): The runner of the validation process.
|
129 |
+
metrics (Dict[str, float], optional): Evaluation results of all
|
130 |
+
metrics on validation dataset. The keys are the names of the
|
131 |
+
metrics, and the values are corresponding results.
|
132 |
+
"""
|
133 |
+
self._swap_ema_parameters()
|
134 |
+
|
135 |
+
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
136 |
+
"""We load parameter values from ema model to source model before test.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
runner (Runner): The runner of the training process.
|
140 |
+
"""
|
141 |
+
self._swap_ema_parameters()
|
142 |
+
|
143 |
+
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
144 |
+
"""We recover source model's parameter from ema model after test.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
runner (Runner): The runner of the testing process.
|
148 |
+
metrics (Dict[str, float], optional): Evaluation results of all
|
149 |
+
metrics on test dataset. The keys are the names of the
|
150 |
+
metrics, and the values are corresponding results.
|
151 |
+
"""
|
152 |
+
self._swap_ema_parameters()
|
153 |
+
|
154 |
+
def on_save_checkpoint(
|
155 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
156 |
+
) -> None:
|
157 |
+
"""Save ema parameters to checkpoint.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
runner (Runner): The runner of the testing process.
|
161 |
+
"""
|
162 |
+
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
163 |
+
# Save ema parameters to the source model's state dict so that we
|
164 |
+
# can directly load the averaged model weights for deployment.
|
165 |
+
# Swapping the state_dict key-values instead of swapping model
|
166 |
+
# parameters because the state_dict is a shallow copy of model
|
167 |
+
# parameters.
|
168 |
+
self._swap_ema_state_dict(checkpoint)
|
169 |
+
|
170 |
+
def on_load_checkpoint(
|
171 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
172 |
+
) -> None:
|
173 |
+
"""Resume ema parameters from checkpoint.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
runner (Runner): The runner of the testing process.
|
177 |
+
"""
|
178 |
+
from mmengine.runner.checkpoint import load_state_dict
|
179 |
+
if 'ema_state_dict' in checkpoint and not trainer._checkpoint_connector._loaded_checkpoint:
|
180 |
+
# The original model parameters are actually saved in ema
|
181 |
+
# field swap the weights back to resume ema state.
|
182 |
+
self._swap_ema_state_dict(checkpoint)
|
183 |
+
self.ema_model.load_state_dict(
|
184 |
+
checkpoint['ema_state_dict'], strict=self.strict_load)
|
185 |
+
|
186 |
+
# Support load checkpoint without ema state dict.
|
187 |
+
else:
|
188 |
+
if not trainer._checkpoint_connector._loaded_checkpoint:
|
189 |
+
print_log(
|
190 |
+
'There is no `ema_state_dict` in checkpoint. '
|
191 |
+
'`EMAHook` will make a copy of `state_dict` as the '
|
192 |
+
'initial `ema_state_dict`', 'current', logging.WARNING)
|
193 |
+
load_state_dict(
|
194 |
+
self.ema_model.module,
|
195 |
+
copy.deepcopy(checkpoint['state_dict']),
|
196 |
+
strict=self.strict_load)
|
197 |
+
|
198 |
+
def _swap_ema_parameters(self) -> None:
|
199 |
+
"""Swap the parameter of model with ema_model."""
|
200 |
+
avg_param = (
|
201 |
+
itertools.chain(self.ema_model.module.parameters(),
|
202 |
+
self.ema_model.module.buffers())
|
203 |
+
if self.ema_model.update_buffers else
|
204 |
+
self.ema_model.module.parameters())
|
205 |
+
src_param = (
|
206 |
+
itertools.chain(self.src_model.parameters(),
|
207 |
+
self.src_model.buffers())
|
208 |
+
if self.ema_model.update_buffers else self.src_model.parameters())
|
209 |
+
for p_avg, p_src in zip(avg_param, src_param):
|
210 |
+
tmp = p_avg.data.clone()
|
211 |
+
p_avg.data.copy_(p_src.data)
|
212 |
+
p_src.data.copy_(tmp)
|
213 |
+
|
214 |
+
def _swap_ema_state_dict(self, checkpoint):
|
215 |
+
"""Swap the state dict values of model with ema_model."""
|
216 |
+
model_state = checkpoint['state_dict']
|
217 |
+
ema_state = checkpoint['ema_state_dict']
|
218 |
+
for k in ema_state:
|
219 |
+
if k[:7] == 'module.':
|
220 |
+
tmp = ema_state[k]
|
221 |
+
ema_state[k] = model_state[k[7:]]
|
222 |
+
model_state[k[7:]] = tmp
|
223 |
+
|
224 |
+
def _ema_started(self, trainer) -> bool:
|
225 |
+
"""Whether ``EMAHook`` has been initialized at current iteration or
|
226 |
+
epoch.
|
227 |
+
|
228 |
+
:attr:`ema_model` will be initialized when ``runner.iter`` or
|
229 |
+
``runner.epoch`` is greater than ``self.begin`` for the first time.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
runner (Runner): Runner of the training, validation process.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
bool: Whether ``EMAHook`` has been initialized.
|
236 |
+
"""
|
237 |
+
if self.enabled_by_epoch:
|
238 |
+
return trainer.current_epoch + 1 >= self.begin_epoch
|
239 |
+
else:
|
240 |
+
return trainer.global_step + 1 >= self.begin_iter
|
mmpl/engine/hooks/param_scheduler_hook.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Union, Any
|
2 |
+
|
3 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
4 |
+
from mmengine.optim import _ParamScheduler
|
5 |
+
from mmpl.registry import HOOKS
|
6 |
+
from mmengine.utils import is_list_of
|
7 |
+
from lightning import Callback
|
8 |
+
|
9 |
+
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
10 |
+
|
11 |
+
|
12 |
+
@HOOKS.register_module()
|
13 |
+
class ParamSchedulerHook(Callback):
|
14 |
+
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate
|
15 |
+
and momentum."""
|
16 |
+
|
17 |
+
priority = 'LOW'
|
18 |
+
|
19 |
+
def on_train_batch_end(
|
20 |
+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
21 |
+
) -> None:
|
22 |
+
"""Call step function for each scheduler after each training iteration.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
runner (Runner): The runner of the training process.
|
26 |
+
batch_idx (int): The index of the current batch in the train loop.
|
27 |
+
data_batch (dict or tuple or list, optional): Data from dataloader.
|
28 |
+
In order to keep this interface consistent with other hooks,
|
29 |
+
we keep ``data_batch`` here.
|
30 |
+
outputs (dict, optional): Outputs from model.
|
31 |
+
In order to keep this interface consistent with other hooks, we
|
32 |
+
keep ``data_batch`` here.
|
33 |
+
"""
|
34 |
+
param_schedulers = pl_module.lr_schedulers()
|
35 |
+
if param_schedulers is None:
|
36 |
+
return
|
37 |
+
|
38 |
+
def step(param_schedulers):
|
39 |
+
assert isinstance(param_schedulers, list)
|
40 |
+
for scheduler in param_schedulers:
|
41 |
+
if not scheduler.by_epoch:
|
42 |
+
scheduler.step()
|
43 |
+
if isinstance(param_schedulers, _ParamScheduler):
|
44 |
+
param_schedulers = [param_schedulers]
|
45 |
+
if isinstance(param_schedulers, list):
|
46 |
+
step(param_schedulers)
|
47 |
+
elif isinstance(param_schedulers, dict):
|
48 |
+
for param_schedulers in param_schedulers.values():
|
49 |
+
step(param_schedulers)
|
50 |
+
else:
|
51 |
+
raise TypeError(
|
52 |
+
'runner.param_schedulers should be list of ParamScheduler or '
|
53 |
+
'a dict containing list of ParamScheduler, '
|
54 |
+
f'but got {param_schedulers}')
|
55 |
+
|
56 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
57 |
+
"""Call step function for each scheduler after each training epoch.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
runner (Runner): The runner of the training process.
|
61 |
+
"""
|
62 |
+
param_schedulers = pl_module.lr_schedulers()
|
63 |
+
if param_schedulers is None:
|
64 |
+
return
|
65 |
+
|
66 |
+
def step(param_schedulers):
|
67 |
+
assert isinstance(param_schedulers, list)
|
68 |
+
for scheduler in param_schedulers:
|
69 |
+
if scheduler.by_epoch:
|
70 |
+
scheduler.step()
|
71 |
+
if isinstance(param_schedulers, _ParamScheduler):
|
72 |
+
param_schedulers = [param_schedulers]
|
73 |
+
if isinstance(param_schedulers, list):
|
74 |
+
step(param_schedulers)
|
75 |
+
elif isinstance(param_schedulers, dict):
|
76 |
+
for param_schedulers in param_schedulers.values():
|
77 |
+
step(param_schedulers)
|
78 |
+
else:
|
79 |
+
raise TypeError(
|
80 |
+
'runner.param_schedulers should be list of ParamScheduler or '
|
81 |
+
'a dict containing list of ParamScheduler, '
|
82 |
+
f'but got {param_schedulers}')
|
83 |
+
|
84 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
85 |
+
"""Call step function for each scheduler which has attribute
|
86 |
+
``need_val_args`` after each validation epoch.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
runner (Runner): The runner of the validation process.
|
90 |
+
metrics (Dict[str, float], optional): Evaluation results of all
|
91 |
+
metrics on validation dataset. The keys are the names of the
|
92 |
+
metrics, and the values are corresponding results.
|
93 |
+
|
94 |
+
Note:
|
95 |
+
if ``runner.param_schedulers`` is not built before,
|
96 |
+
the hook ``after_val_epoch`` will be skipped.
|
97 |
+
"""
|
98 |
+
param_schedulers = pl_module.lr_schedulers()
|
99 |
+
if param_schedulers is None:
|
100 |
+
return
|
101 |
+
|
102 |
+
# avoid counting scheduler._global_step
|
103 |
+
# it has counted in after_train_* hook
|
104 |
+
metrics = trainer.callback_metrics
|
105 |
+
if metrics is None:
|
106 |
+
return
|
107 |
+
|
108 |
+
def step(param_schedulers):
|
109 |
+
# check param_schedulers is list and built
|
110 |
+
if not is_list_of(param_schedulers, _ParamScheduler):
|
111 |
+
return
|
112 |
+
|
113 |
+
for scheduler in param_schedulers:
|
114 |
+
if (scheduler.by_epoch
|
115 |
+
and getattr(scheduler, 'need_val_args', False)):
|
116 |
+
scheduler.step(metrics)
|
117 |
+
if isinstance(param_schedulers, _ParamScheduler):
|
118 |
+
param_schedulers = [param_schedulers]
|
119 |
+
if isinstance(param_schedulers, list):
|
120 |
+
step(param_schedulers)
|
121 |
+
elif isinstance(param_schedulers, dict):
|
122 |
+
for param_schedulers in param_schedulers.values():
|
123 |
+
step(param_schedulers)
|
124 |
+
else:
|
125 |
+
raise TypeError(
|
126 |
+
'runner.param_schedulers should be list of ParamScheduler or '
|
127 |
+
'a dict containing list of ParamScheduler, '
|
128 |
+
f'but got {param_schedulers}')
|
mmpl/engine/hooks/pipeline_switch_hook.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.transforms import Compose
|
2 |
+
from mmpl.registry import HOOKS
|
3 |
+
from lightning.pytorch.callbacks import Callback
|
4 |
+
|
5 |
+
|
6 |
+
@HOOKS.register_module()
|
7 |
+
class PipelineSwitchHook(Callback):
|
8 |
+
"""Switch data pipeline at switch_epoch.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
switch_epoch (int): switch pipeline at this epoch.
|
12 |
+
switch_pipeline (list[dict]): the pipeline to switch to.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, switch_epoch, switch_pipeline):
|
16 |
+
self.switch_epoch = switch_epoch
|
17 |
+
self.switch_pipeline = switch_pipeline
|
18 |
+
self._restart_dataloader = False
|
19 |
+
|
20 |
+
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
21 |
+
"""switch pipeline."""
|
22 |
+
epoch = trainer.current_epoch
|
23 |
+
train_loader = trainer.train_dataloader
|
24 |
+
if epoch == self.switch_epoch:
|
25 |
+
if trainer.local_rank == 0:
|
26 |
+
print('Switch pipeline now!')
|
27 |
+
# The dataset pipeline cannot be updated when persistent_workers
|
28 |
+
# is True, so we need to force the dataloader's multi-process
|
29 |
+
# restart. This is a very hacky approach.
|
30 |
+
train_loader.dataset.pipeline = Compose(self.switch_pipeline)
|
31 |
+
if hasattr(train_loader, 'persistent_workers'
|
32 |
+
) and train_loader.persistent_workers is True:
|
33 |
+
train_loader._DataLoader__initialized = False
|
34 |
+
train_loader._iterator = None
|
35 |
+
self._restart_dataloader = True
|
36 |
+
|
37 |
+
else:
|
38 |
+
# Once the restart is complete, we need to restore
|
39 |
+
# the initialization flag.
|
40 |
+
if self._restart_dataloader:
|
41 |
+
train_loader._DataLoader__initialized = True
|
mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from mmengine.hooks import ParamSchedulerHook
|
6 |
+
from mmengine.runner import Runner
|
7 |
+
|
8 |
+
from mmyolo.registry import HOOKS
|
9 |
+
|
10 |
+
|
11 |
+
@HOOKS.register_module()
|
12 |
+
class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
|
13 |
+
"""A hook to update learning rate and momentum in optimizer of PPYOLOE. We
|
14 |
+
use this hook to implement adaptive computation for `warmup_total_iters`,
|
15 |
+
which is not possible with the built-in ParamScheduler in mmyolo.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
|
19 |
+
start_factor (float): The number we multiply learning rate in the
|
20 |
+
first epoch. The multiplication factor changes towards end_factor
|
21 |
+
in the following epochs. Defaults to 0.
|
22 |
+
warmup_epochs (int): Epochs for warmup. Defaults to 5.
|
23 |
+
min_lr_ratio (float): Minimum learning rate ratio.
|
24 |
+
total_epochs (int): In PPYOLOE, `total_epochs` is set to
|
25 |
+
training_epochs x 1.2. Defaults to 360.
|
26 |
+
"""
|
27 |
+
priority = 9
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
warmup_min_iter: int = 1000,
|
31 |
+
start_factor: float = 0.,
|
32 |
+
warmup_epochs: int = 5,
|
33 |
+
min_lr_ratio: float = 0.0,
|
34 |
+
total_epochs: int = 360):
|
35 |
+
|
36 |
+
self.warmup_min_iter = warmup_min_iter
|
37 |
+
self.start_factor = start_factor
|
38 |
+
self.warmup_epochs = warmup_epochs
|
39 |
+
self.min_lr_ratio = min_lr_ratio
|
40 |
+
self.total_epochs = total_epochs
|
41 |
+
|
42 |
+
self._warmup_end = False
|
43 |
+
self._base_lr = None
|
44 |
+
|
45 |
+
def before_train(self, runner: Runner):
|
46 |
+
"""Operations before train.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
runner (Runner): The runner of the training process.
|
50 |
+
"""
|
51 |
+
optimizer = runner.optim_wrapper.optimizer
|
52 |
+
for group in optimizer.param_groups:
|
53 |
+
# If the param is never be scheduled, record the current value
|
54 |
+
# as the initial value.
|
55 |
+
group.setdefault('initial_lr', group['lr'])
|
56 |
+
|
57 |
+
self._base_lr = [
|
58 |
+
group['initial_lr'] for group in optimizer.param_groups
|
59 |
+
]
|
60 |
+
self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
|
61 |
+
|
62 |
+
def before_train_iter(self,
|
63 |
+
runner: Runner,
|
64 |
+
batch_idx: int,
|
65 |
+
data_batch: Optional[dict] = None):
|
66 |
+
"""Operations before each training iteration.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
runner (Runner): The runner of the training process.
|
70 |
+
batch_idx (int): The index of the current batch in the train loop.
|
71 |
+
data_batch (dict or tuple or list, optional): Data from dataloader.
|
72 |
+
"""
|
73 |
+
cur_iters = runner.iter
|
74 |
+
optimizer = runner.optim_wrapper.optimizer
|
75 |
+
dataloader_len = len(runner.train_dataloader)
|
76 |
+
|
77 |
+
# The minimum warmup is self.warmup_min_iter
|
78 |
+
warmup_total_iters = max(
|
79 |
+
round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
|
80 |
+
|
81 |
+
if cur_iters <= warmup_total_iters:
|
82 |
+
# warm up
|
83 |
+
alpha = cur_iters / warmup_total_iters
|
84 |
+
factor = self.start_factor * (1 - alpha) + alpha
|
85 |
+
|
86 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
87 |
+
param['lr'] = self._base_lr[group_idx] * factor
|
88 |
+
else:
|
89 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
90 |
+
total_iters = self.total_epochs * dataloader_len
|
91 |
+
lr = self._min_lr[group_idx] + (
|
92 |
+
self._base_lr[group_idx] -
|
93 |
+
self._min_lr[group_idx]) * 0.5 * (
|
94 |
+
math.cos((cur_iters - warmup_total_iters) * math.pi /
|
95 |
+
(total_iters - warmup_total_iters)) + 1.0)
|
96 |
+
param['lr'] = lr
|
mmpl/engine/hooks/switch_to_deploy_hook.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
from mmengine.hooks import Hook
|
4 |
+
from mmengine.runner import Runner
|
5 |
+
|
6 |
+
from mmyolo.registry import HOOKS
|
7 |
+
from mmyolo.utils import switch_to_deploy
|
8 |
+
|
9 |
+
|
10 |
+
@HOOKS.register_module()
|
11 |
+
class SwitchToDeployHook(Hook):
|
12 |
+
"""Switch to deploy mode before testing.
|
13 |
+
|
14 |
+
This hook converts the multi-channel structure of the training network
|
15 |
+
(high performance) to the one-way structure of the testing network (fast
|
16 |
+
speed and memory saving).
|
17 |
+
"""
|
18 |
+
|
19 |
+
def before_test_epoch(self, runner: Runner):
|
20 |
+
"""Switch to deploy mode before testing."""
|
21 |
+
switch_to_deploy(runner.model)
|
mmpl/engine/hooks/visualization_hook.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import warnings
|
3 |
+
from typing import Optional, Sequence, Any
|
4 |
+
|
5 |
+
import mmcv
|
6 |
+
from lightning import Callback
|
7 |
+
from mmengine.fileio import get
|
8 |
+
from mmengine.hooks import Hook
|
9 |
+
from mmengine.runner import Runner
|
10 |
+
from mmengine.utils import mkdir_or_exist
|
11 |
+
from mmengine.visualization import Visualizer
|
12 |
+
|
13 |
+
from mmpl.registry import HOOKS
|
14 |
+
from mmdet.structures import DetDataSample
|
15 |
+
|
16 |
+
|
17 |
+
@HOOKS.register_module()
|
18 |
+
class DetVisualizationHook(Callback):
|
19 |
+
"""Detection Visualization Hook. Used to visualize validation and testing
|
20 |
+
process prediction results.
|
21 |
+
|
22 |
+
In the testing phase:
|
23 |
+
|
24 |
+
1. If ``show`` is True, it means that only the prediction results are
|
25 |
+
visualized without storing data, so ``vis_backends`` needs to
|
26 |
+
be excluded.
|
27 |
+
2. If ``test_out_dir`` is specified, it means that the prediction results
|
28 |
+
need to be saved to ``test_out_dir``. In order to avoid vis_backends
|
29 |
+
also storing data, so ``vis_backends`` needs to be excluded.
|
30 |
+
3. ``vis_backends`` takes effect if the user does not specify ``show``
|
31 |
+
and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or
|
32 |
+
TensorboardVisBackend to store the prediction result in Wandb or
|
33 |
+
Tensorboard.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
draw (bool): whether to draw prediction results. If it is False,
|
37 |
+
it means that no drawing will be done. Defaults to False.
|
38 |
+
interval (int): The interval of visualization. Defaults to 50.
|
39 |
+
score_thr (float): The threshold to visualize the bboxes
|
40 |
+
and masks. Defaults to 0.3.
|
41 |
+
show (bool): Whether to display the drawn image. Default to False.
|
42 |
+
wait_time (float): The interval of show (s). Defaults to 0.
|
43 |
+
test_out_dir (str, optional): directory where painted images
|
44 |
+
will be saved in testing process.
|
45 |
+
backend_args (dict, optional): Arguments to instantiate the
|
46 |
+
corresponding backend. Defaults to None.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self,
|
50 |
+
draw: bool = False,
|
51 |
+
interval: int = 50,
|
52 |
+
score_thr: float = 0.3,
|
53 |
+
show: bool = False,
|
54 |
+
wait_time: float = 0.,
|
55 |
+
test_out_dir: Optional[str] = None,
|
56 |
+
backend_args: dict = None):
|
57 |
+
self._visualizer: Visualizer = Visualizer.get_current_instance()
|
58 |
+
self.interval = interval
|
59 |
+
self.score_thr = score_thr
|
60 |
+
self.show = show
|
61 |
+
if self.show:
|
62 |
+
# No need to think about vis backends.
|
63 |
+
self._visualizer._vis_backends = {}
|
64 |
+
warnings.warn('The show is True, it means that only '
|
65 |
+
'the prediction results are visualized '
|
66 |
+
'without storing data, so vis_backends '
|
67 |
+
'needs to be excluded.')
|
68 |
+
|
69 |
+
self.wait_time = wait_time
|
70 |
+
self.backend_args = backend_args
|
71 |
+
self.draw = draw
|
72 |
+
self.test_out_dir = test_out_dir
|
73 |
+
self._test_index = 0
|
74 |
+
|
75 |
+
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
76 |
+
outputs: Sequence[DetDataSample]) -> None:
|
77 |
+
"""Run after every ``self.interval`` validation iterations.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
runner (:obj:`Runner`): The runner of the validation process.
|
81 |
+
batch_idx (int): The index of the current batch in the val loop.
|
82 |
+
data_batch (dict): Data from dataloader.
|
83 |
+
outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
|
84 |
+
that contain annotations and predictions.
|
85 |
+
"""
|
86 |
+
if self.draw is False:
|
87 |
+
return
|
88 |
+
|
89 |
+
# There is no guarantee that the same batch of images
|
90 |
+
# is visualized for each evaluation.
|
91 |
+
total_curr_iter = runner.iter + batch_idx
|
92 |
+
|
93 |
+
# Visualize only the first data
|
94 |
+
img_path = outputs[0].img_path
|
95 |
+
img_bytes = get(img_path, backend_args=self.backend_args)
|
96 |
+
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
97 |
+
|
98 |
+
if total_curr_iter % self.interval == 0:
|
99 |
+
self._visualizer.add_datasample(
|
100 |
+
osp.basename(img_path) if self.show else 'val_img',
|
101 |
+
img,
|
102 |
+
data_sample=outputs[0],
|
103 |
+
show=self.show,
|
104 |
+
wait_time=self.wait_time,
|
105 |
+
pred_score_thr=self.score_thr,
|
106 |
+
step=total_curr_iter)
|
107 |
+
|
108 |
+
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
109 |
+
outputs: Sequence[DetDataSample]) -> None:
|
110 |
+
"""Run after every testing iterations.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
runner (:obj:`Runner`): The runner of the testing process.
|
114 |
+
batch_idx (int): The index of the current batch in the val loop.
|
115 |
+
data_batch (dict): Data from dataloader.
|
116 |
+
outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
|
117 |
+
that contain annotations and predictions.
|
118 |
+
"""
|
119 |
+
if self.draw is False:
|
120 |
+
return
|
121 |
+
|
122 |
+
if self.test_out_dir is not None:
|
123 |
+
self.test_out_dir = osp.join(runner.work_dir, runner.timestamp,
|
124 |
+
self.test_out_dir)
|
125 |
+
mkdir_or_exist(self.test_out_dir)
|
126 |
+
|
127 |
+
for data_sample in outputs:
|
128 |
+
self._test_index += 1
|
129 |
+
|
130 |
+
img_path = data_sample.img_path
|
131 |
+
img_bytes = get(img_path, backend_args=self.backend_args)
|
132 |
+
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
133 |
+
|
134 |
+
out_file = None
|
135 |
+
if self.test_out_dir is not None:
|
136 |
+
out_file = osp.basename(img_path)
|
137 |
+
out_file = osp.join(self.test_out_dir, out_file)
|
138 |
+
|
139 |
+
self._visualizer.add_datasample(
|
140 |
+
osp.basename(img_path) if self.show else 'test_img',
|
141 |
+
img,
|
142 |
+
data_sample=data_sample,
|
143 |
+
show=self.show,
|
144 |
+
wait_time=self.wait_time,
|
145 |
+
pred_score_thr=self.score_thr,
|
146 |
+
out_file=out_file,
|
147 |
+
step=self._test_index)
|
148 |
+
|
149 |
+
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
150 |
+
# if hasattr(trainer.datamodule, f'predict_dataset'):
|
151 |
+
# dataset = getattr(trainer.datamodule, f'predict_dataset')
|
152 |
+
# if hasattr(dataset, 'metainfo') and hasattr(self._visualizer, 'dataset_meta'):
|
153 |
+
# self._visualizer.dataset_meta = dataset.metainfo
|
154 |
+
if self.test_out_dir is not None:
|
155 |
+
self.test_out_dir = osp.join(trainer.default_root_dir, self.test_out_dir)
|
156 |
+
mkdir_or_exist(self.test_out_dir)
|
157 |
+
|
158 |
+
def on_predict_batch_end(
|
159 |
+
self,
|
160 |
+
trainer: "pl.Trainer",
|
161 |
+
pl_module: "pl.LightningModule",
|
162 |
+
outputs: Any,
|
163 |
+
batch: Any,
|
164 |
+
batch_idx: int,
|
165 |
+
dataloader_idx: int = 0,
|
166 |
+
) -> None:
|
167 |
+
"""Run after every testing iterations.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
runner (:obj:`Runner`): The runner of the testing process.
|
171 |
+
batch_idx (int): The index of the current batch in the val loop.
|
172 |
+
data_batch (dict): Data from dataloader.
|
173 |
+
outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
|
174 |
+
that contain annotations and predictions.
|
175 |
+
"""
|
176 |
+
if self.draw is False:
|
177 |
+
return
|
178 |
+
|
179 |
+
for data_sample in outputs:
|
180 |
+
self._test_index += 1
|
181 |
+
|
182 |
+
img_path = data_sample.img_path
|
183 |
+
img_bytes = get(img_path, backend_args=self.backend_args)
|
184 |
+
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
185 |
+
|
186 |
+
out_file = None
|
187 |
+
if self.test_out_dir is not None:
|
188 |
+
out_file = osp.basename(img_path)
|
189 |
+
out_file = osp.join(self.test_out_dir, out_file)
|
190 |
+
|
191 |
+
self._visualizer.add_datasample(
|
192 |
+
osp.basename(img_path) if self.show else 'test_img',
|
193 |
+
img,
|
194 |
+
data_sample=data_sample,
|
195 |
+
show=self.show,
|
196 |
+
wait_time=self.wait_time,
|
197 |
+
pred_score_thr=self.score_thr,
|
198 |
+
out_file=out_file,
|
199 |
+
step=self._test_index)
|
mmpl/engine/hooks/yolov5_param_scheduler_hook.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
import numpy as np
|
4 |
+
from typing import Dict, Optional, Union
|
5 |
+
from mmengine.registry import HOOKS
|
6 |
+
from .param_scheduler_hook import ParamSchedulerHook
|
7 |
+
|
8 |
+
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
9 |
+
|
10 |
+
|
11 |
+
def linear_fn(lr_factor: float, max_epochs: int):
|
12 |
+
"""Generate linear function."""
|
13 |
+
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor
|
14 |
+
|
15 |
+
|
16 |
+
def cosine_fn(lr_factor: float, max_epochs: int):
|
17 |
+
"""Generate cosine function."""
|
18 |
+
return lambda x: (
|
19 |
+
(1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1
|
20 |
+
|
21 |
+
|
22 |
+
@HOOKS.register_module()
|
23 |
+
class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
|
24 |
+
"""A hook to update learning rate and momentum in optimizer of YOLOv5."""
|
25 |
+
priority = 9
|
26 |
+
|
27 |
+
scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn}
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
scheduler_type: str = 'linear',
|
31 |
+
lr_factor: float = 0.01,
|
32 |
+
max_epochs: int = 300,
|
33 |
+
warmup_epochs: int = 3,
|
34 |
+
warmup_bias_lr: float = 0.1,
|
35 |
+
warmup_momentum: float = 0.8,
|
36 |
+
warmup_mim_iter: int = 500,
|
37 |
+
**kwargs):
|
38 |
+
|
39 |
+
assert scheduler_type in self.scheduler_maps
|
40 |
+
|
41 |
+
self.warmup_epochs = warmup_epochs
|
42 |
+
self.warmup_bias_lr = warmup_bias_lr
|
43 |
+
self.warmup_momentum = warmup_momentum
|
44 |
+
self.warmup_mim_iter = warmup_mim_iter
|
45 |
+
|
46 |
+
kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs})
|
47 |
+
self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs)
|
48 |
+
|
49 |
+
self._warmup_end = False
|
50 |
+
self._base_lr = None
|
51 |
+
self._base_momentum = None
|
52 |
+
|
53 |
+
|
54 |
+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
55 |
+
optimizer = trainer.optimizers[0]
|
56 |
+
for group in optimizer.param_groups:
|
57 |
+
# If the param is never be scheduled, record the current value
|
58 |
+
# as the initial value.
|
59 |
+
group.setdefault('initial_lr', group['lr'])
|
60 |
+
group.setdefault('initial_momentum', group.get('momentum', -1))
|
61 |
+
|
62 |
+
self._base_lr = [
|
63 |
+
group['initial_lr'] for group in optimizer.param_groups
|
64 |
+
]
|
65 |
+
self._base_momentum = [
|
66 |
+
group['initial_momentum'] for group in optimizer.param_groups
|
67 |
+
]
|
68 |
+
|
69 |
+
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss) -> None:
|
70 |
+
cur_iters = trainer.global_step
|
71 |
+
cur_epoch = trainer.current_epoch
|
72 |
+
optimizer = trainer.optimizers[0]
|
73 |
+
|
74 |
+
# The minimum warmup is self.warmup_mim_iter
|
75 |
+
warmup_total_iters = max(
|
76 |
+
round(self.warmup_epochs * len(trainer.train_dataloader)),
|
77 |
+
self.warmup_mim_iter)
|
78 |
+
|
79 |
+
if cur_iters <= warmup_total_iters:
|
80 |
+
xp = [0, warmup_total_iters]
|
81 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
82 |
+
if group_idx == 2:
|
83 |
+
# bias learning rate will be handled specially
|
84 |
+
yp = [
|
85 |
+
self.warmup_bias_lr,
|
86 |
+
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
|
87 |
+
]
|
88 |
+
else:
|
89 |
+
yp = [
|
90 |
+
0.0,
|
91 |
+
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
|
92 |
+
]
|
93 |
+
param['lr'] = np.interp(cur_iters, xp, yp)
|
94 |
+
|
95 |
+
if 'momentum' in param:
|
96 |
+
param['momentum'] = np.interp(
|
97 |
+
cur_iters, xp,
|
98 |
+
[self.warmup_momentum, self._base_momentum[group_idx]])
|
99 |
+
else:
|
100 |
+
self._warmup_end = True
|
101 |
+
|
102 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
103 |
+
if not self._warmup_end:
|
104 |
+
return
|
105 |
+
|
106 |
+
cur_epoch = trainer.current_epoch
|
107 |
+
optimizer = trainer.optimizers[0]
|
108 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
109 |
+
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(
|
110 |
+
cur_epoch)
|
111 |
+
|
mmpl/engine/hooks/yolox_mode_switch_hook.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
from typing import Sequence
|
4 |
+
|
5 |
+
from mmengine.hooks import Hook
|
6 |
+
from mmengine.model import is_model_wrapper
|
7 |
+
from mmengine.runner import Runner
|
8 |
+
|
9 |
+
from mmyolo.registry import HOOKS
|
10 |
+
|
11 |
+
|
12 |
+
@HOOKS.register_module()
|
13 |
+
class YOLOXModeSwitchHook(Hook):
|
14 |
+
"""Switch the mode of YOLOX during training.
|
15 |
+
|
16 |
+
This hook turns off the mosaic and mixup data augmentation and switches
|
17 |
+
to use L1 loss in bbox_head.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
num_last_epochs (int): The number of latter epochs in the end of the
|
21 |
+
training to close the data augmentation and switch to L1 loss.
|
22 |
+
Defaults to 15.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
num_last_epochs: int = 15,
|
27 |
+
new_train_pipeline: Sequence[dict] = None):
|
28 |
+
self.num_last_epochs = num_last_epochs
|
29 |
+
self.new_train_pipeline_cfg = new_train_pipeline
|
30 |
+
|
31 |
+
def before_train_epoch(self, runner: Runner):
|
32 |
+
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
|
33 |
+
epoch = runner.epoch
|
34 |
+
model = runner.model
|
35 |
+
if is_model_wrapper(model):
|
36 |
+
model = model.module
|
37 |
+
|
38 |
+
if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
|
39 |
+
runner.logger.info(f'New Pipeline: {self.new_train_pipeline_cfg}')
|
40 |
+
|
41 |
+
train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader)
|
42 |
+
train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline_cfg
|
43 |
+
# Note: Why rebuild the dataset?
|
44 |
+
# When build_dataloader will make a deep copy of the dataset,
|
45 |
+
# it will lead to potential risks, such as the global instance
|
46 |
+
# object FileClient data is disordered.
|
47 |
+
# This problem needs to be solved in the future.
|
48 |
+
new_train_dataloader = Runner.build_dataloader(
|
49 |
+
train_dataloader_cfg)
|
50 |
+
runner.train_loop.dataloader = new_train_dataloader
|
51 |
+
|
52 |
+
runner.logger.info('recreate the dataloader!')
|
53 |
+
runner.logger.info('Add additional bbox reg loss now!')
|
54 |
+
model.bbox_head.use_bbox_aux = True
|
mmpl/engine/logger/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .builder import PL_LOGGERS
|
mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (198 Bytes). View file
|
|
mmpl/engine/logger/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (3.15 kB). View file
|
|
mmpl/engine/logger/builder.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import inspect
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import lightning
|
8 |
+
|
9 |
+
from mmengine.config import Config, ConfigDict
|
10 |
+
from mmengine.device import is_npu_available
|
11 |
+
from mmpl.registry import LOGGERS
|
12 |
+
|
13 |
+
|
14 |
+
def register_pl_loggers() -> List[str]:
|
15 |
+
"""Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
List[str]: A list of registered optimizers' name.
|
19 |
+
"""
|
20 |
+
pl_loggers = []
|
21 |
+
for module_name in dir(lightning.pytorch.loggers):
|
22 |
+
if module_name.startswith('__'):
|
23 |
+
continue
|
24 |
+
_logger = getattr(lightning.pytorch.loggers, module_name)
|
25 |
+
if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger):
|
26 |
+
LOGGERS.register_module(module=_logger)
|
27 |
+
pl_loggers.append(module_name)
|
28 |
+
return pl_loggers
|
29 |
+
|
30 |
+
|
31 |
+
PL_LOGGERS = register_pl_loggers()
|
32 |
+
|
33 |
+
|
34 |
+
def register_dadaptation_optimizers() -> List[str]:
|
35 |
+
"""Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
List[str]: A list of registered optimizers' name.
|
39 |
+
"""
|
40 |
+
dadaptation_optimizers = []
|
41 |
+
try:
|
42 |
+
import dadaptation
|
43 |
+
except ImportError:
|
44 |
+
pass
|
45 |
+
else:
|
46 |
+
for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']:
|
47 |
+
_optim = getattr(dadaptation, module_name)
|
48 |
+
if inspect.isclass(_optim) and issubclass(_optim,
|
49 |
+
torch.optim.Optimizer):
|
50 |
+
OPTIMIZERS.register_module(module=_optim)
|
51 |
+
dadaptation_optimizers.append(module_name)
|
52 |
+
return dadaptation_optimizers
|
53 |
+
|
54 |
+
|
55 |
+
# DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()
|
56 |
+
|
57 |
+
|
58 |
+
def register_lion_optimizers() -> List[str]:
|
59 |
+
"""Register Lion optimizer to the ``OPTIMIZERS`` registry.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
List[str]: A list of registered optimizers' name.
|
63 |
+
"""
|
64 |
+
optimizers = []
|
65 |
+
try:
|
66 |
+
from lion_pytorch import Lion
|
67 |
+
except ImportError:
|
68 |
+
pass
|
69 |
+
else:
|
70 |
+
OPTIMIZERS.register_module(module=Lion)
|
71 |
+
optimizers.append('Lion')
|
72 |
+
return optimizers
|
73 |
+
|
74 |
+
|
75 |
+
# LION_OPTIMIZERS = register_lion_optimizers()
|
76 |
+
|
77 |
+
|
78 |
+
def build_optim_wrapper(model: nn.Module,
|
79 |
+
cfg: Union[dict, Config, ConfigDict]):
|
80 |
+
"""Build function of OptimWrapper.
|
81 |
+
|
82 |
+
If ``constructor`` is set in the ``cfg``, this method will build an
|
83 |
+
optimizer wrapper constructor, and use optimizer wrapper constructor to
|
84 |
+
build the optimizer wrapper. If ``constructor`` is not set, the
|
85 |
+
``DefaultOptimWrapperConstructor`` will be used by default.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
model (nn.Module): Model to be optimized.
|
89 |
+
cfg (dict): Config of optimizer wrapper, optimizer constructor and
|
90 |
+
optimizer.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
OptimWrapper: The built optimizer wrapper.
|
94 |
+
"""
|
95 |
+
optim_wrapper_cfg = copy.deepcopy(cfg)
|
96 |
+
constructor_type = optim_wrapper_cfg.pop('constructor',
|
97 |
+
'DefaultOptimWrapperConstructor')
|
98 |
+
paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
|
99 |
+
|
100 |
+
# Since the current generation of NPU(Ascend 910) only supports
|
101 |
+
# mixed precision training, here we turn on mixed precision by default
|
102 |
+
# on the NPU to make the training normal
|
103 |
+
if is_npu_available():
|
104 |
+
optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
|
105 |
+
|
106 |
+
optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
|
107 |
+
dict(
|
108 |
+
type=constructor_type,
|
109 |
+
optim_wrapper_cfg=optim_wrapper_cfg,
|
110 |
+
paramwise_cfg=paramwise_cfg))
|
111 |
+
optim_wrapper = optim_wrapper_constructor(model)
|
112 |
+
return optim_wrapper
|
mmpl/engine/optimizers/__init__.py
ADDED
File without changes
|
mmpl/engine/runner/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .pl_runner import PLRunner
|
2 |
+
|
3 |
+
__all__ = ['PLRunner']
|
mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (220 Bytes). View file
|
|
mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc
ADDED
Binary file (27 kB). View file
|
|
mmpl/engine/runner/pl_runner.py
ADDED
@@ -0,0 +1,941 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import pickle
|
6 |
+
import platform
|
7 |
+
import time
|
8 |
+
import warnings
|
9 |
+
from collections import OrderedDict
|
10 |
+
from functools import partial
|
11 |
+
from typing import Callable, Dict, List, Optional, Sequence, Union
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from lightning.pytorch.loggers import Logger
|
16 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
17 |
+
from torch.optim import Optimizer
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
import mmengine
|
21 |
+
from mmengine.config import Config, ConfigDict
|
22 |
+
from mmengine.dataset import worker_init_fn
|
23 |
+
from mmengine.device import get_device
|
24 |
+
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
25 |
+
is_distributed, master_only)
|
26 |
+
from mmengine.evaluator import Evaluator
|
27 |
+
from mmengine.fileio import FileClient, join_path
|
28 |
+
from mmengine.hooks import Hook
|
29 |
+
from mmengine.logging import MessageHub, MMLogger, print_log
|
30 |
+
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
|
31 |
+
is_model_wrapper, revert_sync_batchnorm)
|
32 |
+
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
33 |
+
build_optim_wrapper)
|
34 |
+
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
|
35 |
+
HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
|
36 |
+
OPTIM_WRAPPERS, PARAM_SCHEDULERS,
|
37 |
+
RUNNERS, VISUALIZERS, DefaultScope)
|
38 |
+
from mmengine.utils import digit_version, get_git_hash, is_seq_of
|
39 |
+
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
|
40 |
+
set_multi_processing)
|
41 |
+
from mmengine.visualization import Visualizer
|
42 |
+
from mmengine.runner.base_loop import BaseLoop
|
43 |
+
from mmengine.runner.checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
44 |
+
find_latest_checkpoint, get_state_dict,
|
45 |
+
save_checkpoint, weights_to_cpu)
|
46 |
+
from mmengine.runner.log_processor import LogProcessor
|
47 |
+
from mmengine.runner.loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
48 |
+
from mmengine.runner.priority import Priority, get_priority
|
49 |
+
from mmengine.runner.utils import set_random_seed
|
50 |
+
|
51 |
+
ConfigType = Union[Dict, Config, ConfigDict]
|
52 |
+
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, List[_ParamScheduler]]]
|
53 |
+
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
|
54 |
+
|
55 |
+
from mmpl.registry import MODELS, LOGGERS
|
56 |
+
import lightning.pytorch as pl
|
57 |
+
from mmpl.models import build_pler
|
58 |
+
|
59 |
+
|
60 |
+
@RUNNERS.register_module()
|
61 |
+
class PLRunner:
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
trainer_cfg: Dict,
|
65 |
+
model_cfg: Union[pl.LightningModule, Dict],
|
66 |
+
datamodule_cfg: Optional[Dict] = None,
|
67 |
+
cfg: Optional[ConfigType] = None
|
68 |
+
):
|
69 |
+
self.trainer_cfg = copy.deepcopy(trainer_cfg)
|
70 |
+
self.model_cfg = copy.deepcopy(model_cfg)
|
71 |
+
self.datamodule_cfg = copy.deepcopy(datamodule_cfg)
|
72 |
+
mmengine.mkdir_or_exist(trainer_cfg['default_root_dir'])
|
73 |
+
|
74 |
+
timestamp = torch.tensor(time.time(), dtype=torch.float64)
|
75 |
+
# broadcast timestamp from 0 process to other processes
|
76 |
+
broadcast(timestamp)
|
77 |
+
self.timestamp = time.strftime('%Y%m%d_%H%M%S',
|
78 |
+
time.localtime(timestamp.item()))
|
79 |
+
|
80 |
+
if cfg is not None:
|
81 |
+
if isinstance(cfg, Config):
|
82 |
+
self.cfg = copy.deepcopy(cfg)
|
83 |
+
elif isinstance(cfg, dict):
|
84 |
+
self.cfg = Config(cfg)
|
85 |
+
else:
|
86 |
+
self.cfg = Config(dict())
|
87 |
+
|
88 |
+
compiled_model = trainer_cfg.pop('compiled_model', False)
|
89 |
+
|
90 |
+
# build logger
|
91 |
+
loggers = self.build_logger(
|
92 |
+
trainer_cfg.get('logger', False),
|
93 |
+
trainer_cfg.get('default_root_dir', f'{self.timestamp}')
|
94 |
+
)
|
95 |
+
trainer_cfg['logger'] = loggers
|
96 |
+
|
97 |
+
# build visualizer used for writing log or visualizing all kinds of data
|
98 |
+
self.visualizer = self.build_visualizer(
|
99 |
+
self.cfg.get('visualizer', None),
|
100 |
+
trainer_cfg.get('default_root_dir', f'{self.timestamp}')
|
101 |
+
)
|
102 |
+
if self.cfg:
|
103 |
+
self.visualizer.add_config(self.cfg)
|
104 |
+
|
105 |
+
# build callbacks
|
106 |
+
callbacks = self.build_hooks(
|
107 |
+
trainer_cfg.get('callbacks', None),
|
108 |
+
)
|
109 |
+
trainer_cfg['callbacks'] = callbacks
|
110 |
+
|
111 |
+
# build strategy
|
112 |
+
strategy = self.build_strategy(
|
113 |
+
trainer_cfg.get('strategy', 'auto'),
|
114 |
+
)
|
115 |
+
trainer_cfg['strategy'] = strategy
|
116 |
+
|
117 |
+
self.trainer = pl.Trainer(**trainer_cfg)
|
118 |
+
model_cfg.update({'config_cfg': copy.deepcopy(cfg).to_dict()})
|
119 |
+
model = self.build_model(model_cfg)
|
120 |
+
if cfg.get('load_from', None) is not None:
|
121 |
+
self.load_checkpoint(model, cfg['load_from'])
|
122 |
+
if compiled_model:
|
123 |
+
# default, reduce-overhead, and max-autotune.
|
124 |
+
self.model = torch.compile(model)
|
125 |
+
else:
|
126 |
+
self.model = model
|
127 |
+
|
128 |
+
# dump `cfg` to `work_dir`
|
129 |
+
self.dump_config()
|
130 |
+
# # Collect and log environment information.
|
131 |
+
# self._log_env(env_cfg)
|
132 |
+
# log hooks information
|
133 |
+
# self.logger.info(f'Hooks will be executed in the following '
|
134 |
+
# f'order:\n{self.get_hooks_info()}')
|
135 |
+
|
136 |
+
def build_visualizer(
|
137 |
+
self,
|
138 |
+
visualizer: Optional[Union[Visualizer,
|
139 |
+
Dict]] = None,
|
140 |
+
default_root_dir = 'tmp'
|
141 |
+
) -> Visualizer:
|
142 |
+
"""Build a global asscessable Visualizer.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
visualizer (Visualizer or dict, optional): A Visualizer object
|
146 |
+
or a dict to build Visualizer object. If ``visualizer`` is a
|
147 |
+
Visualizer object, just returns itself. If not specified,
|
148 |
+
default config will be used to build Visualizer object.
|
149 |
+
Defaults to None.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
Visualizer: A Visualizer object build from ``visualizer``.
|
153 |
+
"""
|
154 |
+
if visualizer is None:
|
155 |
+
visualizer = dict(
|
156 |
+
name=os.path.basename(default_root_dir),
|
157 |
+
vis_backends=[dict(type='LocalVisBackend')],
|
158 |
+
save_dir=default_root_dir+'/visualizer'
|
159 |
+
)
|
160 |
+
return Visualizer.get_instance(**visualizer)
|
161 |
+
|
162 |
+
if isinstance(visualizer, Visualizer):
|
163 |
+
return visualizer
|
164 |
+
|
165 |
+
if isinstance(visualizer, dict):
|
166 |
+
# ensure visualizer containing name key
|
167 |
+
visualizer.setdefault('name', os.path.basename(default_root_dir))
|
168 |
+
visualizer.setdefault('save_dir', default_root_dir+'/visualizer')
|
169 |
+
return VISUALIZERS.build(visualizer)
|
170 |
+
else:
|
171 |
+
raise TypeError(
|
172 |
+
'visualizer should be Visualizer object, a dict or None, '
|
173 |
+
f'but got {visualizer}')
|
174 |
+
|
175 |
+
def build_hooks(self, hooks: Union[Dict, List[Dict]] = None) -> List[Hook]:
|
176 |
+
"""Build hooks from config.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
hooks_cfg (dict): Config dict of hooks.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
list[Hook]: A list of hooks.
|
183 |
+
"""
|
184 |
+
if hooks is not None:
|
185 |
+
if isinstance(hooks, dict):
|
186 |
+
hooks = [hooks]
|
187 |
+
tmp_hooks = []
|
188 |
+
for hook in hooks:
|
189 |
+
hook = HOOKS.build(hook)
|
190 |
+
tmp_hooks.append(hook)
|
191 |
+
hooks = tmp_hooks
|
192 |
+
return hooks
|
193 |
+
|
194 |
+
@classmethod
|
195 |
+
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
|
196 |
+
cfg = copy.deepcopy(cfg)
|
197 |
+
runner = cls(
|
198 |
+
trainer_cfg=cfg.get('trainer_cfg'),
|
199 |
+
model_cfg=cfg['model_cfg'],
|
200 |
+
datamodule_cfg=cfg.get('datamodule_cfg'),
|
201 |
+
cfg=cfg
|
202 |
+
)
|
203 |
+
|
204 |
+
return runner
|
205 |
+
|
206 |
+
def build_logger(self, loggers: Union[Dict, List[Dict]] = None, default_root_dir='logger'):
|
207 |
+
if loggers is not None and loggers:
|
208 |
+
if isinstance(loggers, Dict):
|
209 |
+
loggers = [loggers]
|
210 |
+
tmp_loggers = []
|
211 |
+
for logger in loggers:
|
212 |
+
if logger.get('save_dir', None) is None:
|
213 |
+
logger['save_dir'] = default_root_dir
|
214 |
+
mmengine.mkdir_or_exist(logger['save_dir'])
|
215 |
+
tmp_loggers.append(LOGGERS.build(logger))
|
216 |
+
loggers = tmp_loggers
|
217 |
+
return loggers
|
218 |
+
|
219 |
+
def build_strategy(self, strategy='auto'):
|
220 |
+
if isinstance(strategy, str):
|
221 |
+
return strategy
|
222 |
+
elif isinstance(strategy, dict):
|
223 |
+
if strategy.get('type', '') == 'FSDPStrategy':
|
224 |
+
from torch.distributed.fsdp import CPUOffload
|
225 |
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
226 |
+
import functools
|
227 |
+
strategy.update(
|
228 |
+
dict(
|
229 |
+
# cpu_offload=CPUOffload(offload_params=True),
|
230 |
+
auto_wrap_policy=functools.partial(
|
231 |
+
size_based_auto_wrap_policy, min_num_params=int(5e7)
|
232 |
+
)
|
233 |
+
)
|
234 |
+
)
|
235 |
+
strategy = MODEL_WRAPPERS.build(strategy)
|
236 |
+
return strategy
|
237 |
+
return strategy
|
238 |
+
|
239 |
+
def build_model(self, model: Union[pl.LightningModule, Dict]) -> pl.LightningModule:
|
240 |
+
if isinstance(model, pl.LightningModule):
|
241 |
+
return model
|
242 |
+
elif isinstance(model, dict):
|
243 |
+
model = build_pler(model)
|
244 |
+
return model # type: ignore
|
245 |
+
else:
|
246 |
+
raise TypeError('model should be a nn.Module object or dict, '
|
247 |
+
f'but got {model}')
|
248 |
+
|
249 |
+
def _init_model_weights(self) -> None:
|
250 |
+
"""Initialize the model weights if the model has
|
251 |
+
:meth:`init_weights`"""
|
252 |
+
if hasattr(self.model, 'module'):
|
253 |
+
model = self.model.module
|
254 |
+
else:
|
255 |
+
model = self.model
|
256 |
+
if hasattr(model, 'init_weights'):
|
257 |
+
model.init_weights()
|
258 |
+
# sync params and buffers
|
259 |
+
for name, params in model.state_dict().items():
|
260 |
+
broadcast(params)
|
261 |
+
|
262 |
+
def get_hooks_info(self) -> str:
|
263 |
+
# Get hooks info in each stage
|
264 |
+
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
|
265 |
+
for hook in self.hooks:
|
266 |
+
try:
|
267 |
+
priority = Priority(hook.priority).name # type: ignore
|
268 |
+
except ValueError:
|
269 |
+
priority = hook.priority # type: ignore
|
270 |
+
classname = hook.__class__.__name__
|
271 |
+
hook_info = f'({priority:<12}) {classname:<35}'
|
272 |
+
for trigger_stage in hook.get_triggered_stages():
|
273 |
+
stage_hook_map[trigger_stage].append(hook_info)
|
274 |
+
|
275 |
+
stage_hook_infos = []
|
276 |
+
for stage in Hook.stages:
|
277 |
+
hook_infos = stage_hook_map[stage]
|
278 |
+
if len(hook_infos) > 0:
|
279 |
+
info = f'{stage}:\n'
|
280 |
+
info += '\n'.join(hook_infos)
|
281 |
+
info += '\n -------------------- '
|
282 |
+
stage_hook_infos.append(info)
|
283 |
+
return '\n'.join(stage_hook_infos)
|
284 |
+
|
285 |
+
def load_or_resume(self) -> None:
|
286 |
+
"""load or resume checkpoint."""
|
287 |
+
if self._has_loaded:
|
288 |
+
return None
|
289 |
+
|
290 |
+
# decide to load from checkpoint or resume from checkpoint
|
291 |
+
resume_from = None
|
292 |
+
if self._resume and self._load_from is None:
|
293 |
+
# auto resume from the latest checkpoint
|
294 |
+
resume_from = find_latest_checkpoint(self.work_dir)
|
295 |
+
self.logger.info(
|
296 |
+
f'Auto resumed from the latest checkpoint {resume_from}.')
|
297 |
+
elif self._resume and self._load_from is not None:
|
298 |
+
# resume from the specified checkpoint
|
299 |
+
resume_from = self._load_from
|
300 |
+
|
301 |
+
if resume_from is not None:
|
302 |
+
self.resume(resume_from)
|
303 |
+
self._has_loaded = True
|
304 |
+
elif self._load_from is not None:
|
305 |
+
self.load_checkpoint(self._load_from)
|
306 |
+
self._has_loaded = True
|
307 |
+
|
308 |
+
@staticmethod
|
309 |
+
def build_datamodule(datamodule_cfg: Union[pl.LightningDataModule, Dict]):
|
310 |
+
if isinstance(datamodule_cfg, pl.LightningDataModule):
|
311 |
+
return datamodule_cfg
|
312 |
+
datamodule_cfg = copy.deepcopy(datamodule_cfg)
|
313 |
+
# build datamodule
|
314 |
+
datamodule = DATASETS.build(datamodule_cfg)
|
315 |
+
return datamodule
|
316 |
+
|
317 |
+
def run(self, status, *args, **kwargs):
|
318 |
+
assert status in ['fit', 'test', 'predict', 'validate']
|
319 |
+
trainer_func = self.trainer.__getattribute__(status)
|
320 |
+
self.datamodule = self.build_datamodule(self.datamodule_cfg)
|
321 |
+
return trainer_func(model=self.model, datamodule=self.datamodule, *args, **kwargs)
|
322 |
+
|
323 |
+
#
|
324 |
+
# if is_model_wrapper(self.model):
|
325 |
+
# ori_model = self.model.module
|
326 |
+
# else:
|
327 |
+
# ori_model = self.model
|
328 |
+
# assert hasattr(ori_model, 'train_step'), (
|
329 |
+
# 'If you want to train your model, please make sure your model '
|
330 |
+
# 'has implemented `train_step`.')
|
331 |
+
#
|
332 |
+
# if self._val_loop is not None:
|
333 |
+
# assert hasattr(ori_model, 'val_step'), (
|
334 |
+
# 'If you want to validate your model, please make sure your '
|
335 |
+
# 'model has implemented `val_step`.')
|
336 |
+
#
|
337 |
+
# if self._train_loop is None:
|
338 |
+
# raise RuntimeError(
|
339 |
+
# '`self._train_loop` should not be None when calling train '
|
340 |
+
# 'method. Please provide `train_dataloader`, `train_cfg`, '
|
341 |
+
# '`optimizer` and `param_scheduler` arguments when '
|
342 |
+
# 'initializing runner.')
|
343 |
+
#
|
344 |
+
# self._train_loop = self.build_train_loop(
|
345 |
+
# self._train_loop) # type: ignore
|
346 |
+
#
|
347 |
+
# # `build_optimizer` should be called before `build_param_scheduler`
|
348 |
+
# # because the latter depends on the former
|
349 |
+
# self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
|
350 |
+
# # Automatically scaling lr by linear scaling rule
|
351 |
+
# self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
|
352 |
+
#
|
353 |
+
# if self.param_schedulers is not None:
|
354 |
+
# self.param_schedulers = self.build_param_scheduler( # type: ignore
|
355 |
+
# self.param_schedulers) # type: ignore
|
356 |
+
#
|
357 |
+
# if self._val_loop is not None:
|
358 |
+
# self._val_loop = self.build_val_loop(
|
359 |
+
# self._val_loop) # type: ignore
|
360 |
+
# # TODO: add a contextmanager to avoid calling `before_run` many times
|
361 |
+
# self.call_hook('before_run')
|
362 |
+
#
|
363 |
+
# # initialize the model weights
|
364 |
+
# self._init_model_weights()
|
365 |
+
# # make sure checkpoint-related hooks are triggered after `before_run`
|
366 |
+
# self.load_or_resume()
|
367 |
+
#
|
368 |
+
# # Initiate inner count of `optim_wrapper`.
|
369 |
+
# self.optim_wrapper.initialize_count_status(
|
370 |
+
# self.model,
|
371 |
+
# self._train_loop.iter, # type: ignore
|
372 |
+
# self._train_loop.max_iters) # type: ignore
|
373 |
+
#
|
374 |
+
# # Maybe compile the model according to options in self.cfg.compile
|
375 |
+
# # This must be called **AFTER** model has been wrapped.
|
376 |
+
# self._maybe_compile('train_step')
|
377 |
+
#
|
378 |
+
# model = self.train_loop.run() # type: ignore
|
379 |
+
# self.call_hook('after_run')
|
380 |
+
# return model
|
381 |
+
|
382 |
+
|
383 |
+
|
384 |
+
def register_hook(
|
385 |
+
self,
|
386 |
+
hook: Union[Hook, Dict],
|
387 |
+
priority: Optional[Union[str, int, Priority]] = None) -> None:
|
388 |
+
"""Register a hook into the hook list.
|
389 |
+
|
390 |
+
The hook will be inserted into a priority queue, with the specified
|
391 |
+
priority (See :class:`Priority` for details of priorities).
|
392 |
+
For hooks with the same priority, they will be triggered in the same
|
393 |
+
order as they are registered.
|
394 |
+
|
395 |
+
Priority of hook will be decided with the following priority:
|
396 |
+
|
397 |
+
- ``priority`` argument. If ``priority`` is given, it will be priority
|
398 |
+
of hook.
|
399 |
+
- If ``hook`` argument is a dict and ``priority`` in it, the priority
|
400 |
+
will be the value of ``hook['priority']``.
|
401 |
+
- If ``hook`` argument is a dict but ``priority`` not in it or ``hook``
|
402 |
+
is an instance of ``hook``, the priority will be ``hook.priority``.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
hook (:obj:`Hook` or dict): The hook to be registered.
|
406 |
+
priority (int or str or :obj:`Priority`, optional): Hook priority.
|
407 |
+
Lower value means higher priority.
|
408 |
+
"""
|
409 |
+
if not isinstance(hook, (Hook, dict)):
|
410 |
+
raise TypeError(
|
411 |
+
f'hook should be an instance of Hook or dict, but got {hook}')
|
412 |
+
|
413 |
+
_priority = None
|
414 |
+
if isinstance(hook, dict):
|
415 |
+
if 'priority' in hook:
|
416 |
+
_priority = hook.pop('priority')
|
417 |
+
|
418 |
+
hook_obj = HOOKS.build(hook)
|
419 |
+
else:
|
420 |
+
hook_obj = hook
|
421 |
+
|
422 |
+
if priority is not None:
|
423 |
+
hook_obj.priority = priority
|
424 |
+
elif _priority is not None:
|
425 |
+
hook_obj.priority = _priority
|
426 |
+
|
427 |
+
inserted = False
|
428 |
+
for i in range(len(self._hooks) - 1, -1, -1):
|
429 |
+
if get_priority(hook_obj.priority) >= get_priority(
|
430 |
+
self._hooks[i].priority):
|
431 |
+
self._hooks.insert(i + 1, hook_obj)
|
432 |
+
inserted = True
|
433 |
+
break
|
434 |
+
if not inserted:
|
435 |
+
self._hooks.insert(0, hook_obj)
|
436 |
+
|
437 |
+
def register_default_hooks(
|
438 |
+
self,
|
439 |
+
hooks: Optional[Dict[str, Union[Hook, Dict]]] = None) -> None:
|
440 |
+
"""Register default hooks into hook list.
|
441 |
+
|
442 |
+
``hooks`` will be registered into runner to execute some default
|
443 |
+
actions like updating model parameters or saving checkpoints.
|
444 |
+
|
445 |
+
Default hooks and their priorities:
|
446 |
+
|
447 |
+
+----------------------+-------------------------+
|
448 |
+
| Hooks | Priority |
|
449 |
+
+======================+=========================+
|
450 |
+
| RuntimeInfoHook | VERY_HIGH (10) |
|
451 |
+
+----------------------+-------------------------+
|
452 |
+
| IterTimerHook | NORMAL (50) |
|
453 |
+
+----------------------+-------------------------+
|
454 |
+
| DistSamplerSeedHook | NORMAL (50) |
|
455 |
+
+----------------------+-------------------------+
|
456 |
+
| LoggerHook | BELOW_NORMAL (60) |
|
457 |
+
+----------------------+-------------------------+
|
458 |
+
| ParamSchedulerHook | LOW (70) |
|
459 |
+
+----------------------+-------------------------+
|
460 |
+
| CheckpointHook | VERY_LOW (90) |
|
461 |
+
+----------------------+-------------------------+
|
462 |
+
|
463 |
+
If ``hooks`` is None, above hooks will be registered by
|
464 |
+
default::
|
465 |
+
|
466 |
+
default_hooks = dict(
|
467 |
+
runtime_info=dict(type='RuntimeInfoHook'),
|
468 |
+
timer=dict(type='IterTimerHook'),
|
469 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
470 |
+
logger=dict(type='LoggerHook'),
|
471 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
472 |
+
checkpoint=dict(type='CheckpointHook', interval=1),
|
473 |
+
)
|
474 |
+
|
475 |
+
If not None, ``hooks`` will be merged into ``default_hooks``.
|
476 |
+
If there are None value in default_hooks, the corresponding item will
|
477 |
+
be popped from ``default_hooks``::
|
478 |
+
|
479 |
+
hooks = dict(timer=None)
|
480 |
+
|
481 |
+
The final registered default hooks will be :obj:`RuntimeInfoHook`,
|
482 |
+
:obj:`DistSamplerSeedHook`, :obj:`LoggerHook`,
|
483 |
+
:obj:`ParamSchedulerHook` and :obj:`CheckpointHook`.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
hooks (dict[str, Hook or dict], optional): Default hooks or configs
|
487 |
+
to be registered.
|
488 |
+
"""
|
489 |
+
default_hooks: dict = dict(
|
490 |
+
runtime_info=dict(type='RuntimeInfoHook'),
|
491 |
+
timer=dict(type='IterTimerHook'),
|
492 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
493 |
+
logger=dict(type='LoggerHook'),
|
494 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
495 |
+
checkpoint=dict(type='CheckpointHook', interval=1),
|
496 |
+
)
|
497 |
+
if hooks is not None:
|
498 |
+
for name, hook in hooks.items():
|
499 |
+
if name in default_hooks and hook is None:
|
500 |
+
# remove hook from _default_hooks
|
501 |
+
default_hooks.pop(name)
|
502 |
+
else:
|
503 |
+
assert hook is not None
|
504 |
+
default_hooks[name] = hook
|
505 |
+
|
506 |
+
for hook in default_hooks.values():
|
507 |
+
self.register_hook(hook)
|
508 |
+
|
509 |
+
def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None:
|
510 |
+
"""Register custom hooks into hook list.
|
511 |
+
|
512 |
+
Args:
|
513 |
+
hooks (list[Hook | dict]): List of hooks or configs to be
|
514 |
+
registered.
|
515 |
+
"""
|
516 |
+
for hook in hooks:
|
517 |
+
self.register_hook(hook)
|
518 |
+
|
519 |
+
def register_hooks(
|
520 |
+
self,
|
521 |
+
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
|
522 |
+
custom_hooks: Optional[List[Union[Hook, Dict]]] = None) -> None:
|
523 |
+
"""Register default hooks and custom hooks into hook list.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks
|
527 |
+
to execute default actions like updating model parameters and
|
528 |
+
saving checkpoints. Defaults to None.
|
529 |
+
custom_hooks (list[dict] or list[Hook], optional): Hooks to execute
|
530 |
+
custom actions like visualizing images processed by pipeline.
|
531 |
+
Defaults to None.
|
532 |
+
"""
|
533 |
+
self.register_default_hooks(default_hooks)
|
534 |
+
|
535 |
+
if custom_hooks is not None:
|
536 |
+
self.register_custom_hooks(custom_hooks)
|
537 |
+
|
538 |
+
def resume(self,
|
539 |
+
filename: str,
|
540 |
+
resume_optimizer: bool = True,
|
541 |
+
resume_param_scheduler: bool = True,
|
542 |
+
map_location: Union[str, Callable] = 'default') -> None:
|
543 |
+
"""Resume model from checkpoint.
|
544 |
+
|
545 |
+
Args:
|
546 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
547 |
+
``open-mmlab://xxx``.
|
548 |
+
resume_optimizer (bool): Whether to resume optimizer state.
|
549 |
+
Defaults to True.
|
550 |
+
resume_param_scheduler (bool): Whether to resume param scheduler
|
551 |
+
state. Defaults to True.
|
552 |
+
map_location (str or callable):A string or a callable function to
|
553 |
+
specifying how to remap storage locations.
|
554 |
+
Defaults to 'default'.
|
555 |
+
"""
|
556 |
+
if map_location == 'default':
|
557 |
+
device = get_device()
|
558 |
+
checkpoint = self.load_checkpoint(filename, map_location=device)
|
559 |
+
else:
|
560 |
+
checkpoint = self.load_checkpoint(
|
561 |
+
filename, map_location=map_location)
|
562 |
+
|
563 |
+
self.train_loop._epoch = checkpoint['meta']['epoch']
|
564 |
+
self.train_loop._iter = checkpoint['meta']['iter']
|
565 |
+
|
566 |
+
# check whether the number of GPU used for current experiment
|
567 |
+
# is consistent with resuming from checkpoint
|
568 |
+
if 'config' in checkpoint['meta']:
|
569 |
+
config = mmengine.Config.fromstring(
|
570 |
+
checkpoint['meta']['config'], file_format='.py')
|
571 |
+
previous_gpu_ids = config.get('gpu_ids', None)
|
572 |
+
if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0
|
573 |
+
and len(previous_gpu_ids) != self._world_size):
|
574 |
+
# TODO, should we modify the iteration?
|
575 |
+
self.logger.info(
|
576 |
+
'Number of GPU used for current experiment is not '
|
577 |
+
'consistent with resuming from checkpoint')
|
578 |
+
if (self.auto_scale_lr is None
|
579 |
+
or not self.auto_scale_lr.get('enable', False)):
|
580 |
+
raise RuntimeError(
|
581 |
+
'Cannot automatically rescale lr in resuming. Please '
|
582 |
+
'make sure the number of GPU is consistent with the '
|
583 |
+
'previous training state resuming from the checkpoint '
|
584 |
+
'or set `enable` in `auto_scale_lr to False.')
|
585 |
+
|
586 |
+
# resume random seed
|
587 |
+
resumed_seed = checkpoint['meta'].get('seed', None)
|
588 |
+
current_seed = self._randomness_cfg.get('seed')
|
589 |
+
if resumed_seed is not None and resumed_seed != current_seed:
|
590 |
+
if current_seed is not None:
|
591 |
+
print_log(
|
592 |
+
f'The value of random seed in the '
|
593 |
+
f'checkpoint "{resumed_seed}" is '
|
594 |
+
f'different from the value in '
|
595 |
+
f'`randomness` config "{current_seed}"',
|
596 |
+
logger='current',
|
597 |
+
level=logging.WARNING)
|
598 |
+
self._randomness_cfg.update(seed=resumed_seed)
|
599 |
+
self.set_randomness(**self._randomness_cfg)
|
600 |
+
|
601 |
+
resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
602 |
+
dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None)
|
603 |
+
|
604 |
+
# `resumed_dataset_meta` and `dataset_meta` could be object like
|
605 |
+
# np.ndarray, which cannot be directly judged as equal or not,
|
606 |
+
# therefore we just compared their dumped results.
|
607 |
+
if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta):
|
608 |
+
print_log(
|
609 |
+
'The dataset metainfo from the resumed checkpoint is '
|
610 |
+
'different from the current training dataset, please '
|
611 |
+
'check the correctness of the checkpoint or the training '
|
612 |
+
'dataset.',
|
613 |
+
logger='current',
|
614 |
+
level=logging.WARNING)
|
615 |
+
|
616 |
+
self.message_hub.load_state_dict(checkpoint['message_hub'])
|
617 |
+
|
618 |
+
# resume optimizer
|
619 |
+
if 'optimizer' in checkpoint and resume_optimizer:
|
620 |
+
self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
|
621 |
+
self.optim_wrapper.load_state_dict( # type: ignore
|
622 |
+
checkpoint['optimizer'])
|
623 |
+
|
624 |
+
# resume param scheduler
|
625 |
+
if resume_param_scheduler and self.param_schedulers is None:
|
626 |
+
print_log(
|
627 |
+
'`resume_param_scheduler` is True but `self.param_schedulers` '
|
628 |
+
'is None, so skip resuming parameter schedulers',
|
629 |
+
logger='current',
|
630 |
+
level=logging.WARNING)
|
631 |
+
resume_param_scheduler = False
|
632 |
+
if 'param_schedulers' in checkpoint and resume_param_scheduler:
|
633 |
+
self.param_schedulers = self.build_param_scheduler( # type: ignore
|
634 |
+
self.param_schedulers) # type: ignore
|
635 |
+
if isinstance(self.param_schedulers, dict):
|
636 |
+
for name, schedulers in self.param_schedulers.items():
|
637 |
+
for scheduler, ckpt_scheduler in zip(
|
638 |
+
schedulers, checkpoint['param_schedulers'][name]):
|
639 |
+
scheduler.load_state_dict(ckpt_scheduler)
|
640 |
+
else:
|
641 |
+
for scheduler, ckpt_scheduler in zip(
|
642 |
+
self.param_schedulers, # type: ignore
|
643 |
+
checkpoint['param_schedulers']):
|
644 |
+
scheduler.load_state_dict(ckpt_scheduler)
|
645 |
+
|
646 |
+
self._has_loaded = True
|
647 |
+
|
648 |
+
self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}')
|
649 |
+
|
650 |
+
# def load_checkpoint(self,
|
651 |
+
# filename: str,
|
652 |
+
# model,
|
653 |
+
# map_location: Union[str, Callable] = 'cpu',
|
654 |
+
# strict: bool = False,
|
655 |
+
# revise_keys: list = [(r'^module.', '')]):
|
656 |
+
# """Load checkpoint from given ``filename``.
|
657 |
+
#
|
658 |
+
# Args:
|
659 |
+
# filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
660 |
+
# ``open-mmlab://xxx``.
|
661 |
+
# map_location (str or callable): A string or a callable function to
|
662 |
+
# specifying how to remap storage locations.
|
663 |
+
# Defaults to 'cpu'.
|
664 |
+
# strict (bool): strict (bool): Whether to allow different params for
|
665 |
+
# the model and checkpoint.
|
666 |
+
# revise_keys (list): A list of customized keywords to modify the
|
667 |
+
# state_dict in checkpoint. Each item is a (pattern, replacement)
|
668 |
+
# pair of the regular expression operations. Defaults to strip
|
669 |
+
# the prefix 'module.' by [(r'^module\\.', '')].
|
670 |
+
# """
|
671 |
+
# checkpoint = _load_checkpoint(filename, map_location=map_location)
|
672 |
+
#
|
673 |
+
# if is_model_wrapper(model):
|
674 |
+
# model = model.module
|
675 |
+
# else:
|
676 |
+
# model = model
|
677 |
+
#
|
678 |
+
# checkpoint = _load_checkpoint_to_model(
|
679 |
+
# model, checkpoint, strict, revise_keys=revise_keys)
|
680 |
+
#
|
681 |
+
# print(f'Load checkpoint from {filename}')
|
682 |
+
#
|
683 |
+
# return checkpoint
|
684 |
+
def load_checkpoint(self, model, file):
|
685 |
+
|
686 |
+
if isinstance(file, str):
|
687 |
+
file_path = file
|
688 |
+
state_dict = torch.load(file_path, map_location='cpu')['state_dict']
|
689 |
+
elif isinstance(file, dict):
|
690 |
+
file_path = file['file_path']
|
691 |
+
state_dict = torch.load(file_path, map_location='cpu')['state_dict']
|
692 |
+
for delete_key in file['delete_keys']:
|
693 |
+
del state_dict[delete_key]
|
694 |
+
else:
|
695 |
+
raise TypeError('file must be str or dict')
|
696 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
697 |
+
print('load from:', file_path)
|
698 |
+
print('load model missing_keys:', missing_keys)
|
699 |
+
print('load model unexpected_keys:', unexpected_keys)
|
700 |
+
|
701 |
+
@master_only
|
702 |
+
def save_checkpoint(
|
703 |
+
self,
|
704 |
+
out_dir: str,
|
705 |
+
filename: str,
|
706 |
+
file_client_args: Optional[dict] = None,
|
707 |
+
save_optimizer: bool = True,
|
708 |
+
save_param_scheduler: bool = True,
|
709 |
+
meta: dict = None,
|
710 |
+
by_epoch: bool = True,
|
711 |
+
backend_args: Optional[dict] = None,
|
712 |
+
):
|
713 |
+
"""Save checkpoints.
|
714 |
+
|
715 |
+
``CheckpointHook`` invokes this method to save checkpoints
|
716 |
+
periodically.
|
717 |
+
|
718 |
+
Args:
|
719 |
+
out_dir (str): The directory that checkpoints are saved.
|
720 |
+
filename (str): The checkpoint filename.
|
721 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
722 |
+
FileClient. See :class:`mmengine.fileio.FileClient` for
|
723 |
+
details. Defaults to None. It will be deprecated in future.
|
724 |
+
Please use `backend_args` instead.
|
725 |
+
save_optimizer (bool): Whether to save the optimizer to
|
726 |
+
the checkpoint. Defaults to True.
|
727 |
+
save_param_scheduler (bool): Whether to save the param_scheduler
|
728 |
+
to the checkpoint. Defaults to True.
|
729 |
+
meta (dict, optional): The meta information to be saved in the
|
730 |
+
checkpoint. Defaults to None.
|
731 |
+
by_epoch (bool): Whether the scheduled momentum is updated by
|
732 |
+
epochs. Defaults to True.
|
733 |
+
backend_args (dict, optional): Arguments to instantiate the
|
734 |
+
prefix of uri corresponding backend. Defaults to None.
|
735 |
+
New in v0.2.0.
|
736 |
+
"""
|
737 |
+
if meta is None:
|
738 |
+
meta = {}
|
739 |
+
elif not isinstance(meta, dict):
|
740 |
+
raise TypeError(
|
741 |
+
f'meta should be a dict or None, but got {type(meta)}')
|
742 |
+
|
743 |
+
if by_epoch:
|
744 |
+
# self.epoch increments 1 after
|
745 |
+
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
746 |
+
# called by `after_train_epoch`` method of `CheckpointHook` so
|
747 |
+
# `epoch` should be `self.epoch + 1`
|
748 |
+
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
749 |
+
else:
|
750 |
+
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
751 |
+
|
752 |
+
if file_client_args is not None:
|
753 |
+
warnings.warn(
|
754 |
+
'"file_client_args" will be deprecated in future. '
|
755 |
+
'Please use "backend_args" instead', DeprecationWarning)
|
756 |
+
if backend_args is not None:
|
757 |
+
raise ValueError(
|
758 |
+
'"file_client_args" and "backend_args" cannot be set at '
|
759 |
+
'the same time.')
|
760 |
+
|
761 |
+
file_client = FileClient.infer_client(file_client_args, out_dir)
|
762 |
+
filepath = file_client.join_path(out_dir, filename)
|
763 |
+
else:
|
764 |
+
filepath = join_path( # type: ignore
|
765 |
+
out_dir, filename, backend_args=backend_args)
|
766 |
+
|
767 |
+
meta.update(
|
768 |
+
cfg=self.cfg.pretty_text,
|
769 |
+
seed=self.seed,
|
770 |
+
experiment_name=self.experiment_name,
|
771 |
+
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
|
772 |
+
mmengine_version=mmengine.__version__ + get_git_hash())
|
773 |
+
|
774 |
+
if hasattr(self.train_dataloader.dataset, 'metainfo'):
|
775 |
+
meta.update(dataset_meta=self.train_dataloader.dataset.metainfo)
|
776 |
+
|
777 |
+
if is_model_wrapper(self.model):
|
778 |
+
model = self.model.module
|
779 |
+
else:
|
780 |
+
model = self.model
|
781 |
+
|
782 |
+
checkpoint = {
|
783 |
+
'meta': meta,
|
784 |
+
'state_dict': weights_to_cpu(get_state_dict(model)),
|
785 |
+
'message_hub': self.message_hub.state_dict()
|
786 |
+
}
|
787 |
+
# save optimizer state dict to checkpoint
|
788 |
+
if save_optimizer:
|
789 |
+
if isinstance(self.optim_wrapper, OptimWrapper):
|
790 |
+
checkpoint['optimizer'] = self.optim_wrapper.state_dict()
|
791 |
+
else:
|
792 |
+
raise TypeError(
|
793 |
+
'self.optim_wrapper should be an `OptimWrapper` '
|
794 |
+
'or `OptimWrapperDict` instance, but got '
|
795 |
+
f'{self.optim_wrapper}')
|
796 |
+
|
797 |
+
# save param scheduler state dict
|
798 |
+
if save_param_scheduler and self.param_schedulers is None:
|
799 |
+
print_log(
|
800 |
+
'`save_param_scheduler` is True but `self.param_schedulers` '
|
801 |
+
'is None, so skip saving parameter schedulers',
|
802 |
+
logger='current',
|
803 |
+
level=logging.WARNING)
|
804 |
+
save_param_scheduler = False
|
805 |
+
if save_param_scheduler:
|
806 |
+
if isinstance(self.param_schedulers, dict):
|
807 |
+
checkpoint['param_schedulers'] = dict()
|
808 |
+
for name, schedulers in self.param_schedulers.items():
|
809 |
+
checkpoint['param_schedulers'][name] = []
|
810 |
+
for scheduler in schedulers:
|
811 |
+
state_dict = scheduler.state_dict()
|
812 |
+
checkpoint['param_schedulers'][name].append(state_dict)
|
813 |
+
else:
|
814 |
+
checkpoint['param_schedulers'] = []
|
815 |
+
for scheduler in self.param_schedulers: # type: ignore
|
816 |
+
state_dict = scheduler.state_dict() # type: ignore
|
817 |
+
checkpoint['param_schedulers'].append(state_dict)
|
818 |
+
|
819 |
+
self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
|
820 |
+
save_checkpoint(checkpoint, filepath)
|
821 |
+
|
822 |
+
@master_only
|
823 |
+
def dump_config(self) -> None:
|
824 |
+
version = ''
|
825 |
+
if len(self.trainer.loggers) > 0:
|
826 |
+
version = self.trainer.loggers[0].version
|
827 |
+
version = version if isinstance(version, str) else f"version_{version}"
|
828 |
+
if version == '':
|
829 |
+
# if no loggers, use default_root_dir
|
830 |
+
version = 'version'
|
831 |
+
|
832 |
+
"""Dump config to `work_dir`."""
|
833 |
+
if self.cfg.filename is not None:
|
834 |
+
filename = osp.basename(self.cfg.filename)
|
835 |
+
else:
|
836 |
+
filename = f'{self.timestamp}.py'
|
837 |
+
path = f'{self.trainer.default_root_dir}/{version}_{filename}'
|
838 |
+
|
839 |
+
self.cfg.dump(path)
|
840 |
+
|
841 |
+
def _check_scheduler_cfg(
|
842 |
+
self, param_scheduler: Optional[Union[dict, list,
|
843 |
+
_ParamScheduler]]) -> None:
|
844 |
+
"""Parse `param_scheduler` to a list of parameter schedulers, or a
|
845 |
+
`dict` of which each value is a list of parameter schedulers.
|
846 |
+
|
847 |
+
If only one optimizer is used, the parsed config should be a
|
848 |
+
list of parameter scheduler configs or instances. If multiple
|
849 |
+
optimizers are used, the parsed config should be `dict`.
|
850 |
+
Its key should be consistent with the optimizer `dict` and its value
|
851 |
+
should be a list of parameter scheduler configs or instances. See
|
852 |
+
:meth:`build_param_scheduler` for more details.
|
853 |
+
|
854 |
+
Examples:
|
855 |
+
>>> # valid scheduler:
|
856 |
+
>>> # empty scheduler
|
857 |
+
>>> scheduler = None
|
858 |
+
>>> # Single scheduler
|
859 |
+
>>> scheduler = dict(type='MultiStepLR', milestones=[1, 2])
|
860 |
+
>>> # Single list schedulers
|
861 |
+
>>> scheduler = [dict(type='MultiStepLR', milestones=[1, 2]),
|
862 |
+
>>> dict(type='MultiStepLR', milestones=[2, 3])]
|
863 |
+
>>> # `dict` of schedulers
|
864 |
+
>>> scheduler = dict(linear1=dict(type='MultiStepLR', milestones=[1, 2]),
|
865 |
+
>>> linear2=dict(type='MultiStepLR', milestones=[1, 2]))
|
866 |
+
>>> # `dict` of `list` of schedulers
|
867 |
+
>>> scheduler = dict(linear1=[dict(type='MultiStepLR', milestones=[1, 2])],
|
868 |
+
>>> linear2=[dict(type='MultiStepLR', milestones=[1, 2])])
|
869 |
+
>>> # Single built scheduler
|
870 |
+
>>> from mmengine.optim import MultiStepLR
|
871 |
+
>>> scheduler = MultiStepLR(milestones=[1, 2], optimizer=optimizer)
|
872 |
+
>>> # Single built list schedulers
|
873 |
+
>>> scheduler = [MultiStepLR(milestones=[1, 2], optimizer=optimizer)]
|
874 |
+
>>> # dict of built scheduler
|
875 |
+
>>> scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer),
|
876 |
+
>>> linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer))
|
877 |
+
>>> # dict of built list schedulers
|
878 |
+
>>> scheduler = dict(linear1=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)],
|
879 |
+
>>> linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)])
|
880 |
+
|
881 |
+
Args:
|
882 |
+
param_scheduler (dict or list): The original parameter scheduler.
|
883 |
+
""" # noqa: E501
|
884 |
+
param_schedulers: Union[dict, list, _ParamScheduler]
|
885 |
+
if param_scheduler is None:
|
886 |
+
return
|
887 |
+
if isinstance(param_scheduler, _ParamScheduler):
|
888 |
+
return
|
889 |
+
if is_seq_of(param_scheduler, _ParamScheduler):
|
890 |
+
return
|
891 |
+
|
892 |
+
if is_seq_of(param_scheduler, dict):
|
893 |
+
for _param_scheduler in param_scheduler:
|
894 |
+
assert 'type' in _param_scheduler, (
|
895 |
+
'Each parameter scheduler should contain the key type, '
|
896 |
+
f'but got {_param_scheduler}')
|
897 |
+
elif isinstance(param_scheduler, dict):
|
898 |
+
if 'type' not in param_scheduler:
|
899 |
+
for key, _param_scheduler in param_scheduler.items():
|
900 |
+
assert isinstance(
|
901 |
+
_param_scheduler,
|
902 |
+
(dict, tuple, list, _ParamScheduler)), (
|
903 |
+
'Each value of `param_scheduler` should be a '
|
904 |
+
f'dict or a list, but got {_param_scheduler} with '
|
905 |
+
f'type {type(_ParamScheduler)}')
|
906 |
+
|
907 |
+
else:
|
908 |
+
raise TypeError(
|
909 |
+
'`param_scheduler` should be a `_ParamScheduler`, `dict`, '
|
910 |
+
f'list or a tuple, but got {type(param_scheduler)}. If '
|
911 |
+
'`param_scheduler` is a list of dict, it means a list of '
|
912 |
+
'scheduler configs for single optimizer. If it is a dict and '
|
913 |
+
'contains key `type`, it means a scheduler config for a '
|
914 |
+
'single optimizer. If it does not contain key `type`, it '
|
915 |
+
'means multiple lists of schedulers for multiple optimizers.')
|
916 |
+
|
917 |
+
def _log_env(self, env_cfg: dict) -> None:
|
918 |
+
"""Logging environment information of the current task.
|
919 |
+
|
920 |
+
Args:
|
921 |
+
env_cfg (dict): The environment config of the runner.
|
922 |
+
"""
|
923 |
+
# Collect and log environment information.
|
924 |
+
env = collect_env()
|
925 |
+
runtime_env = OrderedDict()
|
926 |
+
runtime_env.update(env_cfg)
|
927 |
+
runtime_env.update(self._randomness_cfg)
|
928 |
+
runtime_env['Distributed launcher'] = self._launcher
|
929 |
+
runtime_env['Distributed training'] = self._distributed
|
930 |
+
runtime_env['GPU number'] = self._world_size
|
931 |
+
|
932 |
+
env_info = '\n ' + '\n '.join(f'{k}: {v}'
|
933 |
+
for k, v in env.items())
|
934 |
+
runtime_env_info = '\n ' + '\n '.join(
|
935 |
+
f'{k}: {v}' for k, v in runtime_env.items())
|
936 |
+
dash_line = '-' * 60
|
937 |
+
self.logger.info('\n' + dash_line + '\nSystem environment:' +
|
938 |
+
env_info + '\n'
|
939 |
+
'\nRuntime environment:' + runtime_env_info + '\n' +
|
940 |
+
dash_line + '\n')
|
941 |
+
self.logger.info(f'Config:\n{self.cfg.pretty_text}')
|
mmpl/engine/strategies/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .builder import PL_MODEL_WRAPPERS
|
mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (209 Bytes). View file
|
|