Detach Tensor From Computation Graph

PyTorch
Author

Imad Dabbura

Published

October 4, 2022

We can think of a Tensor that implements Automatic Differentation as a regular tensor that have, among other attributes, the following attribites that helps to capture its history:

Therefore, to detach a tensor from a computation graph Or if we don’t want to track temporary computations done on a tensor (such as during inference), we can do the following:

a = torch.tensor([[1., 2.]], requires_grad=True)
a  #=> tensor([[1., 2.]], requires_grad=True)

# The following will record the computation
b = a + 1
b.grad_fn  #=> <AddBackward0 at 0x7f81abadc8b0>

# All the following forms allow us to avoid recording the computation on the tensor
b = a.data + 1
b.grad_fn           #=> None
b.requires_grad     #=> False

b = a.detach() + 1
b.grad_fn           #=> None
b.requires_grad     #=> False

with torch.no_grad():
    b = a + 1
b.grad_fn           #=> None
b.requires_grad     #=> False

# Useful for updates/initialization because it keeps requires_grad attribute
a.data = torch.randint(10, size=(1, 2))
a.grad_fn           #=> None
b.requires_grad     #=> True