【Pytorch】テンソルの次元を追加・削除する方法【dim】

スポンサーリンク

【Pytorch】テンソルの次元を追加・削除する方法【dim】

Pytorchでテンソル(Tensor)の次元を追加したり削除したりする方法をコード例でまとめておく。

元のテンソルとして以下を用いる。

0~5を要素とする1次元のものを定義。

a = torch.arange(6)  # 等差数列を作成 (numpy.arange)
print(a)
print(a.size())

# 出力
# tensor([0, 1, 2, 3, 4, 5])
# torch.Size([6])

変形(reshape)

テンソルを別次元に変形するにはreshapeを用いる。

これはnumpyと同じ方法。

以下のコードで変形する。

b = a.reshape(2, 3)  # 変形 (numpy.reshape)
print("reshape(2,3):")
print(b)
print(b.size())

# 出力
# reshape(2,3):
# tensor([[0, 1, 2],
#        [3, 4, 5]])
# torch.Size([2, 3])

次元の追加(view)

次元を追加するにはviewを用いる。

以下のコードで追加。

b = a.view(1, 2, 3, 1) # dimの追加も可能
print("view(1,2,3,1):")
print(b) 
print(b.size())

# 出力
# view(1,2,3,1):
# tensor([[[[0],
#           [1],
#           [2]],
# 
#          [[3],
#           [4],
#           [5]]]])
# torch.Size([1, 2, 3, 1])

a.viewでエラーが出る場合は以下のように.contiguous()を読んでからview()すると良いようだ。(参照:https://discuss.pytorch.org/t/runtimeerror-input-is-not-contiguous/930

print(a.contiguous().view(1, 2, 3, 1))

# 出力
# tensor([[[[0],
#           [1],
#           [2]],
#
#          [[3],
#           [4],
#           [5]]]])

また、最後のdimのsizeを2にして他のdimをつぶすには、以下のようにview(-1,2)というようにすればよい。

b = a.view(-1, 2) # 最後のdimのsizeを2に, 他のdimは"潰す"
print("view(-1,2)で最後のdimのsizeを2に, 他のdimは潰す:")
print(b) 
print(b.size())

# 出力
# view(-1,2)で最後のdimのsizeを2に, 他のdimは潰す:
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])
# torch.Size([3, 2])
スポンサーリンク

新しい次元を追加(unsqueeze)

新しい次元を追加するにはunsqueezeを以下のように用いる。

b = a.unsqueeze(dim=1)  # 新しいdimを追加する。 a[:, None]でもOK (numpy.expand_dims)
print("unsqueeze(dim=1)で次元をdim1に追加:")
print(b)
print(b.size())

# 出力
# unsqueeze(dim=1)で次元をdim1に追加:
# tensor([[0],
#         [1],
#         [2],
#         [3],
#         [4],
#         [5]])
# torch.Size([6, 1])

サイズが1の次元を削除

↑で次元を1つ追加したテンソルに対して、サイズ1の次元を削除して元のテンソルaに戻すにはsqueezeを以下のように用いる。

c = b.squeeze()  # squeezeを用いるとsizeが1の次元が除去される (numpy.squeeze)
print("squeeze()でサイズ1の次元を削除:")
print(c)
print(c.size())

# 出力
# squeeze()でサイズ1の次元を削除:
# tensor([0, 1, 2, 3, 4, 5])
# torch.Size([6])

以上が、Pytorchでテンソル(Tensor)の次元を追加・削除する方法の紹介。

人気記事

人気記事はこちら。

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルを連結する方法(cat・stack)
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】
【Python】Tensorflowをダウングレード・アップグレードするコマンド

タイトルとURLをコピーしました