Pytorch学习之torch.split函数
一、简介
torch.split
用于将一个张量(tensor)沿指定维度(dim)拆分为多个子张量。这个函数对于处理需要按块拆分数据的任务非常有用,例如在自然语言处理和图像处理中的数据预处理。
二、语法和参数
torch.split(tensor, split_size_or_sections, dim=0)
参数:
tensor
:要拆分的张量。split_size_or_sections
:一个整数或者一个包含每个子张量大小的列表。- 当为【整数】时,表示每个子张量的大小,最后一个子张量可能会小于这个大小。
- 当为【列表】时,表示每个子张量的大小。
dim
:沿着哪个维度进行拆分,默认值为0。
三、实例
例1:按固定大小
拆分
import torch
# 创建一个示例张量
tensor = torch.arange(10)
print("原始张量:", tensor)
# 按大小为3拆分张量
split_tensors = torch.split(tensor, 3)
print("拆分后的张量:")
for i, t in enumerate(split_tensors):
print(f"子张量 {i+1}: {t}")
输出:
原始张量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
拆分后的张量:
子张量 1: tensor([0, 1, 2])
子张量 2: tensor([3, 4, 5])
子张量 3: tensor([6, 7, 8])
子张量 4: tensor([9])
例2:按指定大小列表
拆分
import torch
# 创建一个示例张量
tensor = torch.arange(10)
print("原始张量:", tensor)
# 按指定大小列表拆分张量
split_sizes = [2, 3, 5]
split_tensors = torch.split(tensor, split_sizes)
print("拆分后的张量:")
for i, t in enumerate(split_tensors):
print(f"子张量 {i+1}: {t}")
输出:
原始张量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
拆分后的张量:
子张量 1: tensor([0, 1])
子张量 2: tensor([2, 3, 4])
子张量 3: tensor([5, 6, 7, 8, 9])
例3:按不同维度
拆分
import torch
# 创建一个示例二维张量
tensor = torch.arange(12).reshape(3, 4)
print("原始张量:\n", tensor)
# 按大小为2沿维度1拆分张量
split_tensors = torch.split(tensor, 2, dim=1)
print("拆分后的张量:")
for i, t in enumerate(split_tensors):
print(f"子张量 {i+1}:\n{t}")
输出:
原始张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
拆分后的张量:
子张量 1:
tensor([[0, 1],
[4, 5],
[8, 9]])
子张量 2:
tensor([[ 2, 3],
[ 6, 7],
[10, 11]])
例4:按不同维度
指定大小列表
拆分
import torch
# 创建一个示例二维张量
tensor = torch.arange(12).reshape(3, 4)
print("原始张量:\n", tensor)
# 按指定大小列表拆分张量
split_tensors = torch.split(tensor, [1, 2, 1], dim=1)
print("拆分后的张量:")
for i, t in enumerate(split_tensors):
print(f"子张量 {i+1}:\n{t}")
输出:
原始张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
拆分后的张量:
子张量 1:
tensor([[0],
[4],
[8]])
子张量 2:
tensor([[ 1, 2],
[ 5, 6],
[ 9, 10]])
子张量 3:
tensor([[ 3],
[ 7],
[11]])
例5:高维张量拆分
import torch
# 创建一个示例三维张量
tensor = torch.arange(12).reshape(3, 2, 2)
print("原始张量:\n", tensor)
# # 按指定大小列表拆分高维张量
split_tensors = torch.split(tensor, [1, 1], dim=1)
print("拆分后的张量:")
for i, t in enumerate(split_tensors):
print(f"子张量 {i+1}:\n{t}")
输出:
原始张量:
tensor([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]])
拆分后的张量:
子张量 1:
tensor([[[0, 1]],
[[4, 5]],
[[8, 9]]])
子张量 2:
tensor([[[ 2, 3]],
[[ 6, 7]],
[[10, 11]]])
四、注意事项
-
当
split_size_or_sections
为整数时,最后一个子张量的大小可能小于这个整数。 -
确保拆分的维度大小能够被
split_size_or_sections
整除,否则最后一个子张量会包含剩余的元素。 -
在高维张量上操作时,注意选择合适的维度进行拆分。
-
同时只能在一个维度上进行拆分,不能同时拆分多个维度,比如下面这样的语句就是不合规范的:
split_tensors = torch.split(tensor, [[1, 2, 1], [2, 1]], dim=[1, 0])