Home > Blockchain >  How to save custom attributes with custom model in Tensorflow?
How to save custom attributes with custom model in Tensorflow?

Time:01-30

GOAL

I'm trying to create custom model in Tensorflow with subclassing method. My goal is to create model with some custom attributes in it, train it, save it and after loading get the values of custom attributes with the model.

I've been looking for solution in the Internet, but I found nothing about this problem.

ISSUE

I've created test custom model class with self.custom_att attribute, which is a list, in it. I've trained it on random data, saved and loaded. After loading the model, the attribute itself is in the model object, but it's changed to ListWrapper object and it's empty.

QUESTION

How to store this attribute, so it would keep the values from before the saving process and after the loading process?

CODE

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
import numpy as np
from tensorflow.keras.models import load_model


class CustomModel(Model):

    def __init__(self):
        super(CustomModel, self).__init__()
        self.in_dense = Dense(10, activation='relu')
        self.dense = Dense(30, activation='relu')
        self.out = Dense(3, activation='softmax')
        self.custom_att = ['custom_att1', 'custom_att2'] # <- this attribute I want to store

    def call(self, inputs, training=None, mask=None):
        x = self.in_dense(inputs)
        x = self.dense(x)
        x = self.out(x)
        return x

    def get_config(self):
        base_config = super(CustomModel, self).get_config()
        return {**base_config, 'custom_att': self.custom_att}


X = np.random.random((1000, 5))
y = np.random.random((1000, 3))

model = CustomModel()
model.build((1, 5))
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
model.summary()
history = model.fit(X, y, epochs=1, validation_split=0.1)
model.save('models/testModel.model')

del model

model = load_model('models/testModel.model', custom_objects={'CustomModel': CustomModel}) # <- here attribute becomes ListWrapper([])
print(model.custom_att)

ENVIRONMENT

  • Python 3.8.5
  • Tensorflow 2.3.0

CodePudding user response:

I do not think using a list there would work when loading your model. Replace

self.custom_att = ['custom_att1', 'custom_att2']

with

self.custom_att = tf.Variable(['custom_att1', 'custom_att2'])

And you should see something like this:

print(model.custom_att.numpy())
# [b'custom_att1' b'custom_att2']

You can remove the b literal in the strings like this:

print(model.custom_att.numpy()[0].decode("utf-8"))
# custom_att1
  •  Tags:  
  • Related