Metric3D / mono /model /monodepth_model.py
JUGGHM's picture
Upload 62 files
8a32844
import torch
import torch.nn as nn
from .model_pipelines.__base_model__ import BaseDepthModel
class DepthModel(BaseDepthModel):
def __init__(self, cfg, **kwards):
super(DepthModel, self).__init__(cfg)
model_type = cfg.model.type
def inference(self, data):
with torch.no_grad():
pred_depth, confidence, output_dict = self.forward(data)
return pred_depth, confidence, output_dict
def get_monodepth_model(
cfg : dict,
**kwargs
) -> nn.Module:
# config depth model
model = DepthModel(cfg, **kwargs)
#model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
assert isinstance(model, nn.Module)
return model
def get_configured_monodepth_model(
cfg: dict,
) -> nn.Module:
"""
Args:
@ configs: configures for the network.
@ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
@ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
Returns:
# model: depth model.
"""
model = get_monodepth_model(cfg)
return model