I am going through this tutorial on how to customize the training loop
The last example shows a GAN implemented with a custom training, where only __init__, train_step, and compile methods are defined
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
...
What happens if my model also has a call() custom function? Does train_step() overrides call()?
Aren't call() and train_step() both called by fit() and what is the difference between both ?
Below another piece of code "I" wrote where I wonder what is called into fit(), call() or train_step():
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True,
reset_after=True
)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
@tf.function
def train_step(self, inputs):
# unpack the data
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True) # forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)
# compute the gradients
grads=tape.gradient(loss, model.trainable_variables)
# Update weights
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(labels, predictions)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
CodePudding user response:
These are different concepts and are used like this:
train_stepis called byfit. Basically,fitloops over the dataset and provide each batch totrain_step(and then handles metrics, bookkeeping, etc., of course).callis used when you, well, call the model. To be precise, writingmodel(inputs)or in your caseself(inputs)will use the function__call__, but theModelclass has that function defined such that it will in turn usecall.
Those are the technical aspects. Intuitively:
callshould define the forward-pass of your model. i.e. how is the input transformed to the output.train_stepdefines the logic of a training step, usually with gradient descent. It will often make use ofcallsince the training step tends to include a forward pass of the model to compute gradients.
As for the GAN tutorial you linked, I would say that can actually be considered incomplete. It works without defining call because the custom train_step explicitly calls the generator/discriminator fields (as these are predefined models, they can be called as usual). If you tried to call the GAN model like gan(inputs), I would assume you get an error message (I did not test this). So you would always have to call gan.generator(inputs) to generate, for example.
Finally (this part may be a bit confusing), note that you can subclass a Model to define a custom training step, but then initialize it via the functional API (like model = Model(inputs, outputs)), in which case you can make use of call in the training step without ever defining it yourself because the functional API takes care of that.
