KyanChen commited on
Commit
1c3eb47
·
1 Parent(s): 3e06e1c

Upload 159 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mmpl/__init__.py +0 -0
  2. mmpl/__pycache__/__init__.cpython-310.pyc +0 -0
  3. mmpl/__pycache__/registry.cpython-310.pyc +0 -0
  4. mmpl/datasets/__init__.py +9 -0
  5. mmpl/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  6. mmpl/datasets/__pycache__/builder.cpython-310.pyc +0 -0
  7. mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc +0 -0
  8. mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc +0 -0
  9. mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc +0 -0
  10. mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc +0 -0
  11. mmpl/datasets/base_dataset.py +212 -0
  12. mmpl/datasets/builder.py +25 -0
  13. mmpl/datasets/custom.py +237 -0
  14. mmpl/datasets/nwpu_ins_dataset.py +59 -0
  15. mmpl/datasets/pl_datamodule.py +73 -0
  16. mmpl/datasets/ssdd_ins_dataset.py +54 -0
  17. mmpl/datasets/transforms/__init__.py +0 -0
  18. mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  19. mmpl/datasets/utils.py +243 -0
  20. mmpl/datasets/whu_ins_dataset.py +54 -0
  21. mmpl/engine/__init__.py +5 -0
  22. mmpl/engine/__pycache__/__init__.cpython-310.pyc +0 -0
  23. mmpl/engine/hooks/__init__.py +6 -0
  24. mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc +0 -0
  25. mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc +0 -0
  26. mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc +0 -0
  27. mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc +0 -0
  28. mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc +0 -0
  29. mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc +0 -0
  30. mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc +0 -0
  31. mmpl/engine/hooks/builder.py +31 -0
  32. mmpl/engine/hooks/ema_hook.py +240 -0
  33. mmpl/engine/hooks/param_scheduler_hook.py +128 -0
  34. mmpl/engine/hooks/pipeline_switch_hook.py +41 -0
  35. mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py +96 -0
  36. mmpl/engine/hooks/switch_to_deploy_hook.py +21 -0
  37. mmpl/engine/hooks/visualization_hook.py +199 -0
  38. mmpl/engine/hooks/yolov5_param_scheduler_hook.py +111 -0
  39. mmpl/engine/hooks/yolox_mode_switch_hook.py +54 -0
  40. mmpl/engine/logger/__init__.py +1 -0
  41. mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc +0 -0
  42. mmpl/engine/logger/__pycache__/builder.cpython-310.pyc +0 -0
  43. mmpl/engine/logger/builder.py +112 -0
  44. mmpl/engine/optimizers/__init__.py +0 -0
  45. mmpl/engine/runner/__init__.py +3 -0
  46. mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc +0 -0
  47. mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc +0 -0
  48. mmpl/engine/runner/pl_runner.py +941 -0
  49. mmpl/engine/strategies/__init__.py +1 -0
  50. mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc +0 -0
mmpl/__init__.py ADDED
File without changes
mmpl/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (142 Bytes). View file
 
mmpl/__pycache__/registry.cpython-310.pyc ADDED
Binary file (2.1 kB). View file
 
mmpl/datasets/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .builder import build_dataset
2
+ from .pl_datamodule import PLDataModule
3
+ from .nwpu_ins_dataset import NWPUInsSegDataset
4
+ from .whu_ins_dataset import WHUInsSegDataset
5
+ from .ssdd_ins_dataset import SSDDInsSegDataset
6
+
7
+ __all__ = [
8
+ 'build_dataset', 'PLDataModule',
9
+ ]
mmpl/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (449 Bytes). View file
 
