Home > Back-end >  Is there a differentiable alternative to K.cast?
Is there a differentiable alternative to K.cast?

Time:01-15

For a custom Keras loss function, I need to create a float tensor from a bool tensor. Unfortunately, K.cast() is not differentiable and therefore can't be used. Is there an alternative way to do this that is differentiable?

less_than_tau = y_pred < tau
less_than_tau = K.cast(less_than_tau, 'float32')

CodePudding user response:

Dr. Snoopy is right.

The way you solve for this in deep learning is "soft" functions, such as softmax instead of max.

In your case, if you want to minimize y-pred relative y-tau, you'd do something like

switch = sigmoid(y_pred - y_tau)
loss = switch * true_case   (1. - switch) * false_case 
  •  Tags:  
  • Related