I encountered many code fragments like the following for choosing an action, that include a mix of torch.no_grad and detach (where actor is some actor, SomeDistribution your preferred distribution), and I'm wondering whether they make sense:
def f():
with torch.no_grad():
x = actor(observation)
dist = SomeDistribution(x)
sample = dist.sample()
return sample.detach()
Is the use of detach in the return statement not unnecessary, as x has its requires_grad already set to False, so all computations using x should already be detached from the graph? Or do the computations after the torch.no_grad wrapper somehow end up on the graph again, so we need to detach them once again in the end (in which case it seems to me that no_grad would be unnecessary)?
Also, if I'm right, I suppose instead of omitting detach one could also omit torch.no_grad, and end up with the same functionality, but worse performance, so torch.no_grad is to be preferred?
CodePudding user response:
While it may be redundant, it depends on the internals of actor and SomeDistribution. In general, there are three cases I can think of where detach would be necessary in this code. Since you've already observed that x has requires_grad set to False then cases 2 and 3 don't apply to your specific case.
- If
SomeDistributionhas internal parameters (leaf tensors withrequires_grad=True) thendist.sample()may result in a computation graph connectingsampleto those parameters. Without detaching, that computation graph, including those parameters, would be unnecessarily kept in memory after returning. - The default behavior within a
torch.no_gradcontext is to return the result of tensor operations havingrequires_gradset toFalse. However, ifactor(observation)for some reason explicitly setsrequires_gradof its return value toTruebefore returning, then a computation graph may be created that connectsxtosample. Without detaching, that computation graph, includingx, would be unnecessarily kept in memory after returning. - This one seems even more unlikely, but if
actor(observation)actually just returns a reference toobservation, andobservation.requires_gradisTrue, then a computation graph all the way fromobservationtosamplemay be constructed duringdist.sample().
As for the suggestion of removing the no_grad context in leu of detach, this may result in the construction of a computation graph connecting observation (if it requires gradients) and/or the parameters of the distribution (if it has any) to x. The graph would be discarded after detach, but it does take time and memory to create the computation graph, so there may be a performance penalty.
In conclusion, it's safer to do both no_grad and detach, though the necessity of either depends on the details of the distribution and actor.
