|
def load_model_weights(model, weights, multi_gpus, train=True): |
|
""" |
|
Load the model weights from the given checkpoint file |
|
""" |
|
|
|
|
|
if list(weights.keys())[0].find('module') == -1: |
|
pretrained_with_multi_gpu = False |
|
else: |
|
pretrained_with_multi_gpu = True |
|
|
|
if (multi_gpus is False) or (train is False): |
|
if pretrained_with_multi_gpu: |
|
state_dict = { |
|
key[7:]: value |
|
for key, value in weights.items() |
|
} |
|
else: |
|
state_dict = weights |
|
else: |
|
state_dict = weights |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
|
|
class dummy_context_mgr(): |
|
def __enter__(self): |
|
return None |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
return False |
|
|
|
|
|
|
|
def read_css_from_file(filename): |
|
with open(filename, 'r') as file: |
|
return file.read() |
|
|