|
|
|
|
|
__author__ = 'cleardusk' |
|
|
|
import sys |
|
|
|
sys.path.append('..') |
|
|
|
import torch |
|
import tddfa.models |
|
from tddfa.utils.tddfa_util import load_model |
|
|
|
|
|
def convert_to_onnx(**kvs): |
|
|
|
size = kvs.get('size', 120) |
|
model = getattr(models, kvs.get('arch'))( |
|
num_classes=kvs.get('num_params', 62), |
|
widen_factor=kvs.get('widen_factor', 1), |
|
size=size, |
|
mode=kvs.get('mode', 'small') |
|
) |
|
checkpoint_fp = kvs.get('checkpoint_fp') |
|
model = load_model(model, checkpoint_fp) |
|
model.eval() |
|
|
|
|
|
batch_size = 1 |
|
dummy_input = torch.randn(batch_size, 3, size, size) |
|
wfp = checkpoint_fp.replace('.pth', '.onnx') |
|
torch.onnx.export( |
|
model, |
|
(dummy_input, ), |
|
wfp, |
|
input_names=['input'], |
|
output_names=['output'], |
|
do_constant_folding=True |
|
) |
|
print(f'Convert {checkpoint_fp} to {wfp} done.') |
|
return wfp |
|
|