【Pytorch】テンソルを連結する方法(cat・stack)
Pytorchのテンソルを連結する方法をまとめておく。
連結方法としてはcatとstackがある。
既存のdimに沿ってテンソルを連結する(cat)
torch.catは既存のdimに沿ってテンソルを連結することができる。
以下のコードのように使う。
a = torch.arange(6).reshape(2, 3)
b = torch.arange(8).reshape(2, 4)
print(a)
print(a.size())
print()
print(b)
print(b.size())
print()
c = torch.cat([a, b], dim=1) # 既存のdimで連結する. 他のdimのsizeは揃っている必要がある (numpy.concatenate)
print("torch.catでdim1方向に連結:")
print(c)
print(c.size())
# 出力
# tensor([[0, 1, 2],
# [3, 4, 5]])
# torch.Size([2, 3])
#
# tensor([[0, 1, 2, 3],
# [4, 5, 6, 7]])
# torch.Size([2, 4])
#
# torch.catでdim1方向に連結:
# tensor([[0, 1, 2, 0, 1, 2, 3],
# [3, 4, 5, 4, 5, 6, 7]])
# torch.Size([2, 7])
新しいdimを作成しそれにそってテンソルを連結する(stack)
torch.stackは新しいdimを作成し、そのdimに沿ってテンソルを連結することができる。
以下のコードのように使う。
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6).reshape(2, 3)
print(a)
print(a.size())
print()
c = torch.stack([a, b], dim=2) # 新しいdimで連結する. 既存のdimのsizeはすべて揃っている必要がある (numpy.stack)
print("torch.stackで新しい次元(dim2)方向に連結:")
print(c)
print(c.size())
# 出力
# torch.stackで新しい次元(dim2)方向に連結:
# tensor([[[0, 0],
# [1, 1],
# [2, 2]],
#
# [[3, 3],
# [4, 4],
# [5, 5]]])
# torch.Size([2, 3, 2])
以上がPytorchのテンソルを連結する方法のまとめ。
人気記事
人気記事はこちら。