mmpl/datasets/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
mmpl/datasets/__pycache__/nwpu_ins_dataset.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
mmpl/datasets/__pycache__/pl_datamodule.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
mmpl/datasets/__pycache__/ssdd_ins_dataset.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
mmpl/datasets/__pycache__/whu_ins_dataset.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
mmpl/datasets/base_dataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from os import PathLike
4
+ from typing import List, Optional, Sequence, Union
5
+
6
+ import mmengine
7
+ import numpy as np
8
+ from mmengine.dataset import BaseDataset as _BaseDataset
9
+
10
+ from .builder import DATASETS
11
+
12
+
13
+ def expanduser(path):
14
+ """Expand ~ and ~user constructions.
15
+
16
+ If user or $HOME is unknown, do nothing.
17
+ """
18
+ if isinstance(path, (str, PathLike)):
19
+ return osp.expanduser(path)
20
+ else:
21
+ return path
22
+
23
+
24
+ @DATASETS.register_module()
25
+ class BaseDataset(_BaseDataset):
26
+ """Base dataset for image classification task.
27
+
28
+ This dataset support annotation file in `OpenMMLab 2.0 style annotation
29
+ format`.
30
+
31
+ .. _OpenMMLab 2.0 style annotation format:
32
+ https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
33
+
34
+ Comparing with the :class:`mmengine.BaseDataset`, this class implemented
35
+ several useful methods.
36
+
37
+ Args:
38
+ ann_file (str): Annotation file path.
39
+ metainfo (dict, optional): Meta information for dataset, such as class
40
+ information. Defaults to None.
41
+ data_root (str): The root directory for ``data_prefix`` and
42
+ ``ann_file``. Defaults to ''.
43
+ data_prefix (str | dict): Prefix for training data. Defaults to ''.
44
+ filter_cfg (dict, optional): Config for filter data. Defaults to None.
45
+ indices (int or Sequence[int], optional): Support using first few
46
+ data in annotation file to facilitate training/testing on a smaller
47
+ dataset. Defaults to None, which means using all ``data_infos``.
48
+ serialize_data (bool): Whether to hold memory using serialized objects,
49
+ when enabled, data loader workers can use shared RAM from master
50
+ process instead of making a copy. Defaults to True.
51
+ pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
52
+ test_mode (bool): ``test_mode=True`` means in test phase.
53
+ Defaults to False.
54
+ lazy_init (bool): Whether to load annotation during instantiation.
55
+ In some cases, such as visualization, only the meta information of
56
+ the dataset is needed, which is not necessary to load annotation
57
+ file. ``Basedataset`` can skip load annotations to save time by set
58
+ ``lazy_init=False``. Defaults to False.
59
+ max_refetch (int): If ``Basedataset.prepare_data`` get a None img.
60
+ The maximum extra number of cycles to get a valid image.
61
+ Defaults to 1000.
62
+ classes (str | Sequence[str], optional): Specify names of classes.
63
+
64
+ - If is string, it should be a file path, and the every line of
65
+ the file is a name of a class.
66
+ - If is a sequence of string, every item is a name of class.
67
+ - If is None, use categories information in ``metainfo`` argument,
68
+ annotation file or the class attribute ``METAINFO``.
69
+
70
+ Defaults to None.
71
+ """ # noqa: E501
72
+
73
+ def __init__(self,
74
+ ann_file: str = '',
75
+ metainfo: Optional[dict] = None,
76
+ data_root: str = '',
77
+ data_prefix: Union[str, dict] = '',
78
+ filter_cfg: Optional[dict] = None,
79
+ indices: Optional[Union[int, Sequence[int]]] = None,
80
+ serialize_data: bool = True,
81
+ pipeline: Sequence = (),
82
+ test_mode: bool = False,
83
+ lazy_init: bool = False,
84
+ max_refetch: int = 1000,
85
+ classes: Union[str, Sequence[str], None] = None):
86
+ if isinstance(data_prefix, str):
87
+ data_prefix = dict(img_path=expanduser(data_prefix))
88
+
89
+ ann_file = expanduser(ann_file)
90
+ metainfo = self._compat_classes(metainfo, classes)
91
+
92
+ super().__init__(
93
+ ann_file=ann_file,
94
+ metainfo=metainfo,
95
+ data_root=data_root,
96
+ data_prefix=data_prefix,
97
+ filter_cfg=filter_cfg,
98
+ indices=indices,
99
+ serialize_data=serialize_data,
100
+ pipeline=pipeline,
101
+ test_mode=test_mode,
102
+ lazy_init=lazy_init,
103
+ max_refetch=max_refetch)
104
+
105
+ @property
106
+ def img_prefix(self):
107
+ """The prefix of images."""
108
+ return self.data_prefix['img_path']
109
+
110
+ @property
111
+ def CLASSES(self):
112
+ """Return all categories names."""
113
+ return self._metainfo.get('classes', None)
114
+
115
+ @property
116
+ def class_to_idx(self):
117
+ """Map mapping class name to class index.
118
+
119
+ Returns:
120
+ dict: mapping from class name to class index.
121
+ """
122
+
123
+ return {cat: i for i, cat in enumerate(self.CLASSES)}
124
+
125
+ def get_gt_labels(self):
126
+ """Get all ground-truth labels (categories).
127
+
128
+ Returns:
129
+ np.ndarray: categories for all images.
130
+ """
131
+
132
+ gt_labels = np.array(
133
+ [self.get_data_info(i)['gt_label'] for i in range(len(self))])
134
+ return gt_labels
135
+
136
+ def get_cat_ids(self, idx: int) -> List[int]:
137
+ """Get category id by index.
138
+
139
+ Args:
140
+ idx (int): Index of data.
141
+
142
+ Returns:
143
+ cat_ids (List[int]): Image category of specified index.
144
+ """
145
+
146
+ return [int(self.get_data_info(idx)['gt_label'])]
147
+
148
+ def _compat_classes(self, metainfo, classes):
149
+ """Merge the old style ``classes`` arguments to ``metainfo``."""
150
+ if isinstance(classes, str):
151
+ # take it as a file path
152
+ class_names = mmengine.list_from_file(expanduser(classes))
153
+ elif isinstance(classes, (tuple, list)):
154
+ class_names = classes
155
+ elif classes is not None:
156
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
157
+
158
+ if metainfo is None:
159
+ metainfo = {}
160
+
161
+ if classes is not None:
162
+ metainfo = {'classes': tuple(class_names), **metainfo}
163
+
164
+ return metainfo
165
+
166
+ def full_init(self):
167
+ """Load annotation file and set ``BaseDataset._fully_initialized`` to
168
+ True."""
169
+ super().full_init()
170
+
171
+ # To support the standard OpenMMLab 2.0 annotation format. Generate
172
+ # metainfo in internal format from standard metainfo format.
173
+ if 'categories' in self._metainfo and 'classes' not in self._metainfo:
174
+ categories = sorted(
175
+ self._metainfo['categories'], key=lambda x: x['id'])
176
+ self._metainfo['classes'] = tuple(
177
+ [cat['category_name'] for cat in categories])
178
+
179
+ def __repr__(self):
180
+ """Print the basic information of the dataset.
181
+
182
+ Returns:
183
+ str: Formatted string.
184
+ """
185
+ head = 'Dataset ' + self.__class__.__name__
186
+ body = []
187
+ if self._fully_initialized:
188
+ body.append(f'Number of samples: \t{self.__len__()}')
189
+ else:
190
+ body.append("Haven't been initialized")
191
+
192
+ if self.CLASSES is not None:
193
+ body.append(f'Number of categories: \t{len(self.CLASSES)}')
194
+ else:
195
+ body.append('The `CLASSES` meta info is not set.')
196
+
197
+ body.extend(self.extra_repr())
198
+
199
+ if len(self.pipeline.transforms) > 0:
200
+ body.append('With transforms:')
201
+ for t in self.pipeline.transforms:
202
+ body.append(f' {t}')
203
+
204
+ lines = [head] + [' ' * 4 + line for line in body]
205
+ return '\n'.join(lines)
206
+
207
+ def extra_repr(self) -> List[str]:
208
+ """The extra repr information of the dataset."""
209
+ body = []
210
+ body.append(f'Annotation file: \t{self.ann_file}')
211
+ body.append(f'Prefix of images: \t{self.img_prefix}')
212
+ return body
mmpl/datasets/builder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmpl.registry import DATASETS
3
+
4
+
5
+ def build_dataset(cfg):
6
+ """Build dataset.
7
+
8
+ Examples:
9
+ >>> from mmpl.datasets import build_dataset
10
+ >>> mnist_train = build_dataset(
11
+ ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=False))
12
+ >>> print(mnist_train)
13
+ Dataset MNIST
14
+ Number of samples: 60000
15
+ Number of categories: 10
16
+ Prefix of data: data/mnist/
17
+ >>> mnist_test = build_dataset(
18
+ ... dict(type='MNIST', data_prefix='data/mnist/', test_mode=True))
19
+ >>> print(mnist_test)
20
+ Dataset MNIST
21
+ Number of samples: 10000
22
+ Number of categories: 10
23
+ Prefix of data: data/mnist/
24
+ """
25
+ return DATASETS.build(cfg)
mmpl/datasets/custom.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
3
+
4
+ from mmengine.fileio import (BaseStorageBackend, get_file_backend,
5
+ list_from_file)
6
+ from mmengine.logging import MMLogger
7
+
8
+ from mmcls.registry import DATASETS
9
+ from .base_dataset import BaseDataset
10
+
11
+
12
+ def find_folders(
13
+ root: str,
14
+ backend: Optional[BaseStorageBackend] = None
15
+ ) -> Tuple[List[str], Dict[str, int]]:
16
+ """Find classes by folders under a root.
17
+
18
+ Args:
19
+ root (string): root directory of folders
20
+ backend (BaseStorageBackend | None): The file backend of the root.
21
+ If None, auto infer backend from the root path. Defaults to None.
22
+
23
+ Returns:
24
+ Tuple[List[str], Dict[str, int]]:
25
+
26
+ - folders: The name of sub folders under the root.
27
+ - folder_to_idx: The map from folder name to class idx.
28
+ """
29
+ # Pre-build file backend to prevent verbose file backend inference.
30
+ backend = backend or get_file_backend(root, enable_singleton=True)
31
+ folders = list(
32
+ backend.list_dir_or_file(
33
+ root,
34
+ list_dir=True,
35
+ list_file=False,
36
+ recursive=False,
37
+ ))
38
+ folders.sort()
39
+ folder_to_idx = {folders[i]: i for i in range(len(folders))}
40
+ return folders, folder_to_idx
41
+
42
+
43
+ def get_samples(
44
+ root: str,
45
+ folder_to_idx: Dict[str, int],
46
+ is_valid_file: Callable,
47
+ backend: Optional[BaseStorageBackend] = None,
48
+ ):
49
+ """Make dataset by walking all images under a root.
50
+
51
+ Args:
52
+ root (string): root directory of folders
53
+ folder_to_idx (dict): the map from class name to class idx
54
+ is_valid_file (Callable): A function that takes path of a file
55
+ and check if the file is a valid sample file.
56
+ backend (BaseStorageBackend | None): The file backend of the root.
57
+ If None, auto infer backend from the root path. Defaults to None.
58
+
59
+ Returns:
60
+ Tuple[list, set]:
61
+
62
+ - samples: a list of tuple where each element is (image, class_idx)
63
+ - empty_folders: The folders don't have any valid files.
64
+ """
65
+ samples = []
66
+ available_classes = set()
67
+ # Pre-build file backend to prevent verbose file backend inference.
68
+ backend = backend or get_file_backend(root, enable_singleton=True)
69
+
70
+ for folder_name in sorted(list(folder_to_idx.keys())):
71
+ _dir = backend.join_path(root, folder_name)
72
+ files = backend.list_dir_or_file(
73
+ _dir,
74
+ list_dir=False,
75
+ list_file=True,
76
+ recursive=True,
77
+ )
78
+ for file in sorted(list(files)):
79
+ if is_valid_file(file):
80
+ path = backend.join_path(folder_name, file)
81
+ item = (path, folder_to_idx[folder_name])
82
+ samples.append(item)
83
+ available_classes.add(folder_name)
84
+
85
+ empty_folders = set(folder_to_idx.keys()) - available_classes
86
+
87
+ return samples, empty_folders
88
+
89
+
90
+ @DATASETS.register_module()
91
+ class CustomDataset(BaseDataset):
92
+ """Custom dataset for classification.
93
+
94
+ The dataset supports two kinds of annotation format.
95
+
96
+ 1. An annotation file is provided, and each line indicates a sample:
97
+
98
+ The sample files: ::
99
+
100
+ data_prefix/
101
+ ├── folder_1
102
+ │ ├── xxx.png
103
+ │ ├── xxy.png
104
+ │ └── ...
105
+ └── folder_2
106
+ ├── 123.png
107
+ ├── nsdf3.png
108
+ └── ...
109
+
110
+ The annotation file (the first column is the image path and the second
111
+ column is the index of category): ::
112
+
113
+ folder_1/xxx.png 0
114
+ folder_1/xxy.png 1
115
+ folder_2/123.png 5
116
+ folder_2/nsdf3.png 3
117
+ ...
118
+
119
+ Please specify the name of categories by the argument ``classes``
120
+ or ``metainfo``.
121
+
122
+ 2. The samples are arranged in the specific way: ::
123
+
124
+ data_prefix/
125
+ ├── class_x
126
+ │ ├── xxx.png
127
+ │ ├── xxy.png
128
+ │ └── ...
129
+ │ └── xxz.png
130
+ └── class_y
131
+ ├── 123.png
132
+ ├── nsdf3.png
133
+ ├── ...
134
+ └── asd932_.png
135
+
136
+ If the ``ann_file`` is specified, the dataset will be generated by the
137
+ first way, otherwise, try the second way.
138
+
139
+ Args:
140
+ ann_file (str): Annotation file path. Defaults to ''.
141
+ metainfo (dict, optional): Meta information for dataset, such as class
142
+ information. Defaults to None.
143
+ data_root (str): The root directory for ``data_prefix`` and
144
+ ``ann_file``. Defaults to ''.
145
+ data_prefix (str | dict): Prefix for the data. Defaults to ''.
146
+ extensions (Sequence[str]): A sequence of allowed extensions. Defaults
147
+ to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
148
+ lazy_init (bool): Whether to load annotation during instantiation.
149
+ In some cases, such as visualization, only the meta information of
150
+ the dataset is needed, which is not necessary to load annotation
151
+ file. ``Basedataset`` can skip load annotations to save time by set
152
+ ``lazy_init=False``. Defaults to False.
153
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
154
+ """
155
+
156
+ def __init__(self,
157
+ ann_file: str = '',
158
+ metainfo: Optional[dict] = None,
159
+ data_root: str = '',
160
+ data_prefix: Union[str, dict] = '',
161
+ extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
162
+ '.bmp', '.pgm', '.tif'),
163
+ lazy_init: bool = False,
164
+ **kwargs):
165
+ assert (ann_file or data_prefix or data_root), \
166
+ 'One of `ann_file`, `data_root` and `data_prefix` must '\
167
+ 'be specified.'
168
+
169
+ self.extensions = tuple(set([i.lower() for i in extensions]))
170
+
171
+ super().__init__(
172
+ # The base class requires string ann_file but this class doesn't
173
+ ann_file=ann_file,
174
+ metainfo=metainfo,
175
+ data_root=data_root,
176
+ data_prefix=data_prefix,
177
+ # Force to lazy_init for some modification before loading data.
178
+ lazy_init=True,
179
+ **kwargs)
180
+
181
+ # Full initialize the dataset.
182
+ if not lazy_init:
183
+ self.full_init()
184
+
185
+ def _find_samples(self):
186
+ """find samples from ``data_prefix``."""
187
+ classes, folder_to_idx = find_folders(self.img_prefix)
188
+ samples, empty_classes = get_samples(
189
+ self.img_prefix,
190
+ folder_to_idx,
191
+ is_valid_file=self.is_valid_file,
192
+ )
193
+
194
+ if len(samples) == 0:
195
+ raise RuntimeError(
196
+ f'Found 0 files in subfolders of: {self.data_prefix}. '
197
+ f'Supported extensions are: {",".join(self.extensions)}')
198
+
199
+ if self.CLASSES is not None:
200
+ assert len(self.CLASSES) == len(classes), \
201
+ f"The number of subfolders ({len(classes)}) doesn't match " \
202
+ f'the number of specified classes ({len(self.CLASSES)}). ' \
203
+ 'Please check the data folder.'
204
+ else:
205
+ self._metainfo['classes'] = tuple(classes)
206
+
207
+ if empty_classes:
208
+ logger = MMLogger.get_current_instance()
209
+ logger.warning(
210
+ 'Found no valid file in the folder '
211
+ f'{", ".join(empty_classes)}. '
212
+ f"Supported extensions are: {', '.join(self.extensions)}")
213
+
214
+ self.folder_to_idx = folder_to_idx
215
+
216
+ return samples
217
+
218
+ def load_data_list(self):
219
+ """Load image paths and gt_labels."""
220
+ if not self.ann_file:
221
+ samples = self._find_samples()
222
+ else:
223
+ lines = list_from_file(self.ann_file)
224
+ samples = [x.strip().rsplit(' ', 1) for x in lines]
225
+
226
+ # Pre-build file backend to prevent verbose file backend inference.
227
+ backend = get_file_backend(self.img_prefix, enable_singleton=True)
228
+ data_list = []
229
+ for filename, gt_label in samples:
230
+ img_path = backend.join_path(self.img_prefix, filename)
231
+ info = {'img_path': img_path, 'gt_label': int(gt_label)}
232
+ data_list.append(info)
233
+ return data_list
234
+
235
+ def is_valid_file(self, filename: str) -> bool:
236
+ """Check if a file is a valid sample."""
237
+ return filename.lower().endswith(self.extensions)
mmpl/datasets/nwpu_ins_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from mmpl.registry import DATASETS
4
+ from mmdet.datasets.coco import CocoDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class NWPUInsSegDataset(CocoDataset):
9
+ """Dataset for Cityscapes."""
10
+
11
+ METAINFO = {
12
+ 'classes': ['airplane', 'ship', 'storage_tank', 'baseball_diamond',
13
+ 'tennis_court', 'basketball_court', 'ground_track_field',
14
+ 'harbor', 'bridge', 'vehicle'],
15
+ 'palette': [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
16
+ (0, 60, 100), (0, 80, 100), (0, 0, 230),
17
+ (119, 11, 32), (0, 255, 0), (0, 0, 255)]
18
+ }
19
+
20
+ def filter_data(self) -> List[dict]:
21
+ """Filter annotations according to filter_cfg.
22
+
23
+ Returns:
24
+ List[dict]: Filtered results.
25
+ """
26
+ if self.test_mode:
27
+ return self.data_list
28
+
29
+ if self.filter_cfg is None:
30
+ return self.data_list
31
+
32
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
33
+ min_size = self.filter_cfg.get('min_size', 0)
34
+
35
+ # obtain images that contain annotation
36
+ ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
37
+ # obtain images that contain annotations of the required categories
38
+ ids_in_cat = set()
39
+ for i, class_id in enumerate(self.cat_ids):
40
+ ids_in_cat |= set(self.cat_img_map[class_id])
41
+ # merge the image id sets of the two conditions and use the merged set
42
+ # to filter out images if self.filter_empty_gt=True
43
+ ids_in_cat &= ids_with_ann
44
+
45
+ valid_data_infos = []
46
+ for i, data_info in enumerate(self.data_list):
47
+ img_id = data_info['img_id']
48
+ width = data_info['width']
49
+ height = data_info['height']
50
+ all_is_crowd = all([
51
+ instance['ignore_flag'] == 1
52
+ for instance in data_info['instances']
53
+ ])
54
+ if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
55
+ continue
56
+ if min(width, height) >= min_size:
57
+ valid_data_infos.append(data_info)
58
+
59
+ return valid_data_infos
mmpl/datasets/pl_datamodule.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmpl.registry import DATASETS
2
+ import lightning.pytorch as pl
3
+ from torch.utils.data import DataLoader
4
+ from .builder import build_dataset
5
+ from mmengine.registry import FUNCTIONS
6
+ from functools import partial
7
+
8
+
9
+ def get_collate_fn(dataloader_cfg):
10
+ collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate'))
11
+ collate_fn_type = collate_fn_cfg.pop('type')
12
+ collate_fn = FUNCTIONS.get(collate_fn_type)
13
+ collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
14
+ return collate_fn
15
+
16
+
17
+ @DATASETS.register_module()
18
+ class PLDataModule(pl.LightningDataModule):
19
+ def __init__(self,
20
+ train_loader=None,
21
+ val_loader=None,
22
+ test_loader=None,
23
+ predict_loader=None,
24
+ **kwargs
25
+ ):
26
+ super().__init__()
27
+ self.train_loader = train_loader
28
+ self.val_loader = val_loader
29
+ self.test_loader = test_loader
30
+ self.predict_loader = predict_loader
31
+ self.train_dataset = None
32
+ self.val_dataset = None
33
+ self.test_dataset = None
34
+ self.predict_dataset = None
35
+
36
+ def prepare_data(self):
37
+ pass
38
+
39
+ def setup(self, stage: str):
40
+ if stage == "fit":
41
+ dataset_cfg = self.train_loader.pop('dataset')
42
+ self.train_dataset = build_dataset(dataset_cfg)
43
+ if self.val_loader is not None:
44
+ dataset_cfg = self.val_loader.pop('dataset')
45
+ self.val_dataset = build_dataset(dataset_cfg)
46
+ if stage == "val":
47
+ if self.val_loader is not None:
48
+ dataset_cfg = self.val_loader.pop('dataset')
49
+ self.val_dataset = build_dataset(dataset_cfg)
50
+ if stage == "test":
51
+ if self.test_loader is not None:
52
+ dataset_cfg = self.test_loader.pop('dataset')
53
+ self.test_dataset = build_dataset(dataset_cfg)
54
+ if stage == "predict":
55
+ if self.predict_loader is not None:
56
+ dataset_cfg = self.predict_loader.pop('dataset')
57
+ self.predict_dataset = build_dataset(dataset_cfg)
58
+
59
+ def train_dataloader(self):
60
+ collate_fn = get_collate_fn(self.train_loader)
61
+ return DataLoader(self.train_dataset, collate_fn=collate_fn, **self.train_loader)
62
+
63
+ def val_dataloader(self):
64
+ collate_fn = get_collate_fn(self.val_loader)
65
+ return DataLoader(self.val_dataset, collate_fn=collate_fn, **self.val_loader)
66
+
67
+ def test_dataloader(self):
68
+ collate_fn = get_collate_fn(self.test_loader)
69
+ return DataLoader(self.test_dataset, collate_fn=collate_fn, **self.test_loader)
70
+
71
+ def predict_dataloader(self):
72
+ collate_fn = get_collate_fn(self.predict_loader)
73
+ return DataLoader(self.predict_dataset, collate_fn=collate_fn, **self.predict_loader)
mmpl/datasets/ssdd_ins_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from mmpl.registry import DATASETS
3
+ from mmdet.datasets.coco import CocoDataset
4
+
5
+
6
+ @DATASETS.register_module()
7
+ class SSDDInsSegDataset(CocoDataset):
8
+ """Dataset for Cityscapes."""
9
+
10
+ METAINFO = {
11
+ 'classes': ['ship'],
12
+ 'palette': [(0, 0, 255)]
13
+ }
14
+
15
+ def filter_data(self) -> List[dict]:
16
+ """Filter annotations according to filter_cfg.
17
+
18
+ Returns:
19
+ List[dict]: Filtered results.
20
+ """
21
+ # if self.test_mode:
22
+ # return self.data_list
23
+
24
+ if self.filter_cfg is None:
25
+ return self.data_list
26
+
27
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
28
+ min_size = self.filter_cfg.get('min_size', 0)
29
+
30
+ # obtain images that contain annotation
31
+ ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
32
+ # obtain images that contain annotations of the required categories
33
+ ids_in_cat = set()
34
+ for i, class_id in enumerate(self.cat_ids):
35
+ ids_in_cat |= set(self.cat_img_map[class_id])
36
+ # merge the image id sets of the two conditions and use the merged set
37
+ # to filter out images if self.filter_empty_gt=True
38
+ ids_in_cat &= ids_with_ann
39
+
40
+ valid_data_infos = []
41
+ for i, data_info in enumerate(self.data_list):
42
+ img_id = data_info['img_id']
43
+ width = data_info['width']
44
+ height = data_info['height']
45
+ all_is_crowd = all([
46
+ instance['ignore_flag'] == 1
47
+ for instance in data_info['instances']
48
+ ])
49
+ if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
50
+ continue
51
+ if min(width, height) >= min_size:
52
+ valid_data_infos.append(data_info)
53
+
54
+ return valid_data_infos
mmpl/datasets/transforms/__init__.py ADDED
File without changes
mmpl/datasets/transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (162 Bytes). View file
 
