File size: 619 Bytes
69a8f96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
import torch.nn as nn
from mono.utils.comm import get_func
class BaseDepthModel(nn.Module):
def __init__(self, cfg, **kwargs) -> None:
super(BaseDepthModel, self).__init__()
model_type = cfg.model.type
self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
def forward(self, data):
output = self.depth_model(**data)
return output['prediction'], output['confidence'], output
def inference(self, data):
with torch.no_grad():
pred_depth, confidence, _ = self.forward(data)
return pred_depth, confidence |