Tensor.detach() 在 PyTorch 中做什么?

Tensor.detach()用于从当前计算图中分离张量。它返回一个不需要梯度的新张量。

  • 当我们不需要为梯度计算跟踪张量时,我们将张量从当前计算图中分离。

  • 当我们需要将张量从 GPU 移动到 CPU 时,我们还需要分离张量。

语法

Tensor.detach()

它返回一个没有requires_grad = True的新张量。将不再计算关于该张量的梯度。

步骤

  • 导入火炬库。确保您已经安装了它。

import torch

  • 使用requires_grad = True创建 PyTorch 张量并打印张量。

x = torch.tensor(2.0, requires_grad = True)
print("x:", x)

  • 计算并可选择将此值分配给新变量。Tensor.detach()

x_detach = x.detach()

  • 在 之后打印张量。detach()操作被执行。

print("带有分离的张量:", x_detach)

示例 1

# import torch library
import torch

# create a tensor with requires_gradient=true
x = torch.tensor(2.0, requires_grad = True)

# print the tensor
print("Tensor:", x)

#tensor.detachoperation
x_detach = x.detach()
print("带有分离的张量:", x_detach)
输出结果
Tensor: tensor(2., requires_grad=True)
带有分离的张量: tensor(2.)

请注意,在上面的输出中,分离后的张量没有requires_grad = True

示例 2

# import torch library
import torch

# define a tensor with requires_grad=true
x = torch.rand(3, requires_grad = True)
print("x:", x)

# apply above tensor to use detach()
y = 3 + x
z = 3 * x.detach()

print("y:", y)
print("z:", z)
输出结果
x: tensor([0.5656, 0.8402, 0.6661], requires_grad=True)
y: tensor([3.5656, 3.8402, 3.6661], grad_fn=<AddBackward0>)
z: tensor([1.6968, 2.5207, 1.9984])