mmpl/datasets/utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import gzip
3
+ import hashlib
4
+ import os
5
+ import os.path
6
+ import shutil
7
+ import tarfile
8
+ import tempfile
9
+ import urllib.error
10
+ import urllib.request
11
+ import zipfile
12
+
13
+ from mmengine.fileio import LocalBackend, get_file_backend
14
+
15
+ __all__ = [
16
+ 'rm_suffix', 'check_integrity', 'download_and_extract_archive',
17
+ 'open_maybe_compressed_file'
18
+ ]
19
+
20
+
21
+ def rm_suffix(s, suffix=None):
22
+ if suffix is None:
23
+ return s[:s.rfind('.')]
24
+ else:
25
+ return s[:s.rfind(suffix)]
26
+
27
+
28
+ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
29
+ md5 = hashlib.md5()
30
+ backend = get_file_backend(fpath, enable_singleton=True)
31
+ if isinstance(backend, LocalBackend):
32
+ # Enable chunk update for local file.
33
+ with open(fpath, 'rb') as f:
34
+ for chunk in iter(lambda: f.read(chunk_size), b''):
35
+ md5.update(chunk)
36
+ else:
37
+ md5.update(backend.get(fpath))
38
+ return md5.hexdigest()
39
+
40
+
41
+ def check_md5(fpath, md5, **kwargs):
42
+ return md5 == calculate_md5(fpath, **kwargs)
43
+
44
+
45
+ def check_integrity(fpath, md5=None):
46
+ if not os.path.isfile(fpath):
47
+ return False
48
+ if md5 is None:
49
+ return True
50
+ return check_md5(fpath, md5)
51
+
52
+
53
+ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
54
+ """Download object at the given URL to a local path.
55
+
56
+ Modified from
57
+ https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file
58
+
59
+ Args:
60
+ url (str): URL of the object to download
61
+ dst (str): Full path where object will be saved,
62
+ e.g. ``/tmp/temporary_file``
63
+ hash_prefix (string, optional): If not None, the SHA256 downloaded
64
+ file should start with ``hash_prefix``. Defaults to None.
65
+ progress (bool): whether or not to display a progress bar to stderr.
66
+ Defaults to True
67
+ """
68
+ file_size = None
69
+ req = urllib.request.Request(url)
70
+ u = urllib.request.urlopen(req)
71
+ meta = u.info()
72
+ if hasattr(meta, 'getheaders'):
73
+ content_length = meta.getheaders('Content-Length')
74
+ else:
75
+ content_length = meta.get_all('Content-Length')
76
+ if content_length is not None and len(content_length) > 0:
77
+ file_size = int(content_length[0])
78
+
79
+ # We deliberately save it in a temp file and move it after download is
80
+ # complete. This prevents a local file being overridden by a broken
81
+ # download.
82
+ dst = os.path.expanduser(dst)
83
+ dst_dir = os.path.dirname(dst)
84
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
85
+
86
+ import rich.progress
87
+ columns = [
88
+ rich.progress.DownloadColumn(),
89
+ rich.progress.BarColumn(bar_width=None),
90
+ rich.progress.TimeRemainingColumn(),
91
+ ]
92
+ try:
93
+ if hash_prefix is not None:
94
+ sha256 = hashlib.sha256()
95
+ with rich.progress.Progress(*columns) as pbar:
96
+ task = pbar.add_task('download', total=file_size, visible=progress)
97
+ while True:
98
+ buffer = u.read(8192)
99
+ if len(buffer) == 0:
100
+ break
101
+ f.write(buffer)
102
+ if hash_prefix is not None:
103
+ sha256.update(buffer)
104
+ pbar.update(task, advance=len(buffer))
105
+
106
+ f.close()
107
+ if hash_prefix is not None:
108
+ digest = sha256.hexdigest()
109
+ if digest[:len(hash_prefix)] != hash_prefix:
110
+ raise RuntimeError(
111
+ 'invalid hash value (expected "{}", got "{}")'.format(
112
+ hash_prefix, digest))
113
+ shutil.move(f.name, dst)
114
+ finally:
115
+ f.close()
116
+ if os.path.exists(f.name):
117
+ os.remove(f.name)
118
+
119
+
120
+ def download_url(url, root, filename=None, md5=None):
121
+ """Download a file from a url and place it in root.
122
+
123
+ Args:
124
+ url (str): URL to download file from.
125
+ root (str): Directory to place downloaded file in.
126
+ filename (str | None): Name to save the file under.
127
+ If filename is None, use the basename of the URL.
128
+ md5 (str | None): MD5 checksum of the download.
129
+ If md5 is None, download without md5 check.
130
+ """
131
+ root = os.path.expanduser(root)
132
+ if not filename:
133
+ filename = os.path.basename(url)
134
+ fpath = os.path.join(root, filename)
135
+
136
+ os.makedirs(root, exist_ok=True)
137
+
138
+ if check_integrity(fpath, md5):
139
+ print(f'Using downloaded and verified file: {fpath}')
140
+ else:
141
+ try:
142
+ print(f'Downloading {url} to {fpath}')
143
+ download_url_to_file(url, fpath)
144
+ except (urllib.error.URLError, IOError) as e:
145
+ if url[:5] == 'https':
146
+ url = url.replace('https:', 'http:')
147
+ print('Failed download. Trying https -> http instead.'
148
+ f' Downloading {url} to {fpath}')
149
+ download_url_to_file(url, fpath)
150
+ else:
151
+ raise e
152
+ # check integrity of downloaded file
153
+ if not check_integrity(fpath, md5):
154
+ raise RuntimeError('File not found or corrupted.')
155
+
156
+
157
+ def _is_tarxz(filename):
158
+ return filename.endswith('.tar.xz')
159
+
160
+
161
+ def _is_tar(filename):
162
+ return filename.endswith('.tar')
163
+
164
+
165
+ def _is_targz(filename):
166
+ return filename.endswith('.tar.gz')
167
+
168
+
169
+ def _is_tgz(filename):
170
+ return filename.endswith('.tgz')
171
+
172
+
173
+ def _is_gzip(filename):
174
+ return filename.endswith('.gz') and not filename.endswith('.tar.gz')
175
+
176
+
177
+ def _is_zip(filename):
178
+ return filename.endswith('.zip')
179
+
180
+
181
+ def extract_archive(from_path, to_path=None, remove_finished=False):
182
+ if to_path is None:
183
+ to_path = os.path.dirname(from_path)
184
+
185
+ if _is_tar(from_path):
186
+ with tarfile.open(from_path, 'r') as tar:
187
+ tar.extractall(path=to_path)
188
+ elif _is_targz(from_path) or _is_tgz(from_path):
189
+ with tarfile.open(from_path, 'r:gz') as tar:
190
+ tar.extractall(path=to_path)
191
+ elif _is_tarxz(from_path):
192
+ with tarfile.open(from_path, 'r:xz') as tar:
193
+ tar.extractall(path=to_path)
194
+ elif _is_gzip(from_path):
195
+ to_path = os.path.join(
196
+ to_path,
197
+ os.path.splitext(os.path.basename(from_path))[0])
198
+ with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f:
199
+ out_f.write(zip_f.read())
200
+ elif _is_zip(from_path):
201
+ with zipfile.ZipFile(from_path, 'r') as z:
202
+ z.extractall(to_path)
203
+ else:
204
+ raise ValueError(f'Extraction of {from_path} not supported')
205
+
206
+ if remove_finished:
207
+ os.remove(from_path)
208
+
209
+
210
+ def download_and_extract_archive(url,
211
+ download_root,
212
+ extract_root=None,
213
+ filename=None,
214
+ md5=None,
215
+ remove_finished=False):
216
+ download_root = os.path.expanduser(download_root)
217
+ if extract_root is None:
218
+ extract_root = download_root
219
+ if not filename:
220
+ filename = os.path.basename(url)
221
+
222
+ download_url(url, download_root, filename, md5)
223
+
224
+ archive = os.path.join(download_root, filename)
225
+ print(f'Extracting {archive} to {extract_root}')
226
+ extract_archive(archive, extract_root, remove_finished)
227
+
228
+
229
+ def open_maybe_compressed_file(path: str):
230
+ """Return a file object that possibly decompresses 'path' on the fly.
231
+
232
+ Decompression occurs when argument `path` is a string and ends with '.gz'
233
+ or '.xz'.
234
+ """
235
+ if not isinstance(path, str):
236
+ return path
237
+ if path.endswith('.gz'):
238
+ import gzip
239
+ return gzip.open(path, 'rb')
240
+ if path.endswith('.xz'):
241
+ import lzma
242
+ return lzma.open(path, 'rb')
243
+ return open(path, 'rb')
mmpl/datasets/whu_ins_dataset.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from mmpl.registry import DATASETS
3
+ from mmdet.datasets.coco import CocoDataset
4
+
5
+
6
+ @DATASETS.register_module()
7
+ class WHUInsSegDataset(CocoDataset):
8
+ """Dataset for Cityscapes."""
9
+
10
+ METAINFO = {
11
+ 'classes': ['building'],
12
+ 'palette': [(0, 255, 0)]
13
+ }
14
+
15
+ def filter_data(self) -> List[dict]:
16
+ """Filter annotations according to filter_cfg.
17
+
18
+ Returns:
19
+ List[dict]: Filtered results.
20
+ """
21
+ # if self.test_mode:
22
+ # return self.data_list
23
+
24
+ if self.filter_cfg is None:
25
+ return self.data_list
26
+
27
+ filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
28
+ min_size = self.filter_cfg.get('min_size', 0)
29
+
30
+ # obtain images that contain annotation
31
+ ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
32
+ # obtain images that contain annotations of the required categories
33
+ ids_in_cat = set()
34
+ for i, class_id in enumerate(self.cat_ids):
35
+ ids_in_cat |= set(self.cat_img_map[class_id])
36
+ # merge the image id sets of the two conditions and use the merged set
37
+ # to filter out images if self.filter_empty_gt=True
38
+ ids_in_cat &= ids_with_ann
39
+
40
+ valid_data_infos = []
41
+ for i, data_info in enumerate(self.data_list):
42
+ img_id = data_info['img_id']
43
+ width = data_info['width']
44
+ height = data_info['height']
45
+ all_is_crowd = all([
46
+ instance['ignore_flag'] == 1
47
+ for instance in data_info['instances']
48
+ ])
49
+ if filter_empty_gt and (img_id not in ids_in_cat or all_is_crowd):
50
+ continue
51
+ if min(width, height) >= min_size:
52
+ valid_data_infos.append(data_info)
53
+
54
+ return valid_data_infos
mmpl/engine/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .runner import *
2
+ from .logger import *
3
+ from .hooks import *
4
+ from .visualization import *
5
+ from .strategies import *
mmpl/engine/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (254 Bytes). View file
 
mmpl/engine/hooks/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .builder import PL_HOOKS
2
+ from .pipeline_switch_hook import PipelineSwitchHook
3
+ from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
4
+ from .ema_hook import EMAHook
5
+ from .param_scheduler_hook import ParamSchedulerHook
6
+ from .visualization_hook import DetVisualizationHook
mmpl/engine/hooks/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (500 Bytes). View file
 
mmpl/engine/hooks/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
mmpl/engine/hooks/__pycache__/ema_hook.cpython-310.pyc ADDED
Binary file (8.86 kB). View file
 
mmpl/engine/hooks/__pycache__/param_scheduler_hook.cpython-310.pyc ADDED
Binary file (4.26 kB). View file
 
mmpl/engine/hooks/__pycache__/pipeline_switch_hook.cpython-310.pyc ADDED
Binary file (1.55 kB). View file
 
mmpl/engine/hooks/__pycache__/visualization_hook.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
mmpl/engine/hooks/__pycache__/yolov5_param_scheduler_hook.cpython-310.pyc ADDED
Binary file (3.85 kB). View file
 
