如何在 PyTorch 中执行置换操作?

torch.permute()方法用于对 PyTorch 张量执行置换操作。它返回维度排列后的输入张量的视图。它不会复制原始张量。

例如,维度为 [2, 3] 的张量可以置换为 [3, 2]。我们还可以使用 置换具有新维度的张量。Tensor.permute()

语法

torch.permute(input,dims)

参数

  • 输入– PyTorch 张量。

  • dims – 所需维度的元组。

步骤

  • 导入火炬库。确保您已经安装了它。

import torch

  • 创建一个 PyTorch 张量并打印张量和张量的大小。

t = torch.tensor([[1,2],[3,4],[5,6]])
print("Tensor:\n", t)
print("张量的大小:", t.size()) # size 3x2

  • 计算并将值分配给变量。它不会改变原始张量inputtorch.permute(input, dims)

t1 = torch.permute(t, (1,0))

  • 在置换操作后打印结果张量及其大小。

print("Tensor after Permuting:\n", t1)
print("置换后的大小:", t1.size())

示例 1

在以下 Python 程序中,输入张量的维度为 [3,2]。我们使用 dims = (1, 0) 用新的维度 [2,3] 置换张量。

# import the torch library
import torch

# create a tensor
t = torch.tensor([[1,2],[3,4],[5,6]])

# print the created tensor
print("Tensor:\n", t)
print("张量的大小:", t.size())

# perform permute operation
t1 = torch.permute(t,(1,0))

# print the permuted tensor
print("Tensor after Permuting:\n", t1)
print("置换后的大小:", t1.size())
输出结果
Tensor:
 tensor([[1, 2],
    [3, 4],
    [5, 6]])
张量的大小: torch.Size([3, 2])
Tensor after Permuting:
 tensor([[1, 3, 5],
   [2, 4, 6]])
置换后的大小: torch.Size([2, 3])

示例 2

在以下 Python 代码中,输入张量大小为 [2,3,1]。我们使用dims = (0,2,1)。它给出了维度为 [2,1,3] 的输入张量的视图。

# import torch library
import torch

# create a tensor
t = torch.randn(2,3,1)

# print the created tensor
print("Tensor:\n", t)
print("张量的大小:", t.size())

# perform permute
t1 = torch.permute(t, (0,2,1))

# print the resultant tensor
print("Tensor after Permuting:\n", t1)
print("置换后的大小:", t1.size())
输出结果
Tensor:
 tensor([[[ 1.5285],
    [-0.2401],
    [ 0.2378]],

    [[ 0.4733],
     [-1.7317],
     [ 0.7557]]])
张量的大小: torch.Size([2, 3, 1])
Tensor after Permuting:
 tensor([[[ 1.5285, -0.2401, 0.2378]],

    [[ 0.4733, -1.7317, 0.7557]]])
置换后的大小: torch.Size([2, 1, 3])