【Pytorch】テンソルの次元を入れ替え・変形する方法(reshape・transpose・permute)

スポンサーリンク

【Pytorch】テンソルの次元を入れ替え・変形する方法(reshape・transpose・permute)

Pytorchで定義したテンソルの次元を入れ替えたり変形する方法をまとめておく。

入れ替え・変形にはreshape・transpose・permuteを用いる。

元のテンソルとして以下を用いる。

0~5を要素とする1次元のものを定義。

a = torch.arange(6)  # 等差数列を作成 (numpy.arange)
print(a)
print(a.size())

# 出力
# tensor([0, 1, 2, 3, 4, 5])
# torch.Size([6])

reshapeを使って変形する

reshapeを使った変形は以下のコードのようになる。

b = a.reshape(1, 2, 3)
print(b) 
print(b.size())

# 出力
# tensor([[[0, 1, 2],
#          [3, 4, 5]]])
# torch.Size([1, 2, 3])

transposeを使って次元を入れ替える

transposeを使った次元の入れ替えは以下のコードのようになる。

これはテンソルbのdim0とdim2の次元を入れ替える場合。

c = b.transpose(0, 2)  # dim0とdim2の次元を入れ替える (numpy.transpose)
print(c) 
print(c.size())

# 出力
# tensor([[[0],
#          [3]],
# 
#         [[1],
#          [4]],
# 
#         [[2],
#          [5]]])
# torch.Size([3, 2, 1])

permuteを使って次元を入れ替える

permuteを使った次元の入れ替えは以下のコードのようになる。

c = b.permute(0, 2, 1)  # 指定したdimの順に次元を入れ替える
print(c) 
print(c.size())

# 出力
# tensor([[[0, 3],
#          [1, 4],
#          [2, 5]]])
# torch.Size([1, 3, 2])

以上がテンソルの次元を入れ替え・変形する方法。

人気記事

人気記事はこちら。

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Pytorch】テンソルを連結する方法(cat・stack)
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】
【Python】Tensorflowをダウングレード・アップグレードするコマンド

タイトルとURLをコピーしました