PyTorch 中的“with torch no_grad”有什么作用?

“with ”torch.no_grad()的使用就像一个循环,其中循环内的每个张量都将requires_grad设置为False。这意味着当前与当前计算图相连的任何具有梯度的张量现在都与当前图分离。我们不再能够计算关于这个张量的梯度。

张量从当前图中分离,直到它在循环内。一旦它离开循环,如果张量是用梯度定义的,它就会再次附加到当前图。

让我们举几个例子来更好地理解它是如何工作的。

示例 1

在这个例子中,我们用requires_grad = true创建了一个张量 x 。接下来,我们定义这个张量 x 的函数 y 并将该函数放入 with循环中。现在 x 在循环内,所以它的requires_grad被设置为Falsetorch.no_grad()

在循环中,无法计算 y 相对于 x 的梯度。所以,y.requires_grad返回False

# import torch library
import torch

# define a torch tensor
x = torch.tensor(2., requires_grad = True)
print("x:", x)

# define a function y
with torch.no_grad():
   y = x ** 2
print("y:", y)

# check gradient for Y
print("y.requires_grad:", y.requires_grad)
输出结果
x: tensor(2., requires_grad=True)
y: tensor(4.)
y.requires_grad: False

示例 2

在这个例子中,我们在循环之外定义了函数z。所以,z.requires_grad返回True

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)

print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# define a function z
with torch.no_grad():
   z = w * x + b

print("z:", z)

# check if requires grad is true or not
print("y.requires_grad:", y.requires_grad)
print("z.requires_grad:", z.requires_grad)
输出结果
x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
z: tensor(7.)
y.requires_grad: True
z.requires_grad: False