【Pytorch】Autogradでテンソルの自動微分をする方法

スポンサーリンク

【Pytorch】Autogradでテンソルの自動微分をする方法

Pytorchでは、テンソルの自動微分をサポートするautogradの機能が提供されている。

テンソルは以下のような属性を持っている。

https://www.slideshare.net/yuyasoneoka/pytorch-80883065

autogradを用いるには、まず計算グラフを構築して、その後forward関数によって入力のテンソルから出力のテンソルに対する順伝播の計算を行う。

その後、backward関数(逆伝播)を呼ぶことにより、requires_grad=Trueを指定したすべてのテンソルの目的関数に関する勾配を計算する。

順伝播の計算

順伝播の計算を以下のコードで行う。

# 順伝播の計算
x = torch.randn(4, 4)
y = torch.randn(4, 1)

w = torch.randn(4, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

y_pred = torch.matmul(x, w) + b

# 目的関数の定義
loss = (y_pred - y).pow(2).sum()

この状態だと、まだ勾配は計算されていない状態。

以下のコードで勾配を調べてもNoneが出力される。

# まだ勾配は計算されていない
print(x.grad)
print(y.grad)
print(w.grad)
print(b.grad)

# 出力
# None
# None
# None
# None

逆伝播の計算

backward関数を実行した後は、requires_grad=Trueを指定した変数については勾配が計算されている。

# 逆伝播
loss.backward()

以下で勾配を確認する。

# requires_grad=Trueを指定した変数は勾配が計算されている
print(x.grad)
print(y.grad)
print(w.grad)
print(b.grad)

# 出力
# None
# None
# tensor([[ -4.4812],
#         [-18.1819],
#         [ 29.5095],
#         [ 17.7008]])
# tensor([-13.6415])

テンソルの勾配計算を行わないようにする(detach()・torch.no_grad())

.detach()を使うことにより、テンソルの勾配計算を行わないようにすることができる。

x = torch.randn(4, 4)
y = torch.randn(4, 1)

w = torch.randn(4, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
b = b.detach()  # bの勾配計算を停止

y_pred = torch.matmul(x, w) + b

loss = (y_pred - y).pow(2).sum()

loss.backward()

print(w.grad)  # 勾配を有する
print(b.grad)  # 勾配を有さない

# 出力
# tensor([[  6.7204],
#         [-10.1176],
#         [ 12.7670],
#         [  8.8713]])
# None

テンソルbについては勾配を持っていないことが確認できる。

また、with torch.no_grad()でくくることで、その中で定義したテンソルの勾配計算をまとめて停止させることが可能。

これは、学習済みのモデルを評価する際に、モデルがrequires_grad=Trueとなっているパラメータを有する場合でも、勾配計算を行わないようにしたいときなどに有用。

以下のコードのように書く。

with torch.no_grad():
    y_eval = torch.matmul(x, w) + b  # y_predと同様の計算を行う

print('requires_grad of y_pred:', y_pred.requires_grad)  # requires_grad=True
print('requires_grad of y_eval:', y_eval.requires_grad)  # requires_grad=False

# 出力
# requires_grad of y_pred: True
# requires_grad of y_eval: False

人気記事

人気記事はこちら。

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルを連結する方法(cat・stack)
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】
【Python】Tensorflowをダウングレード・アップグレードするコマンド

タイトルとURLをコピーしました