【Pytorch】テンソルの次元を追加・削除する方法【dim】
Pytorchでテンソル(Tensor)の次元を追加したり削除したりする方法をコード例でまとめておく。
元のテンソルとして以下を用いる。
0~5を要素とする1次元のものを定義。
a = torch.arange(6) # 等差数列を作成 (numpy.arange)
print(a)
print(a.size())
# 出力
# tensor([0, 1, 2, 3, 4, 5])
# torch.Size([6])
変形(reshape)
テンソルを別次元に変形するにはreshapeを用いる。
これはnumpyと同じ方法。
以下のコードで変形する。
b = a.reshape(2, 3) # 変形 (numpy.reshape)
print("reshape(2,3):")
print(b)
print(b.size())
# 出力
# reshape(2,3):
# tensor([[0, 1, 2],
# [3, 4, 5]])
# torch.Size([2, 3])
次元の追加(view)
次元を追加するにはviewを用いる。
以下のコードで追加。
b = a.view(1, 2, 3, 1) # dimの追加も可能
print("view(1,2,3,1):")
print(b)
print(b.size())
# 出力
# view(1,2,3,1):
# tensor([[[[0],
# [1],
# [2]],
#
# [[3],
# [4],
# [5]]]])
# torch.Size([1, 2, 3, 1])
a.viewでエラーが出る場合は以下のように.contiguous()を読んでからview()すると良いようだ。(参照:https://discuss.pytorch.org/t/runtimeerror-input-is-not-contiguous/930)
print(a.contiguous().view(1, 2, 3, 1))
# 出力
# tensor([[[[0],
# [1],
# [2]],
#
# [[3],
# [4],
# [5]]]])
また、最後のdimのsizeを2にして他のdimをつぶすには、以下のようにview(-1,2)というようにすればよい。
b = a.view(-1, 2) # 最後のdimのsizeを2に, 他のdimは"潰す"
print("view(-1,2)で最後のdimのsizeを2に, 他のdimは潰す:")
print(b)
print(b.size())
# 出力
# view(-1,2)で最後のdimのsizeを2に, 他のdimは潰す:
# tensor([[0, 1],
# [2, 3],
# [4, 5]])
# torch.Size([3, 2])
新しい次元を追加(unsqueeze)
新しい次元を追加するにはunsqueezeを以下のように用いる。
b = a.unsqueeze(dim=1) # 新しいdimを追加する。 a[:, None]でもOK (numpy.expand_dims)
print("unsqueeze(dim=1)で次元をdim1に追加:")
print(b)
print(b.size())
# 出力
# unsqueeze(dim=1)で次元をdim1に追加:
# tensor([[0],
# [1],
# [2],
# [3],
# [4],
# [5]])
# torch.Size([6, 1])
サイズが1の次元を削除
↑で次元を1つ追加したテンソルに対して、サイズ1の次元を削除して元のテンソルaに戻すにはsqueezeを以下のように用いる。
c = b.squeeze() # squeezeを用いるとsizeが1の次元が除去される (numpy.squeeze)
print("squeeze()でサイズ1の次元を削除:")
print(c)
print(c.size())
# 出力
# squeeze()でサイズ1の次元を削除:
# tensor([0, 1, 2, 3, 4, 5])
# torch.Size([6])
以上が、Pytorchでテンソル(Tensor)の次元を追加・削除する方法の紹介。
人気記事
人気記事はこちら。