Segmentation¶
UNet2dPyTorch¶
-
class
UNet2dPyTorch
(num_classes, in_channels=1, depth=5, start_filts=64, up_mode='transpose', merge_mode='concat')[source]¶ Bases:
delira.models.abstract_network.AbstractPyTorchNetwork
The
UNet2dPyTorch
is a convolutional encoder-decoder neural network. Contextual spatial information (from the decoding, expansive pathway) about an input tensor is merged with information representing the localization of details (from the encoding, compressive pathway).Notes
Differences to the original paper:
padding is used in 3x3 convolutions to prevent loss of border pixels
merging outputs does not require cropping due to (1)
residual connections can be used by specifying
merge_mode='add'
- if non-parametric upsampling is used in the decoder pathway (
specified by upmode=’upsample’), then an additional 1x1 2d convolution occurs after upsampling to reduce channel dimensionality by a factor of 2. This channel halving happens with the convolution in the tranpose convolution (specified by
upmode='transpose'
)
References
https://arxiv.org/abs/1505.04597
See also
-
_build_model
(num_classes, in_channels=3, depth=5, start_filts=64)[source]¶ Builds the actual model
- Parameters
num_classes (int) – number of output classes
in_channels (int) – number of channels for the input tensor (default: 1)
depth (int) – number of MaxPools in the U-Net (default: 5)
start_filts (int) – number of convolutional filters for the first conv (affects all other conv-filter numbers too; default: 64)
Notes
The Helper functions and classes are defined within this function because
delira
offers a possibility to save the source code along the weights to completely recover the network without needing a manually created network instance and these helper functions have to be saved too.
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]¶ closure method to do a single backpropagation step
- Parameters
model (
ClassificationNetworkBasePyTorch
) – trainable modeldata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary of optimizers to optimize model’s parameters
criterions (dict) – dict holding the criterions to calculate errors (gradients from different criterions will be accumulated)
metrics (dict) – dict holding the metrics to calculate
fold (int) – Current Fold in Crossvalidation (default: 0)
**kwargs – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict criterions)
list – Arbitrary number of predictions as torch.Tensor
- Raises
AssertionError – if optimizers or criterions are empty or the optimizers are not specified
-
forward
(x)[source]¶ Feed tensor through network
- Parameters
x (torch.Tensor) –
- Returns
Prediction
- Return type
-
static
prepare_batch
(batch: dict, input_device, output_device)[source]¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
-
static
weight_init
(m)[source]¶ Initializes weights with xavier_normal and bias with zeros
- Parameters
m (torch.nn.Module) – module to initialize
UNet3dPyTorch¶
-
class
UNet3dPyTorch
(num_classes, in_channels=3, depth=5, start_filts=64, up_mode='transpose', merge_mode='concat')[source]¶ Bases:
delira.models.abstract_network.AbstractPyTorchNetwork
The
UNet3dPyTorch
is a convolutional encoder-decoder neural network. Contextual spatial information (from the decoding, expansive pathway) about an input tensor is merged with information representing the localization of details (from the encoding, compressive pathway).Notes
- Differences to the original paper:
Working on 3D data instead of 2D slices
- padding is used in 3x3x3 convolutions to prevent loss of border
pixels
merging outputs does not require cropping due to (1)
residual connections can be used by specifying
merge_mode='add'
- if non-parametric upsampling is used in the decoder pathway (
specified by upmode=’upsample’), then an additional 1x1x1 3d convolution occurs after upsampling to reduce channel dimensionality by a factor of 2. This channel halving happens with the convolution in the tranpose convolution (specified by
upmode='transpose'
)
References
https://arxiv.org/abs/1505.04597
See also
-
_build_model
(num_classes, in_channels=3, depth=5, start_filts=64)[source]¶ Builds the actual model
- Parameters
num_classes (int) – number of output classes
in_channels (int) – number of channels for the input tensor (default: 1)
depth (int) – number of MaxPools in the U-Net (default: 5)
start_filts (int) – number of convolutional filters for the first conv (affects all other conv-filter numbers too; default: 64)
Notes
The Helper functions and classes are defined within this function because
delira
offers a possibility to save the source code along the weights to completely recover the network without needing a manually created network instance and these helper functions have to be saved too.
-
_init_kwargs
= {}¶
-
static
closure
(model, data_dict: dict, optimizers: dict, criterions={}, metrics={}, fold=0, **kwargs)[source]¶ closure method to do a single backpropagation step
- Parameters
model (
ClassificationNetworkBasePyTorch
) – trainable modeldata_dict (dict) – dictionary containing the data
optimizers (dict) – dictionary of optimizers to optimize model’s parameters
criterions (dict) – dict holding the criterions to calculate errors (gradients from different criterions will be accumulated)
metrics (dict) – dict holding the metrics to calculate
fold (int) – Current Fold in Crossvalidation (default: 0)
**kwargs – additional keyword arguments
- Returns
dict – Metric values (with same keys as input dict metrics)
dict – Loss values (with same keys as input dict criterions)
list – Arbitrary number of predictions as torch.Tensor
- Raises
AssertionError – if optimizers or criterions are empty or the optimizers are not specified
-
forward
(x)[source]¶ Feed tensor through network
- Parameters
x (torch.Tensor) –
- Returns
Prediction
- Return type
-
static
prepare_batch
(batch: dict, input_device, output_device)[source]¶ Helper Function to prepare Network Inputs and Labels (convert them to correct type and shape and push them to correct devices)
-
static
weight_init
(m)[source]¶ Initializes weights with xavier_normal and bias with zeros
- Parameters
m (torch.nn.Module) – module to initialize