【Pytorch】テンソルの型を定義する方法(dtype)
Pytorchで定義したテンソルの型は、dtypeの引数を指定することで定義することができる。
テンソルの要素をfloat型にしたいときは以下のようにコードを書く。
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
print(a.dtype)
print(a)
print()
# 出力
# torch.float32
# tensor([[1., 2.],
# [3., 4.]])
テンソルの要素をint型にしたいときは以下のようにコードを書く。
a = torch.ones((2, 3), dtype=torch.int)
print(a.dtype)
print(a)
print()
# 出力
# torch.int32
# tensor([[1, 1, 1],
# [1, 1, 1]], dtype=torch.int32)
テンソルの要素をlong型にしたいときは以下のようにコードを書く。
a = torch.ones((2, 3), dtype=torch.long)
print(a.dtype)
print(a)
print()
# 出力
# torch.int64
# tensor([[1, 1, 1],
[1, 1, 1]])
人気記事
人気記事はこちら。