【Pytorch】テンソルを連結する方法(cat・stack)

スポンサーリンク

【Pytorch】テンソルを連結する方法(cat・stack)

Pytorchのテンソルを連結する方法をまとめておく。

連結方法としてはcatstackがある。

既存のdimに沿ってテンソルを連結する(cat)

torch.catは既存のdimに沿ってテンソルを連結することができる。

以下のコードのように使う。

a = torch.arange(6).reshape(2, 3)
b = torch.arange(8).reshape(2, 4)
print(a)
print(a.size())
print()
print(b)
print(b.size())
print()

c = torch.cat([a, b], dim=1) # 既存のdimで連結する. 他のdimのsizeは揃っている必要がある (numpy.concatenate)
print("torch.catでdim1方向に連結:")
print(c)
print(c.size())

# 出力
# tensor([[0, 1, 2],
#         [3, 4, 5]])
# torch.Size([2, 3])
# 
# tensor([[0, 1, 2, 3],
#         [4, 5, 6, 7]])
# torch.Size([2, 4])
# 
# torch.catでdim1方向に連結:
# tensor([[0, 1, 2, 0, 1, 2, 3],
#         [3, 4, 5, 4, 5, 6, 7]])
# torch.Size([2, 7])

新しいdimを作成しそれにそってテンソルを連結する(stack)

torch.stackは新しいdimを作成し、そのdimに沿ってテンソルを連結することができる。

以下のコードのように使う。

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6).reshape(2, 3)
print(a)
print(a.size())
print()

c = torch.stack([a, b], dim=2) # 新しいdimで連結する. 既存のdimのsizeはすべて揃っている必要がある (numpy.stack)
print("torch.stackで新しい次元(dim2)方向に連結:")
print(c)
print(c.size())

# 出力
# torch.stackで新しい次元(dim2)方向に連結:
# tensor([[[0, 0],
#          [1, 1],
#          [2, 2]],
# 
#         [[3, 3],
#          [4, 4],
#          [5, 5]]])
# torch.Size([2, 3, 2])

以上がPytorchのテンソルを連結する方法のまとめ。

人気記事

人気記事はこちら。

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルを連結する方法(cat・stack)
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】
【Python】Tensorflowをダウングレード・アップグレードするコマンド
タイトルとURLをコピーしました