I am new to PyTorch but have a lot of experience with TensorFlow.
I would like to modify the gradient of just a tiny piece of the graph: just the derivative of activation function of a single layer. This can be easily done in Tensorflow using tf.custom_gradient, which allows you to supply customized gradient for any functions.
I would like to do the same thing in PyTorch and I know that you can modify the backward() method, but that requires you to rewrite the derivative for the whole network defined in the forward() method, when I would just like to modify the gradient of a tiny piece of the graph. Is there something like tf.custom_gradient() in PyTorch? Thanks!
CodePudding user response:
You can do this in two ways:
1. Modifying the backward() function:
As you already said in your question, pytorch also allows you to provide a custom backward implementation. However, in contrast to what you wrote, you do not need to re-write the backward() of the entire model - only the backward() of the specific layer you want to change.
Here's a simple and nice tutorial that shows how this can be done.
For example, here is a custom clip activation that instead of killing the gradients outside the [0, 1] domain, simply passes the gradients as-is:
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.clip(x, 0., 1.)
@staticmethod
def backward(ctx, grad):
return grad
Now you can use MyClip layer wherever you like in your model and you do not need to worry about the overall backward function.
2. Using a backward hook
pytorch allows you to attach hooks to different layer (=sub nn.Modules) of your network. You can register_full_backward_hook to your layer. That hook function can modify the gradients:
The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of
grad_inputin subsequent computations.
