Metric3D / mono /model /model_pipelines /__base_model__.py
JUGGHM's picture
Upload 62 files
8a32844
raw
history blame
619 Bytes
import torch
import torch.nn as nn
from mono.utils.comm import get_func
class BaseDepthModel(nn.Module):
def __init__(self, cfg, **kwargs) -> None:
super(BaseDepthModel, self).__init__()
model_type = cfg.model.type
self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
def forward(self, data):
output = self.depth_model(**data)
return output['prediction'], output['confidence'], output
def inference(self, data):
with torch.no_grad():
pred_depth, confidence, _ = self.forward(data)
return pred_depth, confidence