如何在 PyTorch 中执行展开操作?

Tensor.expand()属性用于执行展开操作。它沿着单例维度将张量扩展到新维度。

  • 扩展张量只会创建原始张量的新视图;它不会复制原始张量。

  • 如果您将特定维度设置为 -1,则张量将不会沿此维度展开。

  • 例如,如果我们有一个大小为 (3,1) 的张量,我们可以沿着大小为 1 的维度扩展这个张量。

步骤

要扩展张量,可以按照以下步骤操作 -

  • 导入火炬库。确保您已经安装了它。

import torch

  • 将至少具有一维的张量定义为单例。

t = torch.tensor([[1],[2],[3]])

  • 沿单例维度展开张量。沿非单一维度展开将引发运行时错误(参见示例 3)。

t_exp = t.expand(3,2)

  • 显示扩展的张量。

print("Tensor after expand:\n", t_exp)

示例 1

以下 Python 程序展示了如何将大小为 (3,1) 的张量扩展为大小为 (3,2) 的张量。它沿维度大小 1 扩展张量。大小为 3 的另一个维度保持不变。

# import required libraries
import torch

# create a tensor
t = torch.tensor([[1],[2],[3]])

# display the tensor
print("Tensor:\n", t)
print("Size of Tensor:\n", t.size())

# expand the tensor
exp = t.expand(3,2)
print("Tensor after expansion:\n", exp)
输出结果
Tensor:
 tensor([[1],
    [2],
    [3]])
Size of Tensor:
 torch.Size([3, 1])
Tensor after expansion:
 tensor([[1, 1],
    [2, 2],
    [3, 3]])

示例 2

以下 Python 程序将大小为 (1,3) 的张量扩展为大小为 (3,3) 的张量。它沿维度大小为 1 扩展张量。

# import required libraries
import torch

# create a tensor
t = torch.tensor([[1,2,3]])

# display the tensor
print("Tensor:\n", t)

# size of tensor is [1,3]
print("Size of Tensor:\n", t.size())

# expand the tensor
expandedTensor = t.expand(3,-1)

print("Expanded Tensor:\n", expandedTensor)
print("Size of expanded tensor:\n", expandedTensor.size())
输出结果
Tensor:
 tensor([[1, 2, 3]])
Size of Tensor:
 torch.Size([1, 3])
Expanded Tensor:
 tensor([[1, 2, 3],
    [1, 2, 3],
    [1, 2, 3]])
Size of expanded tensor:
 torch.Size([3, 3])

示例 3

在下面的 Python 程序中,我们尝试沿非单一维度扩展张量,因此它引发了运行时错误。

# import required libraries
import torch

# create a tensor
t = torch.tensor([[1,2,3]])

# display the tensor
print("Tensor:\n", t)

# size of tensor is [1,3]
print("Size of Tensor:\n", t.size())
t.expand(3,4)
输出结果
Tensor:
 tensor([[1, 2, 3]])
Size of Tensor:
 torch.Size([1, 3])


RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1. Target sizes: [3, 4]. Tensor sizes: [1, 3]