|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
def load_state_dict(model, state_dict): |
|
"""Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. |
|
|
|
DataParallel prefixes state_dict keys with 'module.' when saving. |
|
If the model is not a DataParallel model but the state_dict is, then prefixes are removed. |
|
If the model is a DataParallel model but the state_dict is not, then prefixes are added. |
|
""" |
|
state_dict = state_dict.get('model', state_dict) |
|
|
|
|
|
do_prefix = isinstance( |
|
model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) |
|
state = {} |
|
for k, v in state_dict.items(): |
|
if k.startswith('module.') and not do_prefix: |
|
k = k[7:] |
|
|
|
if not k.startswith('module.') and do_prefix: |
|
k = 'module.' + k |
|
|
|
state[k] = v |
|
|
|
model.load_state_dict(state) |
|
print("Loaded successfully") |
|
return model |
|
|
|
|
|
def load_wts(model, checkpoint_path): |
|
ckpt = torch.load(checkpoint_path, map_location='cpu') |
|
return load_state_dict(model, ckpt) |
|
|
|
|
|
def load_state_dict_from_url(model, url, **kwargs): |
|
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) |
|
return load_state_dict(model, state_dict) |
|
|
|
|
|
def load_state_from_resource(model, resource: str): |
|
"""Loads weights to the model from a given resource. A resource can be of following types: |
|
1. URL. Prefixed with "url::" |
|
e.g. url::http(s)://url.resource.com/ckpt.pt |
|
|
|
2. Local path. Prefixed with "local::" |
|
e.g. local::/path/to/ckpt.pt |
|
|
|
|
|
Args: |
|
model (torch.nn.Module): Model |
|
resource (str): resource string |
|
|
|
Returns: |
|
torch.nn.Module: Model with loaded weights |
|
""" |
|
print(f"Using pretrained resource {resource}") |
|
|
|
if resource.startswith('url::'): |
|
url = resource.split('url::')[1] |
|
return load_state_dict_from_url(model, url, progress=True) |
|
|
|
elif resource.startswith('local::'): |
|
path = resource.split('local::')[1] |
|
return load_wts(model, path) |
|
|
|
else: |
|
raise ValueError("Invalid resource type, only url:: and local:: are supported") |
|
|