fasd / tddfa /utils /onnx.py
ozyman's picture
added depth model
ddadf19
raw
history blame
941 Bytes
# coding: utf-8
__author__ = 'cleardusk'
import sys
sys.path.append('..')
import torch
import tddfa.models
from tddfa.utils.tddfa_util import load_model
def convert_to_onnx(**kvs):
# 1. load model
size = kvs.get('size', 120)
model = getattr(models, kvs.get('arch'))(
num_classes=kvs.get('num_params', 62),
widen_factor=kvs.get('widen_factor', 1),
size=size,
mode=kvs.get('mode', 'small')
)
checkpoint_fp = kvs.get('checkpoint_fp')
model = load_model(model, checkpoint_fp)
model.eval()
# 2. convert
batch_size = 1
dummy_input = torch.randn(batch_size, 3, size, size)
wfp = checkpoint_fp.replace('.pth', '.onnx')
torch.onnx.export(
model,
(dummy_input, ),
wfp,
input_names=['input'],
output_names=['output'],
do_constant_folding=True
)
print(f'Convert {checkpoint_fp} to {wfp} done.')
return wfp