【Pytorch】テンソルの型を定義する方法(dtype)

スポンサーリンク

【Pytorch】テンソルの型を定義する方法(dtype)

Pytorchで定義したテンソルの型は、dtypeの引数を指定することで定義することができる。

テンソルの要素をfloat型にしたいときは以下のようにコードを書く。

a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
print(a.dtype)
print(a)
print()

# 出力
# torch.float32
# tensor([[1., 2.],
#         [3., 4.]])

テンソルの要素をint型にしたいときは以下のようにコードを書く。

a = torch.ones((2, 3), dtype=torch.int)
print(a.dtype)
print(a)
print()

# 出力
# torch.int32
# tensor([[1, 1, 1],
#         [1, 1, 1]], dtype=torch.int32)

テンソルの要素をlong型にしたいときは以下のようにコードを書く。

a = torch.ones((2, 3), dtype=torch.long)
print(a.dtype)
print(a)
print()

# 出力
# torch.int64
# tensor([[1, 1, 1],
        [1, 1, 1]])

気記事

人気記事はこちら。

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Pytorch】テンソルを連結する方法(cat・stack)
【Python】Tensorflowをダウングレード・アップグレードするコマンド
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】

タイトルとURLをコピーしました