【Pytorch】テンソルの次元を入れ替え・変形する方法(reshape・transpose・permute)
Pytorchで定義したテンソルの次元を入れ替えたり変形する方法をまとめておく。
入れ替え・変形にはreshape・transpose・permuteを用いる。
元のテンソルとして以下を用いる。
0~5を要素とする1次元のものを定義。
a = torch.arange(6) # 等差数列を作成 (numpy.arange)
print(a)
print(a.size())
# 出力
# tensor([0, 1, 2, 3, 4, 5])
# torch.Size([6])
reshapeを使って変形する
reshapeを使った変形は以下のコードのようになる。
b = a.reshape(1, 2, 3)
print(b)
print(b.size())
# 出力
# tensor([[[0, 1, 2],
# [3, 4, 5]]])
# torch.Size([1, 2, 3])
transposeを使って次元を入れ替える
transposeを使った次元の入れ替えは以下のコードのようになる。
これはテンソルbのdim0とdim2の次元を入れ替える場合。
c = b.transpose(0, 2) # dim0とdim2の次元を入れ替える (numpy.transpose)
print(c)
print(c.size())
# 出力
# tensor([[[0],
# [3]],
#
# [[1],
# [4]],
#
# [[2],
# [5]]])
# torch.Size([3, 2, 1])
permuteを使って次元を入れ替える
permuteを使った次元の入れ替えは以下のコードのようになる。
c = b.permute(0, 2, 1) # 指定したdimの順に次元を入れ替える
print(c)
print(c.size())
# 出力
# tensor([[[0, 3],
# [1, 4],
# [2, 5]]])
# torch.Size([1, 3, 2])
以上がテンソルの次元を入れ替え・変形する方法。
人気記事
人気記事はこちら。