Home > database >  Defining a variable depth Fully Convolutional Network in Tensorflow
Defining a variable depth Fully Convolutional Network in Tensorflow

Time:01-13

I am currently working with FCNs and want to make my code more convenient. My current model is more or less hard-coded to have three layers in the encoder and decoder path, respectively. My goal is to define a list containing the number of filters per stage and the model automatically is created accordingly.

I set my model up with two custom layers (Encoder and Decoder). The class FullyConvolutionalNetwork (subclass of tf.keras.models.Model) uses these custom layers to build the final model.

# define the layers needed
class EncoderBlock(layers.Layer):
    def __init__(self,
                 num_conv: int,
                 kernel: int = None,
                 stride: int = None,
                 activation: str = "relu",
                 padding: str = "same",
                 name: str = "Encoder1D",
                 data_format: str = "channels_last",
                 *args, **kwargs):
        super(EncoderBlock, self).__init__(name=name, *args, **kwargs)
        if kernel is None:
            kernel = 3
        if stride is None:
            stride = 1

        self.conv = Conv1D(num_conv, kernel, strides=stride, activation=activation,
                           padding=padding, data_format=data_format)
        self.pool = MaxPool1D(2, strides=2, padding=padding, data_format=data_format)

    @tf.function
    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        return self.pool(x)


class DecoderBlock(layers.Layer):
    def __init__(self,
                 num_conv: int,
                 kernel: int = None,
                 stride: int = None,
                 activation: str = "relu",
                 padding: str = "same",
                 name: str = "Decoder1D",
                 data_format: str = "channels_last",
                 *args, **kwargs):
        super(DecoderBlock, self).__init__(name=name, *args, **kwargs)
        if kernel is None:
            kernel = 3
        if stride is None:
            stride = 1

        self.conv = Conv1D(num_conv, kernel, strides=stride, activation=activation,
                           padding=padding, data_format=data_format)
        self.pool = Conv1DTranspose(num_conv, 3, strides=2, padding=padding, data_format=data_format)

    @tf.function
    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        return self.pool(x)


# now define the model
class FullyConvolutionalNetwork(Model, ABC):
    def __init__(self,
                 num_filter: list,
                 name: str = "FullyConvolutionalNetwork",
                 activation: str = "relu",
                 alpha: float = None,
                 data_format: str = "channels_last",
                 *args, **kwargs):
        super(FullyConvolutionalNetwork, self).__init__(name=name, *args, **kwargs)
        # encoder layers
        self.encoder1 = EncoderBlock(num_filter[0], name="enc1", data_format=data_format,
                                     activation=activation)
        self.encoder2 = EncoderBlock(num_filter[1], name="enc2", data_format=data_format,
                                     activation=activation)
        self.encoder3 = EncoderBlock(num_filter[2], name="enc3", data_format=data_format,
                                     activation=activation)

        # decoder layers
        self.decoder1 = DecoderBlock(num_filter[2], name="dec1", data_format=data_format,
                                     activation=activation)
        self.decoder2 = DecoderBlock(num_filter[1], name="dec2", data_format=data_format,
                                     activation=activation)
        self.decoder3 = DecoderBlock(num_filter[0], name="dec3", data_format=data_format,
                                     activation=activation)

        # output section
        self.conv1x1 = Conv1D(1, 1, strides=1, padding='same', name="1x1", data_format=data_format)
        self.flatten = Flatten()  # remove channel dimension!

    @tf.function
    def call(self, input_tensor, training=False):
        x = self.encoder1(input_tensor)
        x = self.encoder2(x)
        x = self.encoder3(x)

        x = self.decoder1(x)
        x = self.decoder2(x)
        x = self.decoder3(x)

        x = self.conv1x1(x)
        return self.flatten(x)


my_model = FullyConvolutionalNetwork(num_filters=[8, 16, 32])
my_model.build(input_shape=[None, 512, 10])

I now want to make my model depth variable. I already achieved a variable amount of class attributes by writing as many layers to self.dict as I need in the models constructor, like shown here:

# define symmetric encoder and decoder blocks
        for i, f in enumerate(num_filter):
            self.__dict__[f"encoder{i}"] = EncoderBlock(f, activation=activation, alpha=alpha, name=f"enc{i}", data_format=data_format)
            self.__dict__[f"decoder{i}"] = DecoderBlock(f, activation=activation, alpha=alpha, name=f"dec{i}", data_format=data_format)

This should in theory replace the lines self.encoder1 = ..., self.encoder2 = ... . However, I now fail to build the model from my dict. My understanding is that all layers are stored to model.dict, such that I should be able to add my layers in a loop:

@tf.function
    def call(self, input_tensor, training=False):
        for count, layer in enumerate([self.__dict__[f"encoder{i}"] for i, _ in enumerate(num_filter)]):
            if count == 0:
                x = layer(input_tensor)
            else:
                x = layer(x)
        # and same for the decoder layers...
        
        # output section
        x = self.conv1x1(x)
        return self.flatten(x)

I double checked and all layers are available in the dict. But they are not added to my model, i.e. model.summary() shows only the last two layers (1x1 convolution and flatten). Same when I add the layers from my dict manually:

@tf.function
    def call(self, input_tensor, training=False):
        x = self.__dict__["encoder0"](input_tensor)
        x = self.__dict__["encoder1"](x)
        ...

Do you have any advice where and why my attempt fails? Thanks in advance!

CodePudding user response:

Using setattr(self, key, value) instead of writing everything manually to self.__dict__ did the trick! You just have to take care of the ordering of your layers as they seem to be ordered chronologically to their initialization in def __init__(self, ...).

  •  Tags:  
  • Related