mmpl/engine/hooks/builder.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import inspect
3
+ from typing import List, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import lightning
8
+
9
+ from mmengine.config import Config, ConfigDict
10
+ from mmengine.device import is_npu_available
11
+ from mmpl.registry import HOOKS
12
+
13
+
14
+ def register_pl_hooks() -> List[str]:
15
+ """Register callbacks in ``lightning.pytorch.callbacks`` to the ``HOOKS`` registry.
16
+
17
+ Returns:
18
+ List[str]: A list of registered callbacks' name.
19
+ """
20
+ pl_hooks = []
21
+ for module_name in dir(lightning.pytorch.callbacks):
22
+ if module_name.startswith('__'):
23
+ continue
24
+ _hook = getattr(lightning.pytorch.callbacks, module_name)
25
+ if inspect.isclass(_hook) and issubclass(_hook, lightning.pytorch.callbacks.Callback):
26
+ HOOKS.register_module(module=_hook)
27
+ pl_hooks.append(module_name)
28
+ return pl_hooks
29
+
30
+
31
+ PL_HOOKS = register_pl_hooks()
mmpl/engine/hooks/ema_hook.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import logging
4
+ from typing import Dict, Optional, Any
5
+
6
+ from lightning import Callback
7
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from mmengine.logging import print_log
9
+ from mmengine.model import is_model_wrapper
10
+ from mmpl.registry import HOOKS, MODELS
11
+
12
+
13
+
14
+ @HOOKS.register_module()
15
+ class EMAHook(Callback):
16
+ """A Hook to apply Exponential Moving Average (EMA) on the model during
17
+ training.
18
+
19
+ Note:
20
+ - EMAHook takes priority over CheckpointHook.
21
+ - The original model parameters are actually saved in ema field after
22
+ train.
23
+ - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
24
+
25
+ Args:
26
+ ema_type (str): The type of EMA strategy to use. You can find the
27
+ supported strategies in :mod:`mmengine.model.averaged_model`.
28
+ Defaults to 'ExponentialMovingAverage'.
29
+ strict_load (bool): Whether to strictly enforce that the keys of
30
+ ``state_dict`` in checkpoint match the keys returned by
31
+ ``self.module.state_dict``. Defaults to False.
32
+ Changed in v0.3.0.
33
+ begin_iter (int): The number of iteration to enable ``EMAHook``.
34
+ Defaults to 0.
35
+ begin_epoch (int): The number of epoch to enable ``EMAHook``.
36
+ Defaults to 0.
37
+ **kwargs: Keyword arguments passed to subclasses of
38
+ :obj:`BaseAveragedModel`
39
+ """
40
+
41
+ priority = 'NORMAL'
42
+
43
+ def __init__(self,
44
+ ema_type: str = 'ExponentialMovingAverage',
45
+ strict_load: bool = False,
46
+ begin_iter: int = 0,
47
+ begin_epoch: int = 0,
48
+ **kwargs):
49
+ self.strict_load = strict_load
50
+ self.ema_cfg = dict(type=ema_type, **kwargs)
51
+ assert not (begin_iter != 0 and begin_epoch != 0), (
52
+ '`begin_iter` and `begin_epoch` should not be both set.')
53
+ assert begin_iter >= 0, (
54
+ '`begin_iter` must larger than or equal to 0, '
55
+ f'but got begin_iter: {begin_iter}')
56
+ assert begin_epoch >= 0, (
57
+ '`begin_epoch` must larger than or equal to 0, '
58
+ f'but got begin_epoch: {begin_epoch}')
59
+ self.begin_iter = begin_iter
60
+ self.begin_epoch = begin_epoch
61
+ # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
62
+ # enabled at 0 iteration.
63
+ self.enabled_by_epoch = self.begin_epoch > 0
64
+
65
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
66
+ """Create an ema copy of the model.
67
+
68
+ Args:
69
+ runner (Runner): The runner of the training process.
70
+ """
71
+ model = pl_module
72
+ if is_model_wrapper(model):
73
+ model = model.module
74
+ self.src_model = model
75
+ self.ema_model = MODELS.build(
76
+ self.ema_cfg, default_args=dict(model=self.src_model))
77
+
78
+ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
79
+ """Check the begin_epoch/iter is smaller than max_epochs/iters.
80
+
81
+ Args:
82
+ runner (Runner): The runner of the training process.
83
+ """
84
+ if self.enabled_by_epoch:
85
+ assert self.begin_epoch <= trainer.max_epochs, (
86
+ 'self.begin_epoch should be smaller than or equal to '
87
+ f'runner.max_epochs: {trainer.max_epochs}, but got '
88
+ f'begin_epoch: {self.begin_epoch}')
89
+ else:
90
+ assert self.begin_iter <= trainer.max_steps or self.begin_iter <= trainer.max_epochs * len(trainer.train_dataloader), (
91
+ 'self.begin_iter should be smaller than or equal to '
92
+ f'runner.max_iters: {trainer.max_steps}, but got '
93
+ f'begin_iter: {self.begin_iter}')
94
+
95
+ def on_train_batch_end(
96
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
97
+ ) -> None:
98
+ """Update ema parameter.
99
+
100
+ Args:
101
+ runner (Runner): The runner of the training process.
102
+ batch_idx (int): The index of the current batch in the train loop.
103
+ data_batch (Sequence[dict], optional): Data from dataloader.
104
+ Defaults to None.
105
+ outputs (dict, optional): Outputs from model. Defaults to None.
106
+ """
107
+ if self._ema_started(trainer):
108
+ self.ema_model.update_parameters(self.src_model)
109
+ else:
110
+ ema_params = self.ema_model.module.state_dict()
111
+ src_params = self.src_model.state_dict()
112
+ for k, p in ema_params.items():
113
+ p.data.copy_(src_params[k].data)
114
+
115
+ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
116
+ """We load parameter values from ema model to source model before
117
+ validation.
118
+
119
+ Args:
120
+ runner (Runner): The runner of the training process.
121
+ """
122
+ self._swap_ema_parameters()
123
+
124
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
125
+ """We recover source model's parameter from ema model after validation.
126
+
127
+ Args:
128
+ runner (Runner): The runner of the validation process.
129
+ metrics (Dict[str, float], optional): Evaluation results of all
130
+ metrics on validation dataset. The keys are the names of the
131
+ metrics, and the values are corresponding results.
132
+ """
133
+ self._swap_ema_parameters()
134
+
135
+ def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
136
+ """We load parameter values from ema model to source model before test.
137
+
138
+ Args:
139
+ runner (Runner): The runner of the training process.
140
+ """
141
+ self._swap_ema_parameters()
142
+
143
+ def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
144
+ """We recover source model's parameter from ema model after test.
145
+
146
+ Args:
147
+ runner (Runner): The runner of the testing process.
148
+ metrics (Dict[str, float], optional): Evaluation results of all
149
+ metrics on test dataset. The keys are the names of the
150
+ metrics, and the values are corresponding results.
151
+ """
152
+ self._swap_ema_parameters()
153
+
154
+ def on_save_checkpoint(
155
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
156
+ ) -> None:
157
+ """Save ema parameters to checkpoint.
158
+
159
+ Args:
160
+ runner (Runner): The runner of the testing process.
161
+ """
162
+ checkpoint['ema_state_dict'] = self.ema_model.state_dict()
163
+ # Save ema parameters to the source model's state dict so that we
164
+ # can directly load the averaged model weights for deployment.
165
+ # Swapping the state_dict key-values instead of swapping model
166
+ # parameters because the state_dict is a shallow copy of model
167
+ # parameters.
168
+ self._swap_ema_state_dict(checkpoint)
169
+
170
+ def on_load_checkpoint(
171
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
172
+ ) -> None:
173
+ """Resume ema parameters from checkpoint.
174
+
175
+ Args:
176
+ runner (Runner): The runner of the testing process.
177
+ """
178
+ from mmengine.runner.checkpoint import load_state_dict
179
+ if 'ema_state_dict' in checkpoint and not trainer._checkpoint_connector._loaded_checkpoint:
180
+ # The original model parameters are actually saved in ema
181
+ # field swap the weights back to resume ema state.
182
+ self._swap_ema_state_dict(checkpoint)
183
+ self.ema_model.load_state_dict(
184
+ checkpoint['ema_state_dict'], strict=self.strict_load)
185
+
186
+ # Support load checkpoint without ema state dict.
187
+ else:
188
+ if not trainer._checkpoint_connector._loaded_checkpoint:
189
+ print_log(
190
+ 'There is no `ema_state_dict` in checkpoint. '
191
+ '`EMAHook` will make a copy of `state_dict` as the '
192
+ 'initial `ema_state_dict`', 'current', logging.WARNING)
193
+ load_state_dict(
194
+ self.ema_model.module,
195
+ copy.deepcopy(checkpoint['state_dict']),
196
+ strict=self.strict_load)
197
+
198
+ def _swap_ema_parameters(self) -> None:
199
+ """Swap the parameter of model with ema_model."""
200
+ avg_param = (
201
+ itertools.chain(self.ema_model.module.parameters(),
202
+ self.ema_model.module.buffers())
203
+ if self.ema_model.update_buffers else
204
+ self.ema_model.module.parameters())
205
+ src_param = (
206
+ itertools.chain(self.src_model.parameters(),
207
+ self.src_model.buffers())
208
+ if self.ema_model.update_buffers else self.src_model.parameters())
209
+ for p_avg, p_src in zip(avg_param, src_param):
210
+ tmp = p_avg.data.clone()
211
+ p_avg.data.copy_(p_src.data)
212
+ p_src.data.copy_(tmp)
213
+
214
+ def _swap_ema_state_dict(self, checkpoint):
215
+ """Swap the state dict values of model with ema_model."""
216
+ model_state = checkpoint['state_dict']
217
+ ema_state = checkpoint['ema_state_dict']
218
+ for k in ema_state:
219
+ if k[:7] == 'module.':
220
+ tmp = ema_state[k]
221
+ ema_state[k] = model_state[k[7:]]
222
+ model_state[k[7:]] = tmp
223
+
224
+ def _ema_started(self, trainer) -> bool:
225
+ """Whether ``EMAHook`` has been initialized at current iteration or
226
+ epoch.
227
+
228
+ :attr:`ema_model` will be initialized when ``runner.iter`` or
229
+ ``runner.epoch`` is greater than ``self.begin`` for the first time.
230
+
231
+ Args:
232
+ runner (Runner): Runner of the training, validation process.
233
+
234
+ Returns:
235
+ bool: Whether ``EMAHook`` has been initialized.
236
+ """
237
+ if self.enabled_by_epoch:
238
+ return trainer.current_epoch + 1 >= self.begin_epoch
239
+ else:
240
+ return trainer.global_step + 1 >= self.begin_iter
mmpl/engine/hooks/param_scheduler_hook.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union, Any
2
+
3
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
4
+ from mmengine.optim import _ParamScheduler
5
+ from mmpl.registry import HOOKS
6
+ from mmengine.utils import is_list_of
7
+ from lightning import Callback
8
+
9
+ DATA_BATCH = Optional[Union[dict, tuple, list]]
10
+
11
+
12
+ @HOOKS.register_module()
13
+ class ParamSchedulerHook(Callback):
14
+ """A hook to update some hyper-parameters in optimizer, e.g., learning rate
15
+ and momentum."""
16
+
17
+ priority = 'LOW'
18
+
19
+ def on_train_batch_end(
20
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
21
+ ) -> None:
22
+ """Call step function for each scheduler after each training iteration.
23
+
24
+ Args:
25
+ runner (Runner): The runner of the training process.
26
+ batch_idx (int): The index of the current batch in the train loop.
27
+ data_batch (dict or tuple or list, optional): Data from dataloader.
28
+ In order to keep this interface consistent with other hooks,
29
+ we keep ``data_batch`` here.
30
+ outputs (dict, optional): Outputs from model.
31
+ In order to keep this interface consistent with other hooks, we
32
+ keep ``data_batch`` here.
33
+ """
34
+ param_schedulers = pl_module.lr_schedulers()
35
+ if param_schedulers is None:
36
+ return
37
+
38
+ def step(param_schedulers):
39
+ assert isinstance(param_schedulers, list)
40
+ for scheduler in param_schedulers:
41
+ if not scheduler.by_epoch:
42
+ scheduler.step()
43
+ if isinstance(param_schedulers, _ParamScheduler):
44
+ param_schedulers = [param_schedulers]
45
+ if isinstance(param_schedulers, list):
46
+ step(param_schedulers)
47
+ elif isinstance(param_schedulers, dict):
48
+ for param_schedulers in param_schedulers.values():
49
+ step(param_schedulers)
50
+ else:
51
+ raise TypeError(
52
+ 'runner.param_schedulers should be list of ParamScheduler or '
53
+ 'a dict containing list of ParamScheduler, '
54
+ f'but got {param_schedulers}')
55
+
56
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
57
+ """Call step function for each scheduler after each training epoch.
58
+
59
+ Args:
60
+ runner (Runner): The runner of the training process.
61
+ """
62
+ param_schedulers = pl_module.lr_schedulers()
63
+ if param_schedulers is None:
64
+ return
65
+
66
+ def step(param_schedulers):
67
+ assert isinstance(param_schedulers, list)
68
+ for scheduler in param_schedulers:
69
+ if scheduler.by_epoch:
70
+ scheduler.step()
71
+ if isinstance(param_schedulers, _ParamScheduler):
72
+ param_schedulers = [param_schedulers]
73
+ if isinstance(param_schedulers, list):
74
+ step(param_schedulers)
75
+ elif isinstance(param_schedulers, dict):
76
+ for param_schedulers in param_schedulers.values():
77
+ step(param_schedulers)
78
+ else:
79
+ raise TypeError(
80
+ 'runner.param_schedulers should be list of ParamScheduler or '
81
+ 'a dict containing list of ParamScheduler, '
82
+ f'but got {param_schedulers}')
83
+
84
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
85
+ """Call step function for each scheduler which has attribute
86
+ ``need_val_args`` after each validation epoch.
87
+
88
+ Args:
89
+ runner (Runner): The runner of the validation process.
90
+ metrics (Dict[str, float], optional): Evaluation results of all
91
+ metrics on validation dataset. The keys are the names of the
92
+ metrics, and the values are corresponding results.
93
+
94
+ Note:
95
+ if ``runner.param_schedulers`` is not built before,
96
+ the hook ``after_val_epoch`` will be skipped.
97
+ """
98
+ param_schedulers = pl_module.lr_schedulers()
99
+ if param_schedulers is None:
100
+ return
101
+
102
+ # avoid counting scheduler._global_step
103
+ # it has counted in after_train_* hook
104
+ metrics = trainer.callback_metrics
105
+ if metrics is None:
106
+ return
107
+
108
+ def step(param_schedulers):
109
+ # check param_schedulers is list and built
110
+ if not is_list_of(param_schedulers, _ParamScheduler):
111
+ return
112
+
113
+ for scheduler in param_schedulers:
114
+ if (scheduler.by_epoch
115
+ and getattr(scheduler, 'need_val_args', False)):
116
+ scheduler.step(metrics)
117
+ if isinstance(param_schedulers, _ParamScheduler):
118
+ param_schedulers = [param_schedulers]
119
+ if isinstance(param_schedulers, list):
120
+ step(param_schedulers)
121
+ elif isinstance(param_schedulers, dict):
122
+ for param_schedulers in param_schedulers.values():
123
+ step(param_schedulers)
124
+ else:
125
+ raise TypeError(
126
+ 'runner.param_schedulers should be list of ParamScheduler or '
127
+ 'a dict containing list of ParamScheduler, '
128
+ f'but got {param_schedulers}')
mmpl/engine/hooks/pipeline_switch_hook.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv.transforms import Compose
2
+ from mmpl.registry import HOOKS
3
+ from lightning.pytorch.callbacks import Callback
4
+
5
+
6
+ @HOOKS.register_module()
7
+ class PipelineSwitchHook(Callback):
8
+ """Switch data pipeline at switch_epoch.
9
+
10
+ Args:
11
+ switch_epoch (int): switch pipeline at this epoch.
12
+ switch_pipeline (list[dict]): the pipeline to switch to.
13
+ """
14
+
15
+ def __init__(self, switch_epoch, switch_pipeline):
16
+ self.switch_epoch = switch_epoch
17
+ self.switch_pipeline = switch_pipeline
18
+ self._restart_dataloader = False
19
+
20
+ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
21
+ """switch pipeline."""
22
+ epoch = trainer.current_epoch
23
+ train_loader = trainer.train_dataloader
24
+ if epoch == self.switch_epoch:
25
+ if trainer.local_rank == 0:
26
+ print('Switch pipeline now!')
27
+ # The dataset pipeline cannot be updated when persistent_workers
28
+ # is True, so we need to force the dataloader's multi-process
29
+ # restart. This is a very hacky approach.
30
+ train_loader.dataset.pipeline = Compose(self.switch_pipeline)
31
+ if hasattr(train_loader, 'persistent_workers'
32
+ ) and train_loader.persistent_workers is True:
33
+ train_loader._DataLoader__initialized = False
34
+ train_loader._iterator = None
35
+ self._restart_dataloader = True
36
+
37
+ else:
38
+ # Once the restart is complete, we need to restore
39
+ # the initialization flag.
40
+ if self._restart_dataloader:
41
+ train_loader._DataLoader__initialized = True
mmpl/engine/hooks/ppyoloe_param_scheduler_hook.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import Optional
4
+
5
+ from mmengine.hooks import ParamSchedulerHook
6
+ from mmengine.runner import Runner
7
+
8
+ from mmyolo.registry import HOOKS
9
+
10
+
11
+ @HOOKS.register_module()
12
+ class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
13
+ """A hook to update learning rate and momentum in optimizer of PPYOLOE. We
14
+ use this hook to implement adaptive computation for `warmup_total_iters`,
15
+ which is not possible with the built-in ParamScheduler in mmyolo.
16
+
17
+ Args:
18
+ warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
19
+ start_factor (float): The number we multiply learning rate in the
20
+ first epoch. The multiplication factor changes towards end_factor
21
+ in the following epochs. Defaults to 0.
22
+ warmup_epochs (int): Epochs for warmup. Defaults to 5.
23
+ min_lr_ratio (float): Minimum learning rate ratio.
24
+ total_epochs (int): In PPYOLOE, `total_epochs` is set to
25
+ training_epochs x 1.2. Defaults to 360.
26
+ """
27
+ priority = 9
28
+
29
+ def __init__(self,
30
+ warmup_min_iter: int = 1000,
31
+ start_factor: float = 0.,
32
+ warmup_epochs: int = 5,
33
+ min_lr_ratio: float = 0.0,
34
+ total_epochs: int = 360):
35
+
36
+ self.warmup_min_iter = warmup_min_iter
37
+ self.start_factor = start_factor
38
+ self.warmup_epochs = warmup_epochs
39
+ self.min_lr_ratio = min_lr_ratio
40
+ self.total_epochs = total_epochs
41
+
42
+ self._warmup_end = False
43
+ self._base_lr = None
44
+
45
+ def before_train(self, runner: Runner):
46
+ """Operations before train.
47
+
48
+ Args:
49
+ runner (Runner): The runner of the training process.
50
+ """
51
+ optimizer = runner.optim_wrapper.optimizer
52
+ for group in optimizer.param_groups:
53
+ # If the param is never be scheduled, record the current value
54
+ # as the initial value.
55
+ group.setdefault('initial_lr', group['lr'])
56
+
57
+ self._base_lr = [
58
+ group['initial_lr'] for group in optimizer.param_groups
59
+ ]
60
+ self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
61
+
62
+ def before_train_iter(self,
63
+ runner: Runner,
64
+ batch_idx: int,
65
+ data_batch: Optional[dict] = None):
66
+ """Operations before each training iteration.
67
+
68
+ Args:
69
+ runner (Runner): The runner of the training process.
70
+ batch_idx (int): The index of the current batch in the train loop.
71
+ data_batch (dict or tuple or list, optional): Data from dataloader.
72
+ """
73
+ cur_iters = runner.iter
74
+ optimizer = runner.optim_wrapper.optimizer
75
+ dataloader_len = len(runner.train_dataloader)
76
+
77
+ # The minimum warmup is self.warmup_min_iter
78
+ warmup_total_iters = max(
79
+ round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
80
+
81
+ if cur_iters <= warmup_total_iters:
82
+ # warm up
83
+ alpha = cur_iters / warmup_total_iters
84
+ factor = self.start_factor * (1 - alpha) + alpha
85
+
86
+ for group_idx, param in enumerate(optimizer.param_groups):
87
+ param['lr'] = self._base_lr[group_idx] * factor
88
+ else:
89
+ for group_idx, param in enumerate(optimizer.param_groups):
90
+ total_iters = self.total_epochs * dataloader_len
91
+ lr = self._min_lr[group_idx] + (
92
+ self._base_lr[group_idx] -
93
+ self._min_lr[group_idx]) * 0.5 * (
94
+ math.cos((cur_iters - warmup_total_iters) * math.pi /
95
+ (total_iters - warmup_total_iters)) + 1.0)
96
+ param['lr'] = lr
mmpl/engine/hooks/switch_to_deploy_hook.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from mmengine.hooks import Hook
4
+ from mmengine.runner import Runner
5
+
6
+ from mmyolo.registry import HOOKS
7
+ from mmyolo.utils import switch_to_deploy
8
+
9
+
10
+ @HOOKS.register_module()
11
+ class SwitchToDeployHook(Hook):
12
+ """Switch to deploy mode before testing.
13
+
14
+ This hook converts the multi-channel structure of the training network
15
+ (high performance) to the one-way structure of the testing network (fast
16
+ speed and memory saving).
17
+ """
18
+
19
+ def before_test_epoch(self, runner: Runner):
20
+ """Switch to deploy mode before testing."""
21
+ switch_to_deploy(runner.model)
mmpl/engine/hooks/visualization_hook.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import warnings
3
+ from typing import Optional, Sequence, Any
4
+
5
+ import mmcv
6
+ from lightning import Callback
7
+ from mmengine.fileio import get
8
+ from mmengine.hooks import Hook
9
+ from mmengine.runner import Runner
10
+ from mmengine.utils import mkdir_or_exist
11
+ from mmengine.visualization import Visualizer
12
+
13
+ from mmpl.registry import HOOKS
14
+ from mmdet.structures import DetDataSample
15
+
16
+
17
+ @HOOKS.register_module()
18
+ class DetVisualizationHook(Callback):
19
+ """Detection Visualization Hook. Used to visualize validation and testing
20
+ process prediction results.
21
+
22
+ In the testing phase:
23
+
24
+ 1. If ``show`` is True, it means that only the prediction results are
25
+ visualized without storing data, so ``vis_backends`` needs to
26
+ be excluded.
27
+ 2. If ``test_out_dir`` is specified, it means that the prediction results
28
+ need to be saved to ``test_out_dir``. In order to avoid vis_backends
29
+ also storing data, so ``vis_backends`` needs to be excluded.
30
+ 3. ``vis_backends`` takes effect if the user does not specify ``show``
31
+ and `test_out_dir``. You can set ``vis_backends`` to WandbVisBackend or
32
+ TensorboardVisBackend to store the prediction result in Wandb or
33
+ Tensorboard.
34
+
35
+ Args:
36
+ draw (bool): whether to draw prediction results. If it is False,
37
+ it means that no drawing will be done. Defaults to False.
38
+ interval (int): The interval of visualization. Defaults to 50.
39
+ score_thr (float): The threshold to visualize the bboxes
40
+ and masks. Defaults to 0.3.
41
+ show (bool): Whether to display the drawn image. Default to False.
42
+ wait_time (float): The interval of show (s). Defaults to 0.
43
+ test_out_dir (str, optional): directory where painted images
44
+ will be saved in testing process.
45
+ backend_args (dict, optional): Arguments to instantiate the
46
+ corresponding backend. Defaults to None.
47
+ """
48
+
49
+ def __init__(self,
50
+ draw: bool = False,
51
+ interval: int = 50,
52
+ score_thr: float = 0.3,
53
+ show: bool = False,
54
+ wait_time: float = 0.,
55
+ test_out_dir: Optional[str] = None,
56
+ backend_args: dict = None):
57
+ self._visualizer: Visualizer = Visualizer.get_current_instance()
58
+ self.interval = interval
59
+ self.score_thr = score_thr
60
+ self.show = show
61
+ if self.show:
62
+ # No need to think about vis backends.
63
+ self._visualizer._vis_backends = {}
64
+ warnings.warn('The show is True, it means that only '
65
+ 'the prediction results are visualized '
66
+ 'without storing data, so vis_backends '
67
+ 'needs to be excluded.')
68
+
69
+ self.wait_time = wait_time
70
+ self.backend_args = backend_args
71
+ self.draw = draw
72
+ self.test_out_dir = test_out_dir
73
+ self._test_index = 0
74
+
75
+ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
76
+ outputs: Sequence[DetDataSample]) -> None:
77
+ """Run after every ``self.interval`` validation iterations.
78
+
79
+ Args:
80
+ runner (:obj:`Runner`): The runner of the validation process.
81
+ batch_idx (int): The index of the current batch in the val loop.
82
+ data_batch (dict): Data from dataloader.
83
+ outputs (Sequence[:obj:`DetDataSample`]]): A batch of data samples
84
+ that contain annotations and predictions.
85
+ """
86
+ if self.draw is False:
87
+ return
88
+
89
+ # There is no guarantee that the same batch of images
90
+ # is visualized for each evaluation.
91
+ total_curr_iter = runner.iter + batch_idx
92
+
93
+ # Visualize only the first data
94
+ img_path = outputs[0].img_path
95
+ img_bytes = get(img_path, backend_args=self.backend_args)
96
+ img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
97
+
98
+ if total_curr_iter % self.interval == 0:
99
+ self._visualizer.add_datasample(
100
+ osp.basename(img_path) if self.show else 'val_img',
101
+ img,
102
+ data_sample=outputs[0],
103
+ show=self.show,
104
+ wait_time=self.wait_time,
105
+ pred_score_thr=self.score_thr,
106
+ step=total_curr_iter)
107
+
108
+ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
109
+ outputs: Sequence[DetDataSample]) -> None:
110
+ """Run after every testing iterations.
111
+
112
+ Args:
113
+ runner (:obj:`Runner`): The runner of the testing process.
114
+ batch_idx (int): The index of the current batch in the val loop.
115
+ data_batch (dict): Data from dataloader.
116
+ outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
117
+ that contain annotations and predictions.
118
+ """
119
+ if self.draw is False:
120
+ return
121
+
122
+ if self.test_out_dir is not None:
123
+ self.test_out_dir = osp.join(runner.work_dir, runner.timestamp,
124
+ self.test_out_dir)
125
+ mkdir_or_exist(self.test_out_dir)
126
+
127
+ for data_sample in outputs:
128
+ self._test_index += 1
129
+
130
+ img_path = data_sample.img_path
131
+ img_bytes = get(img_path, backend_args=self.backend_args)
132
+ img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
133
+
134
+ out_file = None
135
+ if self.test_out_dir is not None:
136
+ out_file = osp.basename(img_path)
137
+ out_file = osp.join(self.test_out_dir, out_file)
138
+
139
+ self._visualizer.add_datasample(
140
+ osp.basename(img_path) if self.show else 'test_img',
141
+ img,
142
+ data_sample=data_sample,
143
+ show=self.show,
144
+ wait_time=self.wait_time,
145
+ pred_score_thr=self.score_thr,
146
+ out_file=out_file,
147
+ step=self._test_index)
148
+
149
+ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
150
+ # if hasattr(trainer.datamodule, f'predict_dataset'):
151
+ # dataset = getattr(trainer.datamodule, f'predict_dataset')
152
+ # if hasattr(dataset, 'metainfo') and hasattr(self._visualizer, 'dataset_meta'):
153
+ # self._visualizer.dataset_meta = dataset.metainfo
154
+ if self.test_out_dir is not None:
155
+ self.test_out_dir = osp.join(trainer.default_root_dir, self.test_out_dir)
156
+ mkdir_or_exist(self.test_out_dir)
157
+
158
+ def on_predict_batch_end(
159
+ self,
160
+ trainer: "pl.Trainer",
161
+ pl_module: "pl.LightningModule",
162
+ outputs: Any,
163
+ batch: Any,
164
+ batch_idx: int,
165
+ dataloader_idx: int = 0,
166
+ ) -> None:
167
+ """Run after every testing iterations.
168
+
169
+ Args:
170
+ runner (:obj:`Runner`): The runner of the testing process.
171
+ batch_idx (int): The index of the current batch in the val loop.
172
+ data_batch (dict): Data from dataloader.
173
+ outputs (Sequence[:obj:`DetDataSample`]): A batch of data samples
174
+ that contain annotations and predictions.
175
+ """
176
+ if self.draw is False:
177
+ return
178
+
179
+ for data_sample in outputs:
180
+ self._test_index += 1
181
+
182
+ img_path = data_sample.img_path
183
+ img_bytes = get(img_path, backend_args=self.backend_args)
184
+ img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
185
+
186
+ out_file = None
187
+ if self.test_out_dir is not None:
188
+ out_file = osp.basename(img_path)
189
+ out_file = osp.join(self.test_out_dir, out_file)
190
+
191
+ self._visualizer.add_datasample(
192
+ osp.basename(img_path) if self.show else 'test_img',
193
+ img,
194
+ data_sample=data_sample,
195
+ show=self.show,
196
+ wait_time=self.wait_time,
197
+ pred_score_thr=self.score_thr,
198
+ out_file=out_file,
199
+ step=self._test_index)
mmpl/engine/hooks/yolov5_param_scheduler_hook.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ import numpy as np
4
+ from typing import Dict, Optional, Union
5
+ from mmengine.registry import HOOKS
6
+ from .param_scheduler_hook import ParamSchedulerHook
7
+
8
+ DATA_BATCH = Optional[Union[dict, tuple, list]]
9
+
10
+
11
+ def linear_fn(lr_factor: float, max_epochs: int):
12
+ """Generate linear function."""
13
+ return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor
14
+
15
+
16
+ def cosine_fn(lr_factor: float, max_epochs: int):
17
+ """Generate cosine function."""
18
+ return lambda x: (
19
+ (1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1
20
+
21
+
22
+ @HOOKS.register_module()
23
+ class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
24
+ """A hook to update learning rate and momentum in optimizer of YOLOv5."""
25
+ priority = 9
26
+
27
+ scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn}
28
+
29
+ def __init__(self,
30
+ scheduler_type: str = 'linear',
31
+ lr_factor: float = 0.01,
32
+ max_epochs: int = 300,
33
+ warmup_epochs: int = 3,
34
+ warmup_bias_lr: float = 0.1,
35
+ warmup_momentum: float = 0.8,
36
+ warmup_mim_iter: int = 500,
37
+ **kwargs):
38
+
39
+ assert scheduler_type in self.scheduler_maps
40
+
41
+ self.warmup_epochs = warmup_epochs
42
+ self.warmup_bias_lr = warmup_bias_lr
43
+ self.warmup_momentum = warmup_momentum
44
+ self.warmup_mim_iter = warmup_mim_iter
45
+
46
+ kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs})
47
+ self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs)
48
+
49
+ self._warmup_end = False
50
+ self._base_lr = None
51
+ self._base_momentum = None
52
+
53
+
54
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
55
+ optimizer = trainer.optimizers[0]
56
+ for group in optimizer.param_groups:
57
+ # If the param is never be scheduled, record the current value
58
+ # as the initial value.
59
+ group.setdefault('initial_lr', group['lr'])
60
+ group.setdefault('initial_momentum', group.get('momentum', -1))
61
+
62
+ self._base_lr = [
63
+ group['initial_lr'] for group in optimizer.param_groups
64
+ ]
65
+ self._base_momentum = [
66
+ group['initial_momentum'] for group in optimizer.param_groups
67
+ ]
68
+
69
+ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss) -> None:
70
+ cur_iters = trainer.global_step
71
+ cur_epoch = trainer.current_epoch
72
+ optimizer = trainer.optimizers[0]
73
+
74
+ # The minimum warmup is self.warmup_mim_iter
75
+ warmup_total_iters = max(
76
+ round(self.warmup_epochs * len(trainer.train_dataloader)),
77
+ self.warmup_mim_iter)
78
+
79
+ if cur_iters <= warmup_total_iters:
80
+ xp = [0, warmup_total_iters]
81
+ for group_idx, param in enumerate(optimizer.param_groups):
82
+ if group_idx == 2:
83
+ # bias learning rate will be handled specially
84
+ yp = [
85
+ self.warmup_bias_lr,
86
+ self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
87
+ ]
88
+ else:
89
+ yp = [
90
+ 0.0,
91
+ self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
92
+ ]
93
+ param['lr'] = np.interp(cur_iters, xp, yp)
94
+
95
+ if 'momentum' in param:
96
+ param['momentum'] = np.interp(
97
+ cur_iters, xp,
98
+ [self.warmup_momentum, self._base_momentum[group_idx]])
99
+ else:
100
+ self._warmup_end = True
101
+
102
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
103
+ if not self._warmup_end:
104
+ return
105
+
106
+ cur_epoch = trainer.current_epoch
107
+ optimizer = trainer.optimizers[0]
108
+ for group_idx, param in enumerate(optimizer.param_groups):
109
+ param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(
110
+ cur_epoch)
111
+
mmpl/engine/hooks/yolox_mode_switch_hook.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ from typing import Sequence
4
+
5
+ from mmengine.hooks import Hook
6
+ from mmengine.model import is_model_wrapper
7
+ from mmengine.runner import Runner
8
+
9
+ from mmyolo.registry import HOOKS
10
+
11
+
12
+ @HOOKS.register_module()
13
+ class YOLOXModeSwitchHook(Hook):
14
+ """Switch the mode of YOLOX during training.
15
+
16
+ This hook turns off the mosaic and mixup data augmentation and switches
17
+ to use L1 loss in bbox_head.
18
+
19
+ Args:
20
+ num_last_epochs (int): The number of latter epochs in the end of the
21
+ training to close the data augmentation and switch to L1 loss.
22
+ Defaults to 15.
23
+ """
24
+
25
+ def __init__(self,
26
+ num_last_epochs: int = 15,
27
+ new_train_pipeline: Sequence[dict] = None):
28
+ self.num_last_epochs = num_last_epochs
29
+ self.new_train_pipeline_cfg = new_train_pipeline
30
+
31
+ def before_train_epoch(self, runner: Runner):
32
+ """Close mosaic and mixup augmentation and switches to use L1 loss."""
33
+ epoch = runner.epoch
34
+ model = runner.model
35
+ if is_model_wrapper(model):
36
+ model = model.module
37
+
38
+ if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
39
+ runner.logger.info(f'New Pipeline: {self.new_train_pipeline_cfg}')
40
+
41
+ train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader)
42
+ train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline_cfg
43
+ # Note: Why rebuild the dataset?
44
+ # When build_dataloader will make a deep copy of the dataset,
45
+ # it will lead to potential risks, such as the global instance
46
+ # object FileClient data is disordered.
47
+ # This problem needs to be solved in the future.
48
+ new_train_dataloader = Runner.build_dataloader(
49
+ train_dataloader_cfg)
50
+ runner.train_loop.dataloader = new_train_dataloader
51
+
52
+ runner.logger.info('recreate the dataloader!')
53
+ runner.logger.info('Add additional bbox reg loss now!')
54
+ model.bbox_head.use_bbox_aux = True
mmpl/engine/logger/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .builder import PL_LOGGERS
mmpl/engine/logger/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
mmpl/engine/logger/__pycache__/builder.cpython-310.pyc ADDED
Binary file (3.15 kB). View file
 
