【Pytorch】テンソルを分割する方法(split・chunk)

スポンサーリンク

【Pytorch】テンソルを分割する方法(split・chunk)

Pytorchで宣言したテンソルの分割を行う方法をまとめる。

分割する前のテンソルを以下のコードで定義しておく。

a = torch.arange(10).reshape(2, 5)
print(a)

# 出力
# tensor([[0, 1, 2, 3, 4],
#         [5, 6, 7, 8, 9]])

指定したサイズで分割する場合(spllit)

2要素ごとに分割する場合は以下のようなコードになる。

割り切れない場合は最後の要素が余りになる。

numpyではnumpy.splitと同じ。

b = torch.split(a, 2, dim=1)  # 2要素ごとに分割(割り切れない場合は最後の要素が余りになる) (numpy.split)
print("torch.splitで2個ずつに分割:")
print(b)
print([c.size() for c in b])

# 出力
# torch.splitで2個ずつに分割:
# (tensor([[0, 1],
#         [5, 6]]), tensor([[2, 3],
#         [7, 8]]), tensor([[4],
#         [9]]))
# [torch.Size([2, 2]), torch.Size([2, 2]), torch.Size([2, 1])]

また、リストを引数に指定すると、その要素のサイズで分割する。

例えば[1, 3, 1]を指定すると1, 3, 1個ずつに分割する。

b = torch.split(a, [1, 3, 1], dim=1)  # リストを指定するとそのサイズで分割
print("torch.splitで1,3,1個ずつに分割:")
print(b)

# 出力
# torch.splitで1,3,1個ずつに分割:
# (tensor([[0],
#         [5]]), tensor([[1, 2, 3],
#         [6, 7, 8]]), tensor([[4],
#         [9]]))
# [torch.Size([2, 1]), torch.Size([2, 3]), torch.Size([2, 1])]
スポンサーリンク

指定した個数に分割する場合(chunk)

指定した個数にテンソルを分割するにはchunkを使う。

5グループに分割するには以下のようなコードになる。

b = torch.chunk(a, 5, dim=1) # 5グループに分割
print("chunkで5グループに分割:")
print(b)
print([c.size() for c in b])

# 出力
# chunkで5グループに分割:
# (tensor([[0],
#         [5]]), tensor([[1],
#         [6]]), tensor([[2],
#         [7]]), tensor([[3],
#         [8]]), tensor([[4],
#         [9]]))
# [torch.Size([2, 1]), torch.Size([2, 1]), torch.Size([2, 1]), torch.Size([2, 1]), torch.Size([2, 1])]

3グループに分割する場合は以下のようなコードになる。

割り切れない場合は最後の要素が小さくなる。

b = torch.chunk(a, 3, dim=1) # 3グループに分割(割り切れない場合は最後の要素が小さくなる)
print("chunkで3グループに分割:")
print(b)
print([c.size() for c in b])

# 出力
# chunkで3グループに分割:
# (tensor([[0, 1],
#         [5, 6]]), tensor([[2, 3],
#         [7, 8]]), tensor([[4],
#         [9]]))
# [torch.Size([2, 2]), torch.Size([2, 2]), torch.Size([2, 1])]

以上がテンソルを分割する方法の紹介。

人気記事

人気記事はこちら。

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルを連結する方法(cat・stack)
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】
【Python】Tensorflowをダウングレード・アップグレードするコマンド
タイトルとURLをコピーしました