|
from collections import OrderedDict
|
|
import torch.nn as nn
|
|
|
|
|
|
class ModelBook:
|
|
"""Maintain the mapping between modules and their paths.
|
|
|
|
Example:
|
|
book = ModelBook(model_ft)
|
|
for p, m in book.conv2d_modules():
|
|
print('path:', p, 'num of filters:', m.out_channels)
|
|
assert m is book.get_module(p)
|
|
"""
|
|
|
|
def __init__(self, model):
|
|
self._model = model
|
|
self._modules = OrderedDict()
|
|
self._paths = OrderedDict()
|
|
path = []
|
|
self._construct(self._model, path)
|
|
|
|
def _construct(self, module, path):
|
|
if not module._modules:
|
|
return
|
|
for name, m in module._modules.items():
|
|
cur_path = tuple(path + [name])
|
|
self._paths[m] = cur_path
|
|
self._modules[cur_path] = m
|
|
self._construct(m, path + [name])
|
|
|
|
def conv2d_modules(self):
|
|
return self.modules(nn.Conv2d)
|
|
|
|
def linear_modules(self):
|
|
return self.modules(nn.Linear)
|
|
|
|
def modules(self, module_type=None):
|
|
for p, m in self._modules.items():
|
|
if not module_type or isinstance(m, module_type):
|
|
yield p, m
|
|
|
|
def num_of_conv2d_modules(self):
|
|
return self.num_of_modules(nn.Conv2d)
|
|
|
|
def num_of_conv2d_filters(self):
|
|
"""Return the sum of out_channels of all conv2d layers.
|
|
|
|
Here we treat the sub weight with size of [in_channels, h, w] as a single filter.
|
|
"""
|
|
num_filters = 0
|
|
for _, m in self.conv2d_modules():
|
|
num_filters += m.out_channels
|
|
return num_filters
|
|
|
|
def num_of_linear_modules(self):
|
|
return self.num_of_modules(nn.Linear)
|
|
|
|
def num_of_linear_filters(self):
|
|
num_filters = 0
|
|
for _, m in self.linear_modules():
|
|
num_filters += m.out_features
|
|
return num_filters
|
|
|
|
def num_of_modules(self, module_type=None):
|
|
num = 0
|
|
for p, m in self._modules.items():
|
|
if not module_type or isinstance(m, module_type):
|
|
num += 1
|
|
return num
|
|
|
|
def get_module(self, path):
|
|
return self._modules.get(path)
|
|
|
|
def get_path(self, module):
|
|
return self._paths.get(module)
|
|
|
|
def update(self, path, module):
|
|
old_module = self._modules[path]
|
|
del self._paths[old_module]
|
|
self._paths[module] = path
|
|
self._modules[path] = module
|
|
|