Spaces:
Runtime error
Runtime error
File size: 667 Bytes
fd52b7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
""" DeepLabv3 Model download and change the head for your prediction"""
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision import models
def createDeepLabv3(outputchannels=1):
"""DeepLabv3 class with custom head
Args:
outputchannels (int, optional): The number of output channels
in your dataset masks. Defaults to 1.
Returns:
model: Returns the DeepLabv3 model with the ResNet101 backbone.
"""
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
model.classifier = DeepLabHead(2048, outputchannels)
# Set the model in training mode
model.train()
return model |