如何缩小 PyTorch 中的张量?

torch.narrow()方法用于对 PyTorch 张量执行窄操作。它返回一个新的张量,它是原始输入张量的缩小版本。

例如,[4, 3] 的张量可以缩小为 [2, 3] 或 [4, 2] 大小的张量。我们可以一次缩小一个维度上的张量。在这里,我们不能将两个维度都缩小到 [2, 2] 的大小。我们也可以用来缩小张量的范围。Tensor.narrow()

语法

torch.narrow(input, dim, start, length)
Tensor.narrow(dim, start, length)

参数

  • 输入——它是要缩小的 PyTorch 张量。

  • dim – 这是我们必须缩小原始张量输入的维度。

  • 开始- 开始维度。

  • 长度– 从起始尺寸到结束尺寸的长度。

步骤

  • 导入火炬库。确保您已经安装了它。

import torch

  • 创建一个 PyTorch 张量并打印张量及其大小。

t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("Tensor:\n", t)
print("张量的大小:", t.size()) # size 3x3

  • 计算并将值分配给变量。torch.narrow(input, dim, start, length)

t1 = torch.narrow(t, 0, 1, 2)

  • 缩小后打印结果张量及其大小。

print("Tensor after Narrowing:\n", t2)
print("缩小后的尺寸:", t2.size())

示例 1

在以下 Python 代码中,输入张量大小为 [3, 3]。我们使用dim = 0, start = 1 和length = 2 沿维度0缩小张量。它返回一个维度为[2, 3]的新张量。

请注意,新张量沿维度 0 变窄,沿维度 0 的长度更改为 2。

# import the library
import torch

# create a tensor
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# print the created tensor
print("Tensor:\n", t)
print("张量的大小:", t.size())

# Narrow-down the tensor in dimension 0
t1 = torch.narrow(t, 0, 1, 2)
print("Tensor after Narrowing:\n", t1)
print("缩小后的尺寸:", t1.size())

# Narrow down the tensor in dimension 1
t2 = torch.narrow(t, 1, 1, 2)
print("Tensor after Narrowing:\n", t2)
print("缩小后的尺寸:", t2.size())
输出结果
Tensor:
 tensor([[1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]])
张量的大小: torch.Size([3, 3])
Tensor after Narrowing:
 tensor([[4, 5, 6],
    [7, 8, 9]])
缩小后的尺寸: torch.Size([2, 3])
Tensor after Narrowing:
 tensor([[2, 3],
    [5, 6],
    [8, 9]])
缩小后的尺寸: torch.Size([3, 2])

示例 2

以下程序显示了如何使用.Tensor.narrow()

# import required library
import torch

# create a tensor
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
# print the above created tensor
print("Tensor:\n", t)
print("张量的大小:", t.size())

# Narrow-down the tensor in dimension 0
t1 = t.narrow(0, 1, 2)
print("Tensor after Narrowing:\n", t1)
print("缩小后的尺寸:", t1.size())

# Narrow down the tensor in dimension 1
t2 = t.narrow(1, 0, 2)
print("Tensor after Narrowing:\n", t2)
print("缩小后的尺寸:", t2.size())
输出结果
Tensor:
 tensor([[ 1, 2, 3],
    [ 4, 5, 6],
    [ 7, 8, 9],
    [10, 11, 12]])
张量的大小: torch.Size([4, 3])
Tensor after Narrowing:
 tensor([[4, 5, 6],
    [7, 8, 9]])
缩小后的尺寸: torch.Size([2, 3])
Tensor after Narrowing:
 tensor([[ 1, 2],
    [ 4, 5],
    [ 7, 8],
    [10, 11]])
缩小后的尺寸: torch.Size([4, 2])