I have a custom forward implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn and it is not None.
I'm trying to understand two things:
How this function can be differentiable since there is an
if-elsestatement on the path from input to output?Does the path from
gt(ground truth input) to loss (output) need to be differentiable? or only the path frompred(prediction input)?
Here is the source code:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s neg_loss_s) / num_pos
return loss
CodePudding user response:
The if statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. the forward function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) using grad_fn callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a. grad_fn).
In your example, whether
num_posis equal to0or not, the resulting loss tensor will depend onneg_loss_salone or onpos_loss_sandneg_loss_s. However in either cases, the resultinglosstensor remains attached to the inputpred:- via one way: the "
neg_loss_s" node - or the other: the "
pos_loss_s" and "neg_loss_s" nodes.
- via one way: the "
In your setup, either way, the operation is differentiable.
- If
gtis a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where bothpos_inds, andneg_indsare non-differientblae because they are boolean operators.
CodePudding user response:
PyTorch does not compute gradients w.r.t the loss function itself. PyTorch records the sequence of standard mathematical operations performed during the forward pass, such as log, exponentiation, multiplication, addition, etc., and computes their gradients w.r.t those mathematical operations when backward() is called. Thus, the presence of if-else conditions don't matter to PyTorch provided you use only the standard math operations to compute your loss.
