「PyTorch」の版間の差分

提供:ペチラボ書庫
ナビゲーションに移動 検索に移動
編集の要約なし
 
22行目: 22行目:
<syntaxhighlight inline>model.train()</syntaxhighlight>を実行しておく。
<syntaxhighlight inline>model.train()</syntaxhighlight>を実行しておく。
各イテレーションにおいて必要な処理は以下の通り。
各イテレーションにおいて必要な処理は以下の通り。
# <syntaxhighlight inline>torch.Optimizer.zero_grad()</syntaxhighlight>または <syntaxhighlight inline>nn.Module.zero_grad()</syntaxhighlight>[https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.zero_grad]で勾配を初期化しておく。
# モデルにデータを入力して出力を得る(順伝播)
# モデルにデータを入力して出力を得る(順伝播)
# [[損失関数]]を計算する。
# [[損失関数]]を計算する。
# 勾配を計算する前に<syntaxhighlight inline>torch.Optimizer.zero_grad()</syntaxhighlight>または <syntaxhighlight inline>nn.Module.zero_grad()</syntaxhighlight>[https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.zero_grad]で勾配を初期化しておく。
# 勾配を計算する。<syntaxhighlight inline>loss.backward()</syntaxhighlight>
# 勾配を計算する。<syntaxhighlight inline>loss.backward()</syntaxhighlight>
# 誤差逆伝播 <syntaxhighlight inline>optim.step()</syntaxhighlight>
# 誤差逆伝播 <syntaxhighlight inline>optim.step()</syntaxhighlight>

2023年2月10日 (金) 01:03時点における最新版

https://pytorch.org/

AIをつくろう

やりかた

データセットの準備

https://pytorch.org/docs/stable/data.html

  • Datasetを作成
    • 適宜、データ拡張
  • DataLoaderを作成

モデルの定義

モデルはtorch.nn.Moduleの派生クラスである必要がある。

torch.nn.Sequentialtorch.nn.ModuleList を使うこともできる。

学習

model.train()を実行しておく。 各イテレーションにおいて必要な処理は以下の通り。

  1. torch.Optimizer.zero_grad()または nn.Module.zero_grad()[1]で勾配を初期化しておく。
  2. モデルにデータを入力して出力を得る(順伝播)
  3. 損失関数を計算する。
  4. 勾配を計算する。loss.backward()
  5. 誤差逆伝播 optim.step()

テスト

model.eval()を実行しておく。 これにより学習時に有効になっていたdropoutやbatch normなどが無効化される。 また、with torch.no_grad():としておけば勾配の計算が省かれ、速くなる。