fa0311's picture
Upload 81 files
b559e06 verified
from collections import OrderedDict
import torch.nn as nn
class ModelBook:
"""Maintain the mapping between modules and their paths.
Example:
book = ModelBook(model_ft)
for p, m in book.conv2d_modules():
print('path:', p, 'num of filters:', m.out_channels)
assert m is book.get_module(p)
"""
def __init__(self, model):
self._model = model
self._modules = OrderedDict()
self._paths = OrderedDict()
path = []
self._construct(self._model, path)
def _construct(self, module, path):
if not module._modules:
return
for name, m in module._modules.items():
cur_path = tuple(path + [name])
self._paths[m] = cur_path
self._modules[cur_path] = m
self._construct(m, path + [name])
def conv2d_modules(self):
return self.modules(nn.Conv2d)
def linear_modules(self):
return self.modules(nn.Linear)
def modules(self, module_type=None):
for p, m in self._modules.items():
if not module_type or isinstance(m, module_type):
yield p, m
def num_of_conv2d_modules(self):
return self.num_of_modules(nn.Conv2d)
def num_of_conv2d_filters(self):
"""Return the sum of out_channels of all conv2d layers.
Here we treat the sub weight with size of [in_channels, h, w] as a single filter.
"""
num_filters = 0
for _, m in self.conv2d_modules():
num_filters += m.out_channels
return num_filters
def num_of_linear_modules(self):
return self.num_of_modules(nn.Linear)
def num_of_linear_filters(self):
num_filters = 0
for _, m in self.linear_modules():
num_filters += m.out_features
return num_filters
def num_of_modules(self, module_type=None):
num = 0
for p, m in self._modules.items():
if not module_type or isinstance(m, module_type):
num += 1
return num
def get_module(self, path):
return self._modules.get(path)
def get_path(self, module):
return self._paths.get(module)
def update(self, path, module):
old_module = self._modules[path]
del self._paths[old_module]
self._paths[module] = path
self._modules[path] = module