mmpl/engine/logger/builder.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import inspect
3
+ from typing import List, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import lightning
8
+
9
+ from mmengine.config import Config, ConfigDict
10
+ from mmengine.device import is_npu_available
11
+ from mmpl.registry import LOGGERS
12
+
13
+
14
+ def register_pl_loggers() -> List[str]:
15
+ """Register loggers in ``lightning.pytorch.loggers`` to the ``LOGGERS`` registry.
16
+
17
+ Returns:
18
+ List[str]: A list of registered optimizers' name.
19
+ """
20
+ pl_loggers = []
21
+ for module_name in dir(lightning.pytorch.loggers):
22
+ if module_name.startswith('__'):
23
+ continue
24
+ _logger = getattr(lightning.pytorch.loggers, module_name)
25
+ if inspect.isclass(_logger) and issubclass(_logger, lightning.pytorch.loggers.logger.Logger):
26
+ LOGGERS.register_module(module=_logger)
27
+ pl_loggers.append(module_name)
28
+ return pl_loggers
29
+
30
+
31
+ PL_LOGGERS = register_pl_loggers()
32
+
33
+
34
+ def register_dadaptation_optimizers() -> List[str]:
35
+ """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry.
36
+
37
+ Returns:
38
+ List[str]: A list of registered optimizers' name.
39
+ """
40
+ dadaptation_optimizers = []
41
+ try:
42
+ import dadaptation
43
+ except ImportError:
44
+ pass
45
+ else:
46
+ for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']:
47
+ _optim = getattr(dadaptation, module_name)
48
+ if inspect.isclass(_optim) and issubclass(_optim,
49
+ torch.optim.Optimizer):
50
+ OPTIMIZERS.register_module(module=_optim)
51
+ dadaptation_optimizers.append(module_name)
52
+ return dadaptation_optimizers
53
+
54
+
55
+ # DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers()
56
+
57
+
58
+ def register_lion_optimizers() -> List[str]:
59
+ """Register Lion optimizer to the ``OPTIMIZERS`` registry.
60
+
61
+ Returns:
62
+ List[str]: A list of registered optimizers' name.
63
+ """
64
+ optimizers = []
65
+ try:
66
+ from lion_pytorch import Lion
67
+ except ImportError:
68
+ pass
69
+ else:
70
+ OPTIMIZERS.register_module(module=Lion)
71
+ optimizers.append('Lion')
72
+ return optimizers
73
+
74
+
75
+ # LION_OPTIMIZERS = register_lion_optimizers()
76
+
77
+
78
+ def build_optim_wrapper(model: nn.Module,
79
+ cfg: Union[dict, Config, ConfigDict]):
80
+ """Build function of OptimWrapper.
81
+
82
+ If ``constructor`` is set in the ``cfg``, this method will build an
83
+ optimizer wrapper constructor, and use optimizer wrapper constructor to
84
+ build the optimizer wrapper. If ``constructor`` is not set, the
85
+ ``DefaultOptimWrapperConstructor`` will be used by default.
86
+
87
+ Args:
88
+ model (nn.Module): Model to be optimized.
89
+ cfg (dict): Config of optimizer wrapper, optimizer constructor and
90
+ optimizer.
91
+
92
+ Returns:
93
+ OptimWrapper: The built optimizer wrapper.
94
+ """
95
+ optim_wrapper_cfg = copy.deepcopy(cfg)
96
+ constructor_type = optim_wrapper_cfg.pop('constructor',
97
+ 'DefaultOptimWrapperConstructor')
98
+ paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None)
99
+
100
+ # Since the current generation of NPU(Ascend 910) only supports
101
+ # mixed precision training, here we turn on mixed precision by default
102
+ # on the NPU to make the training normal
103
+ if is_npu_available():
104
+ optim_wrapper_cfg['type'] = 'AmpOptimWrapper'
105
+
106
+ optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(
107
+ dict(
108
+ type=constructor_type,
109
+ optim_wrapper_cfg=optim_wrapper_cfg,
110
+ paramwise_cfg=paramwise_cfg))
111
+ optim_wrapper = optim_wrapper_constructor(model)
112
+ return optim_wrapper
mmpl/engine/optimizers/__init__.py ADDED
File without changes
mmpl/engine/runner/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pl_runner import PLRunner
2
+
3
+ __all__ = ['PLRunner']
mmpl/engine/runner/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (220 Bytes). View file
 
