torchsat.models.segmentation package

Submodules

torchsat.models.segmentation.pspnet module

class torchsat.models.segmentation.pspnet.PSPNet(num_classes, in_channels=3, backbone='resnet50', pretrained=True, use_aux=True)

Bases: torch.nn.modules.module.Module

PSPNet, currently only support 3 channels.

Args:
nn ([type]): [description]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

freeze_bn()
get_backbone_params()
get_decoder_params()

torchsat.models.segmentation.unet module

class torchsat.models.segmentation.unet.UNetResNet(encoder_depth, num_classes, in_channels=3, num_filters=32, dropout_2d=0.0, pretrained=False, is_deconv=False)

Bases: torch.nn.modules.module.Module

PyTorch U-Net model using ResNet(34, 101 or 152) encoder. UNet: https://arxiv.org/abs/1505.04597 ResNet: https://arxiv.org/abs/1512.03385 Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/

Args:

encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152). num_classes (int): Number of output classes. num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32. dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2. pretrained (bool, optional):

False - no pre-trained weights are being used. True - ResNet encoder is pre-trained on ImageNet. Defaults to False.
is_deconv (bool, optional):
False: bilinear interpolation is used in decoder. True: deconvolution is used in decoder. Defaults to False.
Raises:
ValueError: [description] NotImplementedError: [description]
Returns:
[type]: [description]
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

torchsat.models.segmentation.unet.unet_resnet34(num_classes, in_channels=3, pretrained=False, **kwargs)
torchsat.models.segmentation.unet.unet_resnet101(num_classes, in_channels=3, pretrained=False, **kwargs)
torchsat.models.segmentation.unet.unet_resnet152(num_classes, in_channels=3, pretrained=False, **kwargs)

Module contents