Spaces:
Sleeping
Sleeping
def load_model_weights(model, weights, multi_gpus, train=True): | |
""" | |
Load the model weights from the given checkpoint file | |
""" | |
# If model was originally trained on a single GPU but needs to be loaded onto multiple ones, | |
# it removes the "module" prefix from the weight keys | |
if list(weights.keys())[0].find('module') == -1: | |
pretrained_with_multi_gpu = False | |
else: | |
pretrained_with_multi_gpu = True | |
if (multi_gpus is False) or (train is False): | |
if pretrained_with_multi_gpu: | |
state_dict = { | |
key[7:]: value | |
for key, value in weights.items() | |
} | |
else: | |
state_dict = weights | |
else: | |
state_dict = weights | |
# load the model from the state_dict | |
model.load_state_dict(state_dict) | |
return model | |
# Class to work with if mixed precision is failing | |
class dummy_context_mgr: | |
def __init__(self): | |
pass | |
def __enter__(self): | |
return None | |
def __exit__(self, exc_type, exc_value, traceback): | |
return False | |
# Function to read CSS from file | |
def read_css_from_file(filename): | |
with open(filename, 'r') as file: | |
return file.read() | |