mmpl/engine/runner/__pycache__/pl_runner.cpython-310.pyc ADDED
Binary file (27 kB). View file
 
mmpl/engine/runner/pl_runner.py ADDED
@@ -0,0 +1,941 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ import pickle
6
+ import platform
7
+ import time
8
+ import warnings
9
+ from collections import OrderedDict
10
+ from functools import partial
11
+ from typing import Callable, Dict, List, Optional, Sequence, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from lightning.pytorch.loggers import Logger
16
+ from torch.nn.parallel.distributed import DistributedDataParallel
17
+ from torch.optim import Optimizer
18
+ from torch.utils.data import DataLoader
19
+
20
+ import mmengine
21
+ from mmengine.config import Config, ConfigDict
22
+ from mmengine.dataset import worker_init_fn
23
+ from mmengine.device import get_device
24
+ from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
25
+ is_distributed, master_only)
26
+ from mmengine.evaluator import Evaluator
27
+ from mmengine.fileio import FileClient, join_path
28
+ from mmengine.hooks import Hook
29
+ from mmengine.logging import MessageHub, MMLogger, print_log
30
+ from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
31
+ is_model_wrapper, revert_sync_batchnorm)
32
+ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
33
+ build_optim_wrapper)
34
+ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
35
+ HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS,
36
+ OPTIM_WRAPPERS, PARAM_SCHEDULERS,
37
+ RUNNERS, VISUALIZERS, DefaultScope)
38
+ from mmengine.utils import digit_version, get_git_hash, is_seq_of
39
+ from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
40
+ set_multi_processing)
41
+ from mmengine.visualization import Visualizer
42
+ from mmengine.runner.base_loop import BaseLoop
43
+ from mmengine.runner.checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
44
+ find_latest_checkpoint, get_state_dict,
45
+ save_checkpoint, weights_to_cpu)
46
+ from mmengine.runner.log_processor import LogProcessor
47
+ from mmengine.runner.loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
48
+ from mmengine.runner.priority import Priority, get_priority
49
+ from mmengine.runner.utils import set_random_seed
50
+
51
+ ConfigType = Union[Dict, Config, ConfigDict]
52
+ ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, List[_ParamScheduler]]]
53
+ OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
54
+
55
+ from mmpl.registry import MODELS, LOGGERS
56
+ import lightning.pytorch as pl
57
+ from mmpl.models import build_pler
58
+
59
+
60
+ @RUNNERS.register_module()
61
+ class PLRunner:
62
+ def __init__(
63
+ self,
64
+ trainer_cfg: Dict,
65
+ model_cfg: Union[pl.LightningModule, Dict],
66
+ datamodule_cfg: Optional[Dict] = None,
67
+ cfg: Optional[ConfigType] = None
68
+ ):
69
+ self.trainer_cfg = copy.deepcopy(trainer_cfg)
70
+ self.model_cfg = copy.deepcopy(model_cfg)
71
+ self.datamodule_cfg = copy.deepcopy(datamodule_cfg)
72
+ mmengine.mkdir_or_exist(trainer_cfg['default_root_dir'])
73
+
74
+ timestamp = torch.tensor(time.time(), dtype=torch.float64)
75
+ # broadcast timestamp from 0 process to other processes
76
+ broadcast(timestamp)
77
+ self.timestamp = time.strftime('%Y%m%d_%H%M%S',
78
+ time.localtime(timestamp.item()))
79
+
80
+ if cfg is not None:
81
+ if isinstance(cfg, Config):
82
+ self.cfg = copy.deepcopy(cfg)
83
+ elif isinstance(cfg, dict):
84
+ self.cfg = Config(cfg)
85
+ else:
86
+ self.cfg = Config(dict())
87
+
88
+ compiled_model = trainer_cfg.pop('compiled_model', False)
89
+
90
+ # build logger
91
+ loggers = self.build_logger(
92
+ trainer_cfg.get('logger', False),
93
+ trainer_cfg.get('default_root_dir', f'{self.timestamp}')
94
+ )
95
+ trainer_cfg['logger'] = loggers
96
+
97
+ # build visualizer used for writing log or visualizing all kinds of data
98
+ self.visualizer = self.build_visualizer(
99
+ self.cfg.get('visualizer', None),
100
+ trainer_cfg.get('default_root_dir', f'{self.timestamp}')
101
+ )
102
+ if self.cfg:
103
+ self.visualizer.add_config(self.cfg)
104
+
105
+ # build callbacks
106
+ callbacks = self.build_hooks(
107
+ trainer_cfg.get('callbacks', None),
108
+ )
109
+ trainer_cfg['callbacks'] = callbacks
110
+
111
+ # build strategy
112
+ strategy = self.build_strategy(
113
+ trainer_cfg.get('strategy', 'auto'),
114
+ )
115
+ trainer_cfg['strategy'] = strategy
116
+
117
+ self.trainer = pl.Trainer(**trainer_cfg)
118
+ model_cfg.update({'config_cfg': copy.deepcopy(cfg).to_dict()})
119
+ model = self.build_model(model_cfg)
120
+ if cfg.get('load_from', None) is not None:
121
+ self.load_checkpoint(model, cfg['load_from'])
122
+ if compiled_model:
123
+ # default, reduce-overhead, and max-autotune.
124
+ self.model = torch.compile(model)
125
+ else:
126
+ self.model = model
127
+
128
+ # dump `cfg` to `work_dir`
129
+ self.dump_config()
130
+ # # Collect and log environment information.
131
+ # self._log_env(env_cfg)
132
+ # log hooks information
133
+ # self.logger.info(f'Hooks will be executed in the following '
134
+ # f'order:\n{self.get_hooks_info()}')
135
+
136
+ def build_visualizer(
137
+ self,
138
+ visualizer: Optional[Union[Visualizer,
139
+ Dict]] = None,
140
+ default_root_dir = 'tmp'
141
+ ) -> Visualizer:
142
+ """Build a global asscessable Visualizer.
143
+
144
+ Args:
145
+ visualizer (Visualizer or dict, optional): A Visualizer object
146
+ or a dict to build Visualizer object. If ``visualizer`` is a
147
+ Visualizer object, just returns itself. If not specified,
148
+ default config will be used to build Visualizer object.
149
+ Defaults to None.
150
+
151
+ Returns:
152
+ Visualizer: A Visualizer object build from ``visualizer``.
153
+ """
154
+ if visualizer is None:
155
+ visualizer = dict(
156
+ name=os.path.basename(default_root_dir),
157
+ vis_backends=[dict(type='LocalVisBackend')],
158
+ save_dir=default_root_dir+'/visualizer'
159
+ )
160
+ return Visualizer.get_instance(**visualizer)
161
+
162
+ if isinstance(visualizer, Visualizer):
163
+ return visualizer
164
+
165
+ if isinstance(visualizer, dict):
166
+ # ensure visualizer containing name key
167
+ visualizer.setdefault('name', os.path.basename(default_root_dir))
168
+ visualizer.setdefault('save_dir', default_root_dir+'/visualizer')
169
+ return VISUALIZERS.build(visualizer)
170
+ else:
171
+ raise TypeError(
172
+ 'visualizer should be Visualizer object, a dict or None, '
173
+ f'but got {visualizer}')
174
+
175
+ def build_hooks(self, hooks: Union[Dict, List[Dict]] = None) -> List[Hook]:
176
+ """Build hooks from config.
177
+
178
+ Args:
179
+ hooks_cfg (dict): Config dict of hooks.
180
+
181
+ Returns:
182
+ list[Hook]: A list of hooks.
183
+ """
184
+ if hooks is not None:
185
+ if isinstance(hooks, dict):
186
+ hooks = [hooks]
187
+ tmp_hooks = []
188
+ for hook in hooks:
189
+ hook = HOOKS.build(hook)
190
+ tmp_hooks.append(hook)
191
+ hooks = tmp_hooks
192
+ return hooks
193
+
194
+ @classmethod
195
+ def from_cfg(cls, cfg: ConfigType) -> 'Runner':
196
+ cfg = copy.deepcopy(cfg)
197
+ runner = cls(
198
+ trainer_cfg=cfg.get('trainer_cfg'),
199
+ model_cfg=cfg['model_cfg'],
200
+ datamodule_cfg=cfg.get('datamodule_cfg'),
201
+ cfg=cfg
202
+ )
203
+
204
+ return runner
205
+
206
+ def build_logger(self, loggers: Union[Dict, List[Dict]] = None, default_root_dir='logger'):
207
+ if loggers is not None and loggers:
208
+ if isinstance(loggers, Dict):
209
+ loggers = [loggers]
210
+ tmp_loggers = []
211
+ for logger in loggers:
212
+ if logger.get('save_dir', None) is None:
213
+ logger['save_dir'] = default_root_dir
214
+ mmengine.mkdir_or_exist(logger['save_dir'])
215
+ tmp_loggers.append(LOGGERS.build(logger))
216
+ loggers = tmp_loggers
217
+ return loggers
218
+
219
+ def build_strategy(self, strategy='auto'):
220
+ if isinstance(strategy, str):
221
+ return strategy
222
+ elif isinstance(strategy, dict):
223
+ if strategy.get('type', '') == 'FSDPStrategy':
224
+ from torch.distributed.fsdp import CPUOffload
225
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
226
+ import functools
227
+ strategy.update(
228
+ dict(
229
+ # cpu_offload=CPUOffload(offload_params=True),
230
+ auto_wrap_policy=functools.partial(
231
+ size_based_auto_wrap_policy, min_num_params=int(5e7)
232
+ )
233
+ )
234
+ )
235
+ strategy = MODEL_WRAPPERS.build(strategy)
236
+ return strategy
237
+ return strategy
238
+
239
+ def build_model(self, model: Union[pl.LightningModule, Dict]) -> pl.LightningModule:
240
+ if isinstance(model, pl.LightningModule):
241
+ return model
242
+ elif isinstance(model, dict):
243
+ model = build_pler(model)
244
+ return model # type: ignore
245
+ else:
246
+ raise TypeError('model should be a nn.Module object or dict, '
247
+ f'but got {model}')
248
+
249
+ def _init_model_weights(self) -> None:
250
+ """Initialize the model weights if the model has
251
+ :meth:`init_weights`"""
252
+ if hasattr(self.model, 'module'):
253
+ model = self.model.module
254
+ else:
255
+ model = self.model
256
+ if hasattr(model, 'init_weights'):
257
+ model.init_weights()
258
+ # sync params and buffers
259
+ for name, params in model.state_dict().items():
260
+ broadcast(params)
261
+
262
+ def get_hooks_info(self) -> str:
263
+ # Get hooks info in each stage
264
+ stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
265
+ for hook in self.hooks:
266
+ try:
267
+ priority = Priority(hook.priority).name # type: ignore
268
+ except ValueError:
269
+ priority = hook.priority # type: ignore
270
+ classname = hook.__class__.__name__
271
+ hook_info = f'({priority:<12}) {classname:<35}'
272
+ for trigger_stage in hook.get_triggered_stages():
273
+ stage_hook_map[trigger_stage].append(hook_info)
274
+
275
+ stage_hook_infos = []
276
+ for stage in Hook.stages:
277
+ hook_infos = stage_hook_map[stage]
278
+ if len(hook_infos) > 0:
279
+ info = f'{stage}:\n'
280
+ info += '\n'.join(hook_infos)
281
+ info += '\n -------------------- '
282
+ stage_hook_infos.append(info)
283
+ return '\n'.join(stage_hook_infos)
284
+
285
+ def load_or_resume(self) -> None:
286
+ """load or resume checkpoint."""
287
+ if self._has_loaded:
288
+ return None
289
+
290
+ # decide to load from checkpoint or resume from checkpoint
291
+ resume_from = None
292
+ if self._resume and self._load_from is None:
293
+ # auto resume from the latest checkpoint
294
+ resume_from = find_latest_checkpoint(self.work_dir)
295
+ self.logger.info(
296
+ f'Auto resumed from the latest checkpoint {resume_from}.')
297
+ elif self._resume and self._load_from is not None:
298
+ # resume from the specified checkpoint
299
+ resume_from = self._load_from
300
+
301
+ if resume_from is not None:
302
+ self.resume(resume_from)
303
+ self._has_loaded = True
304
+ elif self._load_from is not None:
305
+ self.load_checkpoint(self._load_from)
306
+ self._has_loaded = True
307
+
308
+ @staticmethod
309
+ def build_datamodule(datamodule_cfg: Union[pl.LightningDataModule, Dict]):
310
+ if isinstance(datamodule_cfg, pl.LightningDataModule):
311
+ return datamodule_cfg
312
+ datamodule_cfg = copy.deepcopy(datamodule_cfg)
313
+ # build datamodule
314
+ datamodule = DATASETS.build(datamodule_cfg)
315
+ return datamodule
316
+
317
+ def run(self, status, *args, **kwargs):
318
+ assert status in ['fit', 'test', 'predict', 'validate']
319
+ trainer_func = self.trainer.__getattribute__(status)
320
+ self.datamodule = self.build_datamodule(self.datamodule_cfg)
321
+ return trainer_func(model=self.model, datamodule=self.datamodule, *args, **kwargs)
322
+
323
+ #
324
+ # if is_model_wrapper(self.model):
325
+ # ori_model = self.model.module
326
+ # else:
327
+ # ori_model = self.model
328
+ # assert hasattr(ori_model, 'train_step'), (
329
+ # 'If you want to train your model, please make sure your model '
330
+ # 'has implemented `train_step`.')
331
+ #
332
+ # if self._val_loop is not None:
333
+ # assert hasattr(ori_model, 'val_step'), (
334
+ # 'If you want to validate your model, please make sure your '
335
+ # 'model has implemented `val_step`.')
336
+ #
337
+ # if self._train_loop is None:
338
+ # raise RuntimeError(
339
+ # '`self._train_loop` should not be None when calling train '
340
+ # 'method. Please provide `train_dataloader`, `train_cfg`, '
341
+ # '`optimizer` and `param_scheduler` arguments when '
342
+ # 'initializing runner.')
343
+ #
344
+ # self._train_loop = self.build_train_loop(
345
+ # self._train_loop) # type: ignore
346
+ #
347
+ # # `build_optimizer` should be called before `build_param_scheduler`
348
+ # # because the latter depends on the former
349
+ # self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
350
+ # # Automatically scaling lr by linear scaling rule
351
+ # self.scale_lr(self.optim_wrapper, self.auto_scale_lr)
352
+ #
353
+ # if self.param_schedulers is not None:
354
+ # self.param_schedulers = self.build_param_scheduler( # type: ignore
355
+ # self.param_schedulers) # type: ignore
356
+ #
357
+ # if self._val_loop is not None:
358
+ # self._val_loop = self.build_val_loop(
359
+ # self._val_loop) # type: ignore
360
+ # # TODO: add a contextmanager to avoid calling `before_run` many times
361
+ # self.call_hook('before_run')
362
+ #
363
+ # # initialize the model weights
364
+ # self._init_model_weights()
365
+ # # make sure checkpoint-related hooks are triggered after `before_run`
366
+ # self.load_or_resume()
367
+ #
368
+ # # Initiate inner count of `optim_wrapper`.
369
+ # self.optim_wrapper.initialize_count_status(
370
+ # self.model,
371
+ # self._train_loop.iter, # type: ignore
372
+ # self._train_loop.max_iters) # type: ignore
373
+ #
374
+ # # Maybe compile the model according to options in self.cfg.compile
375
+ # # This must be called **AFTER** model has been wrapped.
376
+ # self._maybe_compile('train_step')
377
+ #
378
+ # model = self.train_loop.run() # type: ignore
379
+ # self.call_hook('after_run')
380
+ # return model
381
+
382
+
383
+
384
+ def register_hook(
385
+ self,
386
+ hook: Union[Hook, Dict],
387
+ priority: Optional[Union[str, int, Priority]] = None) -> None:
388
+ """Register a hook into the hook list.
389
+
390
+ The hook will be inserted into a priority queue, with the specified
391
+ priority (See :class:`Priority` for details of priorities).
392
+ For hooks with the same priority, they will be triggered in the same
393
+ order as they are registered.
394
+
395
+ Priority of hook will be decided with the following priority:
396
+
397
+ - ``priority`` argument. If ``priority`` is given, it will be priority
398
+ of hook.
399
+ - If ``hook`` argument is a dict and ``priority`` in it, the priority
400
+ will be the value of ``hook['priority']``.
401
+ - If ``hook`` argument is a dict but ``priority`` not in it or ``hook``
402
+ is an instance of ``hook``, the priority will be ``hook.priority``.
403
+
404
+ Args:
405
+ hook (:obj:`Hook` or dict): The hook to be registered.
406
+ priority (int or str or :obj:`Priority`, optional): Hook priority.
407
+ Lower value means higher priority.
408
+ """
409
+ if not isinstance(hook, (Hook, dict)):
410
+ raise TypeError(
411
+ f'hook should be an instance of Hook or dict, but got {hook}')
412
+
413
+ _priority = None
414
+ if isinstance(hook, dict):
415
+ if 'priority' in hook:
416
+ _priority = hook.pop('priority')
417
+
418
+ hook_obj = HOOKS.build(hook)
419
+ else:
420
+ hook_obj = hook
421
+
422
+ if priority is not None:
423
+ hook_obj.priority = priority
424
+ elif _priority is not None:
425
+ hook_obj.priority = _priority
426
+
427
+ inserted = False
428
+ for i in range(len(self._hooks) - 1, -1, -1):
429
+ if get_priority(hook_obj.priority) >= get_priority(
430
+ self._hooks[i].priority):
431
+ self._hooks.insert(i + 1, hook_obj)
432
+ inserted = True
433
+ break
434
+ if not inserted:
435
+ self._hooks.insert(0, hook_obj)
436
+
437
+ def register_default_hooks(
438
+ self,
439
+ hooks: Optional[Dict[str, Union[Hook, Dict]]] = None) -> None:
440
+ """Register default hooks into hook list.
441
+
442
+ ``hooks`` will be registered into runner to execute some default
443
+ actions like updating model parameters or saving checkpoints.
444
+
445
+ Default hooks and their priorities:
446
+
447
+ +----------------------+-------------------------+
448
+ | Hooks | Priority |
449
+ +======================+=========================+
450
+ | RuntimeInfoHook | VERY_HIGH (10) |
451
+ +----------------------+-------------------------+
452
+ | IterTimerHook | NORMAL (50) |
453
+ +----------------------+-------------------------+
454
+ | DistSamplerSeedHook | NORMAL (50) |
455
+ +----------------------+-------------------------+
456
+ | LoggerHook | BELOW_NORMAL (60) |
457
+ +----------------------+-------------------------+
458
+ | ParamSchedulerHook | LOW (70) |
459
+ +----------------------+-------------------------+
460
+ | CheckpointHook | VERY_LOW (90) |
461
+ +----------------------+-------------------------+
462
+
463
+ If ``hooks`` is None, above hooks will be registered by
464
+ default::
465
+
466
+ default_hooks = dict(
467
+ runtime_info=dict(type='RuntimeInfoHook'),
468
+ timer=dict(type='IterTimerHook'),
469
+ sampler_seed=dict(type='DistSamplerSeedHook'),
470
+ logger=dict(type='LoggerHook'),
471
+ param_scheduler=dict(type='ParamSchedulerHook'),
472
+ checkpoint=dict(type='CheckpointHook', interval=1),
473
+ )
474
+
475
+ If not None, ``hooks`` will be merged into ``default_hooks``.
476
+ If there are None value in default_hooks, the corresponding item will
477
+ be popped from ``default_hooks``::
478
+
479
+ hooks = dict(timer=None)
480
+
481
+ The final registered default hooks will be :obj:`RuntimeInfoHook`,
482
+ :obj:`DistSamplerSeedHook`, :obj:`LoggerHook`,
483
+ :obj:`ParamSchedulerHook` and :obj:`CheckpointHook`.
484
+
485
+ Args:
486
+ hooks (dict[str, Hook or dict], optional): Default hooks or configs
487
+ to be registered.
488
+ """
489
+ default_hooks: dict = dict(
490
+ runtime_info=dict(type='RuntimeInfoHook'),
491
+ timer=dict(type='IterTimerHook'),
492
+ sampler_seed=dict(type='DistSamplerSeedHook'),
493
+ logger=dict(type='LoggerHook'),
494
+ param_scheduler=dict(type='ParamSchedulerHook'),
495
+ checkpoint=dict(type='CheckpointHook', interval=1),
496
+ )
497
+ if hooks is not None:
498
+ for name, hook in hooks.items():
499
+ if name in default_hooks and hook is None:
500
+ # remove hook from _default_hooks
501
+ default_hooks.pop(name)
502
+ else:
503
+ assert hook is not None
504
+ default_hooks[name] = hook
505
+
506
+ for hook in default_hooks.values():
507
+ self.register_hook(hook)
508
+
509
+ def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None:
510
+ """Register custom hooks into hook list.
511
+
512
+ Args:
513
+ hooks (list[Hook | dict]): List of hooks or configs to be
514
+ registered.
515
+ """
516
+ for hook in hooks:
517
+ self.register_hook(hook)
518
+
519
+ def register_hooks(
520
+ self,
521
+ default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
522
+ custom_hooks: Optional[List[Union[Hook, Dict]]] = None) -> None:
523
+ """Register default hooks and custom hooks into hook list.
524
+
525
+ Args:
526
+ default_hooks (dict[str, dict] or dict[str, Hook], optional): Hooks
527
+ to execute default actions like updating model parameters and
528
+ saving checkpoints. Defaults to None.
529
+ custom_hooks (list[dict] or list[Hook], optional): Hooks to execute
530
+ custom actions like visualizing images processed by pipeline.
531
+ Defaults to None.
532
+ """
533
+ self.register_default_hooks(default_hooks)
534
+
535
+ if custom_hooks is not None:
536
+ self.register_custom_hooks(custom_hooks)
537
+
538
+ def resume(self,
539
+ filename: str,
540
+ resume_optimizer: bool = True,
541
+ resume_param_scheduler: bool = True,
542
+ map_location: Union[str, Callable] = 'default') -> None:
543
+ """Resume model from checkpoint.
544
+
545
+ Args:
546
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
547
+ ``open-mmlab://xxx``.
548
+ resume_optimizer (bool): Whether to resume optimizer state.
549
+ Defaults to True.
550
+ resume_param_scheduler (bool): Whether to resume param scheduler
551
+ state. Defaults to True.
552
+ map_location (str or callable):A string or a callable function to
553
+ specifying how to remap storage locations.
554
+ Defaults to 'default'.
555
+ """
556
+ if map_location == 'default':
557
+ device = get_device()
558
+ checkpoint = self.load_checkpoint(filename, map_location=device)
559
+ else:
560
+ checkpoint = self.load_checkpoint(
561
+ filename, map_location=map_location)
562
+
563
+ self.train_loop._epoch = checkpoint['meta']['epoch']
564
+ self.train_loop._iter = checkpoint['meta']['iter']
565
+
566
+ # check whether the number of GPU used for current experiment
567
+ # is consistent with resuming from checkpoint
568
+ if 'config' in checkpoint['meta']:
569
+ config = mmengine.Config.fromstring(
570
+ checkpoint['meta']['config'], file_format='.py')
571
+ previous_gpu_ids = config.get('gpu_ids', None)
572
+ if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0
573
+ and len(previous_gpu_ids) != self._world_size):
574
+ # TODO, should we modify the iteration?
575
+ self.logger.info(
576
+ 'Number of GPU used for current experiment is not '
577
+ 'consistent with resuming from checkpoint')
578
+ if (self.auto_scale_lr is None
579
+ or not self.auto_scale_lr.get('enable', False)):
580
+ raise RuntimeError(
581
+ 'Cannot automatically rescale lr in resuming. Please '
582
+ 'make sure the number of GPU is consistent with the '
583
+ 'previous training state resuming from the checkpoint '
584
+ 'or set `enable` in `auto_scale_lr to False.')
585
+
586
+ # resume random seed
587
+ resumed_seed = checkpoint['meta'].get('seed', None)
588
+ current_seed = self._randomness_cfg.get('seed')
589
+ if resumed_seed is not None and resumed_seed != current_seed:
590
+ if current_seed is not None:
591
+ print_log(
592
+ f'The value of random seed in the '
593
+ f'checkpoint "{resumed_seed}" is '
594
+ f'different from the value in '
595
+ f'`randomness` config "{current_seed}"',
596
+ logger='current',
597
+ level=logging.WARNING)
598
+ self._randomness_cfg.update(seed=resumed_seed)
599
+ self.set_randomness(**self._randomness_cfg)
600
+
601
+ resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None)
602
+ dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None)
603
+
604
+ # `resumed_dataset_meta` and `dataset_meta` could be object like
605
+ # np.ndarray, which cannot be directly judged as equal or not,
606
+ # therefore we just compared their dumped results.
607
+ if pickle.dumps(resumed_dataset_meta) != pickle.dumps(dataset_meta):
608
+ print_log(
609
+ 'The dataset metainfo from the resumed checkpoint is '
610
+ 'different from the current training dataset, please '
611
+ 'check the correctness of the checkpoint or the training '
612
+ 'dataset.',
613
+ logger='current',
614
+ level=logging.WARNING)
615
+
616
+ self.message_hub.load_state_dict(checkpoint['message_hub'])
617
+
618
+ # resume optimizer
619
+ if 'optimizer' in checkpoint and resume_optimizer:
620
+ self.optim_wrapper = self.build_optim_wrapper(self.optim_wrapper)
621
+ self.optim_wrapper.load_state_dict( # type: ignore
622
+ checkpoint['optimizer'])
623
+
624
+ # resume param scheduler
625
+ if resume_param_scheduler and self.param_schedulers is None:
626
+ print_log(
627
+ '`resume_param_scheduler` is True but `self.param_schedulers` '
628
+ 'is None, so skip resuming parameter schedulers',
629
+ logger='current',
630
+ level=logging.WARNING)
631
+ resume_param_scheduler = False
632
+ if 'param_schedulers' in checkpoint and resume_param_scheduler:
633
+ self.param_schedulers = self.build_param_scheduler( # type: ignore
634
+ self.param_schedulers) # type: ignore
635
+ if isinstance(self.param_schedulers, dict):
636
+ for name, schedulers in self.param_schedulers.items():
637
+ for scheduler, ckpt_scheduler in zip(
638
+ schedulers, checkpoint['param_schedulers'][name]):
639
+ scheduler.load_state_dict(ckpt_scheduler)
640
+ else:
641
+ for scheduler, ckpt_scheduler in zip(
642
+ self.param_schedulers, # type: ignore
643
+ checkpoint['param_schedulers']):
644
+ scheduler.load_state_dict(ckpt_scheduler)
645
+
646
+ self._has_loaded = True
647
+
648
+ self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}')
649
+
650
+ # def load_checkpoint(self,
651
+ # filename: str,
652
+ # model,
653
+ # map_location: Union[str, Callable] = 'cpu',
654
+ # strict: bool = False,
655
+ # revise_keys: list = [(r'^module.', '')]):
656
+ # """Load checkpoint from given ``filename``.
657
+ #
658
+ # Args:
659
+ # filename (str): Accept local filepath, URL, ``torchvision://xxx``,
660
+ # ``open-mmlab://xxx``.
661
+ # map_location (str or callable): A string or a callable function to
662
+ # specifying how to remap storage locations.
663
+ # Defaults to 'cpu'.
664
+ # strict (bool): strict (bool): Whether to allow different params for
665
+ # the model and checkpoint.
666
+ # revise_keys (list): A list of customized keywords to modify the
667
+ # state_dict in checkpoint. Each item is a (pattern, replacement)
668
+ # pair of the regular expression operations. Defaults to strip
669
+ # the prefix 'module.' by [(r'^module\\.', '')].
670
+ # """
671
+ # checkpoint = _load_checkpoint(filename, map_location=map_location)
672
+ #
673
+ # if is_model_wrapper(model):
674
+ # model = model.module
675
+ # else:
676
+ # model = model
677
+ #
678
+ # checkpoint = _load_checkpoint_to_model(
679
+ # model, checkpoint, strict, revise_keys=revise_keys)
680
+ #
681
+ # print(f'Load checkpoint from {filename}')
682
+ #
683
+ # return checkpoint
684
+ def load_checkpoint(self, model, file):
685
+
686
+ if isinstance(file, str):
687
+ file_path = file
688
+ state_dict = torch.load(file_path, map_location='cpu')['state_dict']
689
+ elif isinstance(file, dict):
690
+ file_path = file['file_path']
691
+ state_dict = torch.load(file_path, map_location='cpu')['state_dict']
692
+ for delete_key in file['delete_keys']:
693
+ del state_dict[delete_key]
694
+ else:
695
+ raise TypeError('file must be str or dict')
696
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
697
+ print('load from:', file_path)
698
+ print('load model missing_keys:', missing_keys)
699
+ print('load model unexpected_keys:', unexpected_keys)
700
+
701
+ @master_only
702
+ def save_checkpoint(
703
+ self,
704
+ out_dir: str,
705
+ filename: str,
706
+ file_client_args: Optional[dict] = None,
707
+ save_optimizer: bool = True,
708
+ save_param_scheduler: bool = True,
709
+ meta: dict = None,
710
+ by_epoch: bool = True,
711
+ backend_args: Optional[dict] = None,
712
+ ):
713
+ """Save checkpoints.
714
+
715
+ ``CheckpointHook`` invokes this method to save checkpoints
716
+ periodically.
717
+
718
+ Args:
719
+ out_dir (str): The directory that checkpoints are saved.
720
+ filename (str): The checkpoint filename.
721
+ file_client_args (dict, optional): Arguments to instantiate a
722
+ FileClient. See :class:`mmengine.fileio.FileClient` for
723
+ details. Defaults to None. It will be deprecated in future.
724
+ Please use `backend_args` instead.
725
+ save_optimizer (bool): Whether to save the optimizer to
726
+ the checkpoint. Defaults to True.
727
+ save_param_scheduler (bool): Whether to save the param_scheduler
728
+ to the checkpoint. Defaults to True.
729
+ meta (dict, optional): The meta information to be saved in the
730
+ checkpoint. Defaults to None.
731
+ by_epoch (bool): Whether the scheduled momentum is updated by
732
+ epochs. Defaults to True.
733
+ backend_args (dict, optional): Arguments to instantiate the
734
+ prefix of uri corresponding backend. Defaults to None.
735
+ New in v0.2.0.
736
+ """
737
+ if meta is None:
738
+ meta = {}
739
+ elif not isinstance(meta, dict):
740
+ raise TypeError(
741
+ f'meta should be a dict or None, but got {type(meta)}')
742
+
743
+ if by_epoch:
744
+ # self.epoch increments 1 after
745
+ # `self.call_hook('after_train_epoch)` but `save_checkpoint` is
746
+ # called by `after_train_epoch`` method of `CheckpointHook` so
747
+ # `epoch` should be `self.epoch + 1`
748
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
749
+ else:
750
+ meta.update(epoch=self.epoch, iter=self.iter + 1)
751
+
752
+ if file_client_args is not None:
753
+ warnings.warn(
754
+ '"file_client_args" will be deprecated in future. '
755
+ 'Please use "backend_args" instead', DeprecationWarning)
756
+ if backend_args is not None:
757
+ raise ValueError(
758
+ '"file_client_args" and "backend_args" cannot be set at '
759
+ 'the same time.')
760
+
761
+ file_client = FileClient.infer_client(file_client_args, out_dir)
762
+ filepath = file_client.join_path(out_dir, filename)
763
+ else:
764
+ filepath = join_path( # type: ignore
765
+ out_dir, filename, backend_args=backend_args)
766
+
767
+ meta.update(
768
+ cfg=self.cfg.pretty_text,
769
+ seed=self.seed,
770
+ experiment_name=self.experiment_name,
771
+ time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
772
+ mmengine_version=mmengine.__version__ + get_git_hash())
773
+
774
+ if hasattr(self.train_dataloader.dataset, 'metainfo'):
775
+ meta.update(dataset_meta=self.train_dataloader.dataset.metainfo)
776
+
777
+ if is_model_wrapper(self.model):
778
+ model = self.model.module
779
+ else:
780
+ model = self.model
781
+
782
+ checkpoint = {
783
+ 'meta': meta,
784
+ 'state_dict': weights_to_cpu(get_state_dict(model)),
785
+ 'message_hub': self.message_hub.state_dict()
786
+ }
787
+ # save optimizer state dict to checkpoint
788
+ if save_optimizer:
789
+ if isinstance(self.optim_wrapper, OptimWrapper):
790
+ checkpoint['optimizer'] = self.optim_wrapper.state_dict()
791
+ else:
792
+ raise TypeError(
793
+ 'self.optim_wrapper should be an `OptimWrapper` '
794
+ 'or `OptimWrapperDict` instance, but got '
795
+ f'{self.optim_wrapper}')
796
+
797
+ # save param scheduler state dict
798
+ if save_param_scheduler and self.param_schedulers is None:
799
+ print_log(
800
+ '`save_param_scheduler` is True but `self.param_schedulers` '
801
+ 'is None, so skip saving parameter schedulers',
802
+ logger='current',
803
+ level=logging.WARNING)
804
+ save_param_scheduler = False
805
+ if save_param_scheduler:
806
+ if isinstance(self.param_schedulers, dict):
807
+ checkpoint['param_schedulers'] = dict()
808
+ for name, schedulers in self.param_schedulers.items():
809
+ checkpoint['param_schedulers'][name] = []
810
+ for scheduler in schedulers:
811
+ state_dict = scheduler.state_dict()
812
+ checkpoint['param_schedulers'][name].append(state_dict)
813
+ else:
814
+ checkpoint['param_schedulers'] = []
815
+ for scheduler in self.param_schedulers: # type: ignore
816
+ state_dict = scheduler.state_dict() # type: ignore
817
+ checkpoint['param_schedulers'].append(state_dict)
818
+
819
+ self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
820
+ save_checkpoint(checkpoint, filepath)
821
+
822
+ @master_only
823
+ def dump_config(self) -> None:
824
+ version = ''
825
+ if len(self.trainer.loggers) > 0:
826
+ version = self.trainer.loggers[0].version
827
+ version = version if isinstance(version, str) else f"version_{version}"
828
+ if version == '':
829
+ # if no loggers, use default_root_dir
830
+ version = 'version'
831
+
832
+ """Dump config to `work_dir`."""
833
+ if self.cfg.filename is not None:
834
+ filename = osp.basename(self.cfg.filename)
835
+ else:
836
+ filename = f'{self.timestamp}.py'
837
+ path = f'{self.trainer.default_root_dir}/{version}_{filename}'
838
+
839
+ self.cfg.dump(path)
840
+
841
+ def _check_scheduler_cfg(
842
+ self, param_scheduler: Optional[Union[dict, list,
843
+ _ParamScheduler]]) -> None:
844
+ """Parse `param_scheduler` to a list of parameter schedulers, or a
845
+ `dict` of which each value is a list of parameter schedulers.
846
+
847
+ If only one optimizer is used, the parsed config should be a
848
+ list of parameter scheduler configs or instances. If multiple
849
+ optimizers are used, the parsed config should be `dict`.
850
+ Its key should be consistent with the optimizer `dict` and its value
851
+ should be a list of parameter scheduler configs or instances. See
852
+ :meth:`build_param_scheduler` for more details.
853
+
854
+ Examples:
855
+ >>> # valid scheduler:
856
+ >>> # empty scheduler
857
+ >>> scheduler = None
858
+ >>> # Single scheduler
859
+ >>> scheduler = dict(type='MultiStepLR', milestones=[1, 2])
860
+ >>> # Single list schedulers
861
+ >>> scheduler = [dict(type='MultiStepLR', milestones=[1, 2]),
862
+ >>> dict(type='MultiStepLR', milestones=[2, 3])]
863
+ >>> # `dict` of schedulers
864
+ >>> scheduler = dict(linear1=dict(type='MultiStepLR', milestones=[1, 2]),
865
+ >>> linear2=dict(type='MultiStepLR', milestones=[1, 2]))
866
+ >>> # `dict` of `list` of schedulers
867
+ >>> scheduler = dict(linear1=[dict(type='MultiStepLR', milestones=[1, 2])],
868
+ >>> linear2=[dict(type='MultiStepLR', milestones=[1, 2])])
869
+ >>> # Single built scheduler
870
+ >>> from mmengine.optim import MultiStepLR
871
+ >>> scheduler = MultiStepLR(milestones=[1, 2], optimizer=optimizer)
872
+ >>> # Single built list schedulers
873
+ >>> scheduler = [MultiStepLR(milestones=[1, 2], optimizer=optimizer)]
874
+ >>> # dict of built scheduler
875
+ >>> scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer),
876
+ >>> linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer))
877
+ >>> # dict of built list schedulers
878
+ >>> scheduler = dict(linear1=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)],
879
+ >>> linear2=[MultiStepLR(milestones=[1, 2], optimizer=optimizer)])
880
+
881
+ Args:
882
+ param_scheduler (dict or list): The original parameter scheduler.
883
+ """ # noqa: E501
884
+ param_schedulers: Union[dict, list, _ParamScheduler]
885
+ if param_scheduler is None:
886
+ return
887
+ if isinstance(param_scheduler, _ParamScheduler):
888
+ return
889
+ if is_seq_of(param_scheduler, _ParamScheduler):
890
+ return
891
+
892
+ if is_seq_of(param_scheduler, dict):
893
+ for _param_scheduler in param_scheduler:
894
+ assert 'type' in _param_scheduler, (
895
+ 'Each parameter scheduler should contain the key type, '
896
+ f'but got {_param_scheduler}')
897
+ elif isinstance(param_scheduler, dict):
898
+ if 'type' not in param_scheduler:
899
+ for key, _param_scheduler in param_scheduler.items():
900
+ assert isinstance(
901
+ _param_scheduler,
902
+ (dict, tuple, list, _ParamScheduler)), (
903
+ 'Each value of `param_scheduler` should be a '
904
+ f'dict or a list, but got {_param_scheduler} with '
905
+ f'type {type(_ParamScheduler)}')
906
+
907
+ else:
908
+ raise TypeError(
909
+ '`param_scheduler` should be a `_ParamScheduler`, `dict`, '
910
+ f'list or a tuple, but got {type(param_scheduler)}. If '
911
+ '`param_scheduler` is a list of dict, it means a list of '
912
+ 'scheduler configs for single optimizer. If it is a dict and '
913
+ 'contains key `type`, it means a scheduler config for a '
914
+ 'single optimizer. If it does not contain key `type`, it '
915
+ 'means multiple lists of schedulers for multiple optimizers.')
916
+
917
+ def _log_env(self, env_cfg: dict) -> None:
918
+ """Logging environment information of the current task.
919
+
920
+ Args:
921
+ env_cfg (dict): The environment config of the runner.
922
+ """
923
+ # Collect and log environment information.
924
+ env = collect_env()
925
+ runtime_env = OrderedDict()
926
+ runtime_env.update(env_cfg)
927
+ runtime_env.update(self._randomness_cfg)
928
+ runtime_env['Distributed launcher'] = self._launcher
929
+ runtime_env['Distributed training'] = self._distributed
930
+ runtime_env['GPU number'] = self._world_size
931
+
932
+ env_info = '\n ' + '\n '.join(f'{k}: {v}'
933
+ for k, v in env.items())
934
+ runtime_env_info = '\n ' + '\n '.join(
935
+ f'{k}: {v}' for k, v in runtime_env.items())
936
+ dash_line = '-' * 60
937
+ self.logger.info('\n' + dash_line + '\nSystem environment:' +
938
+ env_info + '\n'
939
+ '\nRuntime environment:' + runtime_env_info + '\n' +
940
+ dash_line + '\n')
941
+ self.logger.info(f'Config:\n{self.cfg.pretty_text}')
mmpl/engine/strategies/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .builder import PL_MODEL_WRAPPERS
mmpl/engine/strategies/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (209 Bytes). View file