|
import torch |
|
import torch.nn as nn |
|
from mono.utils.comm import get_func |
|
|
|
class DensePredModel(nn.Module): |
|
def __init__(self, cfg) -> None: |
|
super(DensePredModel, self).__init__() |
|
|
|
self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone) |
|
self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg) |
|
|
|
def forward(self, input, **kwargs): |
|
|
|
features = self.encoder(input) |
|
out = self.decoder(features, **kwargs) |
|
return out |