【Pytorch】テンソルを分割する方法(split・chunk)
Pytorchで宣言したテンソルの分割を行う方法をまとめる。
分割する前のテンソルを以下のコードで定義しておく。
a = torch.arange(10).reshape(2, 5)
print(a)
# 出力
# tensor([[0, 1, 2, 3, 4],
# [5, 6, 7, 8, 9]])
指定したサイズで分割する場合(spllit)
2要素ごとに分割する場合は以下のようなコードになる。
割り切れない場合は最後の要素が余りになる。
numpyではnumpy.splitと同じ。
b = torch.split(a, 2, dim=1) # 2要素ごとに分割(割り切れない場合は最後の要素が余りになる) (numpy.split)
print("torch.splitで2個ずつに分割:")
print(b)
print([c.size() for c in b])
# 出力
# torch.splitで2個ずつに分割:
# (tensor([[0, 1],
# [5, 6]]), tensor([[2, 3],
# [7, 8]]), tensor([[4],
# [9]]))
# [torch.Size([2, 2]), torch.Size([2, 2]), torch.Size([2, 1])]
また、リストを引数に指定すると、その要素のサイズで分割する。
例えば[1, 3, 1]を指定すると1, 3, 1個ずつに分割する。
b = torch.split(a, [1, 3, 1], dim=1) # リストを指定するとそのサイズで分割
print("torch.splitで1,3,1個ずつに分割:")
print(b)
# 出力
# torch.splitで1,3,1個ずつに分割:
# (tensor([[0],
# [5]]), tensor([[1, 2, 3],
# [6, 7, 8]]), tensor([[4],
# [9]]))
# [torch.Size([2, 1]), torch.Size([2, 3]), torch.Size([2, 1])]
指定した個数に分割する場合(chunk)
指定した個数にテンソルを分割するにはchunkを使う。
5グループに分割するには以下のようなコードになる。
b = torch.chunk(a, 5, dim=1) # 5グループに分割
print("chunkで5グループに分割:")
print(b)
print([c.size() for c in b])
# 出力
# chunkで5グループに分割:
# (tensor([[0],
# [5]]), tensor([[1],
# [6]]), tensor([[2],
# [7]]), tensor([[3],
# [8]]), tensor([[4],
# [9]]))
# [torch.Size([2, 1]), torch.Size([2, 1]), torch.Size([2, 1]), torch.Size([2, 1]), torch.Size([2, 1])]
3グループに分割する場合は以下のようなコードになる。
割り切れない場合は最後の要素が小さくなる。
b = torch.chunk(a, 3, dim=1) # 3グループに分割(割り切れない場合は最後の要素が小さくなる)
print("chunkで3グループに分割:")
print(b)
print([c.size() for c in b])
# 出力
# chunkで3グループに分割:
# (tensor([[0, 1],
# [5, 6]]), tensor([[2, 3],
# [7, 8]]), tensor([[4],
# [9]]))
# [torch.Size([2, 2]), torch.Size([2, 2]), torch.Size([2, 1])]
以上がテンソルを分割する方法の紹介。
人気記事
人気記事はこちら。